import warnings
import matplotlib.pyplot as plt

from astropy.utils.exceptions import AstropyWarning
import astropy.units as u

from astropy.modeling.fitting import LevMarLSQFitter

from dust_extinction.averages import GCC09_MWAvg
from dust_extinction.shapes import P92

# get an observed extinction curve to fit
g09_model = GCC09_MWAvg()

# get an observed extinction curve to fit
x = g09_model.obsdata_x / u.micron
y = g09_model.obsdata_axav
y_unc = g09_model.obsdata_axav_unc

# initialize the model
p92_init = P92()

# fix a number of the parameters
#   mainly to avoid fitting parameters that are constrained at
#   wavelengths where the observed data for this case does not exist
p92_init.FUV_lambda.fixed = True
p92_init.SIL1_amp.fixed = True
p92_init.SIL1_lambda.fixed = True
p92_init.SIL1_b.fixed = True
p92_init.SIL2_amp.fixed = True
p92_init.SIL2_lambda.fixed = True
p92_init.SIL2_b.fixed = True
p92_init.FIR_amp.fixed = True
p92_init.FIR_lambda.fixed = True
p92_init.FIR_b.fixed = True

# pick the fitter
fit = LevMarLSQFitter()

# fit the data to the P92 model using the fitter
#   use the initialized model as the starting point
#   accuracy set to avoid warning the fit may have failed

# ignore some warnings
#   UserWarning is to avoid the units of x warning
#   AstropyWarning ignored to avoid the "fit may have been unsuccessful" warning
#   fit is fine, but this means the build of the docs fails
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)
    warnings.simplefilter("ignore", category=AstropyWarning)
    p92_fit = fit(p92_init, x.value, y, weights=1.0 / y_unc)

# plot the observed data, initial guess, and final fit
fig, ax = plt.subplots()

ax.errorbar(x.value, y, yerr=y_unc, fmt='ko', label='Observed Curve')
ax.plot(x.value, p92_init(x), label='Initial guess')
ax.plot(x.value, p92_fit(x), label='Fitted model')

ax.set_xlabel('$x$ [$\mu m^{-1}$]')
ax.set_ylabel('$A(x)/A(V)$')

ax.set_title('Example P92 Fit to GCC09_MWAvg average curve')

ax.legend(loc='best')
plt.tight_layout()
plt.show()