import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.grain_models import WD01

fig, ax = plt.subplots()

# generate the curves and plot them
lam = np.logspace(-4.0, 4.0, num=1000)
x = (1.0 / lam) / u.micron

models = [WD01, WD01, WD01]
modelnames = ["LMCAvg", "LMC2", "SMCBar"]

for cmodel, cname in zip(models, modelnames):
   ext_model = cmodel(cname)

   indxs, = np.where(np.logical_and(
      x.value >= ext_model.x_range[0],
      x.value <= ext_model.x_range[1]))
   yvals = ext_model(x[indxs])
   ax.plot(lam[indxs], yvals, label=f"{ext_model.__class__.__name__}  {cname}")

ax.set_xlabel('$\lambda$ [$\mu m$]')
ax.set_ylabel('$A(x)/A(V)$')
ax.set_title('Grain Models')

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('LMC & SMC')

ax.legend(loc='best')
plt.tight_layout()
plt.show()