import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.shapes import G21

fig, ax = plt.subplots()

# generate the curves and plot them
lam = np.logspace(np.log10(1.01), np.log10(39.9), num=1000)
x = (1.0/lam)/u.micron

ext_model = G21()
ax.plot(1/x,ext_model(x),label='total')

ext_model = G21(sil1_amp=0.0, sil2_amp=0.0)
ax.plot(1./x,ext_model(x),label='power-law only')

ext_model = G21(sil2_amp=0.0)
ax.plot(1./x,ext_model(x),label='power-law+sil1 only')

ext_model = G21(sil1_amp=0.0)
ax.plot(1./x,ext_model(x),label='power-law+sil2 only')

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel('$\lambda$ [$\mu$m]')
ax.set_ylabel('$A(x)/A(V)$')

ax.set_title('G21')

ax.legend(loc='best')
plt.tight_layout()
plt.show()