import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.grain_models import DBP90, WD01, D03, ZDA04, C11, J13

fig, ax = plt.subplots()

# generate the curves and plot them
lam = np.logspace(-4.0, 5.0, num=1000)
x = (1.0 / lam) / u.micron

models = [DBP90,
          WD01, WD01, WD01,
          D03, D03, D03,
          ZDA04,
          C11, J13]
modelnames = ["MWRV31",
              "MWRV31", "MWRV40", "MWRV55",
              "MWRV31", "MWRV40", "MWRV55",
              "BARE-GR-S",
              "MWRV31", "MWRV31"]

for cmodel, cname in zip(models, modelnames):
   ext_model = cmodel(cname)

   indxs, = np.where(np.logical_and(
      x.value >= ext_model.x_range[0],
      x.value <= ext_model.x_range[1]))
   yvals = ext_model(x[indxs])
   ax.plot(lam[indxs], yvals, label=f"{ext_model.__class__.__name__}  {cname}")

ax.set_xlabel('$\lambda$ [$\mu m$]')
ax.set_ylabel('$A(x)/A(V)$')
ax.set_title('Grain Models')

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('Milky Way')

ax.legend(loc='best')
plt.tight_layout()
plt.show()