import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.shapes import P92

fig, ax = plt.subplots()

# generate the curves and plot them
lam = np.logspace(-3.0, 3.0, num=1000)
x = (1.0/lam)/u.micron

ext_model = P92()
ax.plot(1/x,ext_model(x),label='total')

ext_model = P92(FUV_amp=0., NUV_amp=0.0,
                SIL1_amp=0.0, SIL2_amp=0.0, FIR_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG only')

ext_model = P92(NUV_amp=0.0,
                SIL1_amp=0.0, SIL2_amp=0.0, FIR_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG+FUV only')

ext_model = P92(FUV_amp=0.,
                SIL1_amp=0.0, SIL2_amp=0.0, FIR_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG+NUV only')

ext_model = P92(FUV_amp=0., NUV_amp=0.0,
                SIL2_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG+FIR+SIL1 only')

ext_model = P92(FUV_amp=0., NUV_amp=0.0,
                SIL1_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG+FIR+SIL2 only')

ext_model = P92(FUV_amp=0., NUV_amp=0.0,
                SIL1_amp=0.0, SIL2_amp=0.0)
ax.plot(1./x,ext_model(x),label='BKG+FIR only')

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_ylim(1e-3,10.)

ax.set_xlabel('$\lambda$ [$\mu$m]')
ax.set_ylabel('$A(x)/A(V)$')

ax.set_title('P92')

ax.legend(loc='best')
plt.tight_layout()
plt.show()