import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.grain_models import WD01

fig, ax = plt.subplots()

tmod = WD01()
possmodels = tmod.possnames.keys()

# generate the curves and plot them
lam = np.logspace(-2.0, 4.0, num=1000)
x = (1.0 / lam) / u.micron

for cmodel in possmodels:
    # define the extinction model
    ext_model = WD01(cmodel)
    ax.plot(lam,ext_model(x),label=cmodel)

ax.set_xlabel(r'$\lambda$ [$\mu m$]')
ax.set_ylabel(r'$A(x)/A(V)$')

ax.set_xscale('log')
ax.set_yscale('log')

ax.legend(loc='best')
plt.show()