import warnings
import matplotlib.pyplot as plt
import numpy as np

from astropy.modeling.fitting import LevMarLSQFitter
import astropy.units as u

from dust_extinction.averages import G03_LMCAvg
from dust_extinction.shapes import FM90

# get an observed extinction curve to fit
g03_model = G03_LMCAvg()

x = g03_model.obsdata_x / u.micron
# convert to E(x-V)/E(B0V)
y = (g03_model.obsdata_axav - 1.0) * g03_model.Rv
# only fit the UV portion (FM90 only valid in UV)
(gindxs,) = np.where(x > 3.125 / u.micron)

# initialize the model
fm90_init = FM90()

# pick the fitter
fit = LevMarLSQFitter()

# fit the data to the FM90 model using the fitter
#   use the initialized model as the starting point

# ignore some warnings
#   UserWarning is to avoid the units of x warning
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)
    g03_fit = fit(fm90_init, x[gindxs].value, y[gindxs])

# plot the observed data, initial guess, and final fit
fig, ax = plt.subplots()

ax.plot(x, y, "ko", label="Observed Curve")
ax.plot(x[gindxs], fm90_init(x[gindxs]), label="Initial guess")
ax.plot(x[gindxs], g03_fit(x[gindxs]), label="Fitted model")

ax.set_xlabel("$x$ [$\mu m^{-1}$]")
ax.set_ylabel("$E(x-V)/E(B-V)$")

ax.set_title("Example FM90 Fit to G03_LMCAvg curve")

ax.legend(loc="best")
plt.tight_layout()
plt.show()