import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.averages import GCC09_MWAvg

fig, ax = plt.subplots(ncols=2)

# generate the curves and plot them
x = np.arange(0.3,10.0,0.1)/u.micron
ext_model = GCC09_MWAvg()

ax[0].plot(1./x,ext_model(x)*0.5,label='A(V) = 0.5 mag')
ax[0].plot(1./x,ext_model(x)*1.5,label='A(V) = 1.5 mag')
ax[0].plot(1./x,ext_model(x)*2.5,label='A(V) = 2.5 mag')

ax[1].plot(1./x,ext_model(x),label='A(V) = 0.5 mag')
ax[1].plot(1./x,ext_model(x),label='A(V) = 1.5 mag')
ax[1].plot(1./x,ext_model(x),label='A(V) = 2.5 mag')

ax[0].set_title('Total Extinction')
ax[1].set_title('Normalized Extinction')
ax[0].set_xlabel('$\lambda$ [$\mu m$]')
ax[1].set_xlabel('$\lambda$ [$\mu m$]')
ax[0].set_ylabel('$A(\lambda)$')
ax[1].set_ylabel('$A(\lambda)/A(V)$')
ax[0].set_xscale('log')
ax[1].set_xscale('log')
ax[0].set_xlim(0.09,4.0)
ax[1].set_xlim(0.09,4.0)

ax[0].legend(loc='best')
ax[1].legend(loc='best')
plt.tight_layout()
plt.show()