import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from dust_extinction.averages import GCC09_MWAvg

fig, ax = plt.subplots()

# generate the curves and plot them
x = np.arange(0.3,1.0/0.0912,0.1)/u.micron

# define the extinction model
ext_model = GCC09_MWAvg()

# generate the curves and plot them
x = np.arange(ext_model.x_range[0], ext_model.x_range[1],0.1)/u.micron

ax.plot(x,ext_model(x),label='GCC09_MWAvg')
ax.errorbar(ext_model.obsdata_x_fuse, ext_model.obsdata_axav_fuse,
            yerr=ext_model.obsdata_axav_unc_fuse,
            fmt='ko', label='obsdata (FUSE)')
ax.errorbar(ext_model.obsdata_x_iue, ext_model.obsdata_axav_iue,
            yerr=ext_model.obsdata_axav_unc_iue,
            fmt='bs', label='obsdata (IUE)')
ax.errorbar(ext_model.obsdata_x_bands, ext_model.obsdata_axav_bands,
            yerr=ext_model.obsdata_axav_unc_bands,
            fmt='g^', label='obsdata (Opt/NIR)')

ax.set_xlabel(r'$x$ [$\mu m^{-1}$]')
ax.set_ylabel(r'$A(x)/A(V)$')

# for 2nd x-axis with lambda values
axis_xs = np.array([0.09, 0.1, 0.12, 0.15, 0.2, 0.3, 0.5, 1.0])
new_ticks = 1 / axis_xs
new_ticks_labels = ["%.2f" % z for z in axis_xs]
tax = ax.twiny()
tax.set_xlim(ax.get_xlim())
tax.set_xticks(new_ticks)
tax.set_xticklabels(new_ticks_labels)
tax.set_xlabel(r"$\lambda$ [$\mu$m]")

ax.legend(loc='best')
plt.show()