import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import numpy as np

import astropy.units as u
from astropy.modeling.models import BlackBody

from dust_extinction.parameter_averages import G23
from dust_extinction.conv_functions import unred

# generate wavelengths between 0.092 and 31 microns
#    within the valid range for the G23 R(V) dependent relationship
lam = np.logspace(np.log10(0.092), np.log10(31.0), num=1000)

# setup the inputs for the blackbody function
wavelengths = lam*1e4*u.AA
temperature = 10000*u.K

# get the blackbody flux
bb_lam = BlackBody(temperature, scale=1.0 * u.erg / (u.cm ** 2 * u.AA * u.s * u.sr))
flux = bb_lam(wavelengths)

# initialize the model
ext = G23(Rv=3.1)

# extinguish or redden the spectrum
flux_ext_ebv10 = flux*ext.extinguish(wavelengths, Ebv=1.0)

# unextinguish or unredden the spectrum
# positive input Ebv *dereddens* the spectrum
flux_unred = unred(wavelengths, flux_ext_ebv10, 1.0)

# plot the intrinsic and extinguished fluxes
fig, ax = plt.subplots()

ax.plot(wavelengths, flux_ext_ebv10, label='reddened spectrum')
ax.plot(wavelengths, flux_unred, label='unreddened spectrum')

ax.set_xlabel('$\lambda$ [$\AA$]')
ax.set_ylabel('$Flux$')

ax.set_xscale('log')
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.set_yscale('log')

ax.set_title('Example unreddening a blackbody')

ax.legend(loc='best')
plt.tight_layout()
plt.show()