import warnings
import matplotlib.pyplot as plt

from astropy.utils.exceptions import AstropyWarning
import astropy.units as u

from astropy.modeling.fitting import LevMarLSQFitter

from dust_extinction.averages import GCC09_MWAvg
from dust_extinction.shapes import P92
from dust_extinction.warnings import SpectralUnitsWarning

# get an observed extinction curve to fit
g09_model = GCC09_MWAvg()

# get an observed extinction curve to fit
x = g09_model.obsdata_x / u.micron
y = g09_model.obsdata_axav
y_unc = g09_model.obsdata_axav_unc

# initialize the model
p92_init = P92()

# fix a number of the parameters
#   mainly to avoid fitting parameters that are constrained at
#   wavelengths where the observed data for this case does not exist
p92_init.FUV_lambda.fixed = True
p92_init.SIL1_amp.fixed = True
p92_init.SIL1_lambda.fixed = True
p92_init.SIL1_b.fixed = True
p92_init.SIL2_amp.fixed = True
p92_init.SIL2_lambda.fixed = True
p92_init.SIL2_b.fixed = True
p92_init.FIR_amp.fixed = True
p92_init.FIR_lambda.fixed = True
p92_init.FIR_b.fixed = True

# pick the fitter
fit = LevMarLSQFitter()

# fit the data to the P92 model using the fitter
#   use the initialized model as the starting point
#   accuracy set to avoid warning the fit may have failed

# ignore some warnings
#   SpectralUnitsWarning is to avoid the units of x warning
#   AstropyWarning ignored to avoid the "fit may have been unsuccessful" warning
#   fit is fine, but this means the build of the docs fails
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=SpectralUnitsWarning)
    warnings.simplefilter("ignore", category=AstropyWarning)
    p92_fit = fit(p92_init, x.value, y, weights=1.0 / y_unc)

# plot the observed data, initial guess, and final fit
fig, ax = plt.subplots()

ax.errorbar(x.value, y, yerr=y_unc, fmt='ko', label='Observed Curve')
ax.plot(x.value, p92_init(x), label='Initial guess')
ax.plot(x.value, p92_fit(x), label='Fitted model')

ax.set_xlabel('$x$ [$\mu m^{-1}$]')
ax.set_ylabel('$A(x)/A(V)$')

# for 2nd x-axis with lambda values
axis_xs = np.array([0.1, 0.12, 0.15, 0.2, 0.3, 0.5, 1.0])
new_ticks = 1 / axis_xs
new_ticks_labels = ["%.2f" % z for z in axis_xs]
tax = ax.twiny()
tax.set_xlim(ax.get_xlim())
tax.set_xticks(new_ticks)
tax.set_xticklabels(new_ticks_labels)
tax.set_xlabel(r"$\lambda$ [$\mu$m]")

ax.legend(loc='best')
plt.tight_layout()
plt.show()