Skip to content

Commit

Permalink
Merge pull request #54 from mtakahiro/gsf-bpass
Browse files Browse the repository at this point in the history
bpass
  • Loading branch information
mtakahiro authored Jan 27, 2024
2 parents 049234c + d33e4d1 commit c0b6a90
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 21 deletions.
33 changes: 26 additions & 7 deletions gsf/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:floa
'Fitting' : ['MC_SAMP', 'NMC', 'NWALK', 'NMCZ', 'NWALKZ',
'FNELD', 'NCPU', 'F_ERR', 'ZVIS', 'F_MDYN',
'NTEMP', 'DISP', 'SIG_TEMP', 'VDISP',
'FORCE_AGE', 'NORDER_SFH_PRIOR', ],
'FORCE_AGE', 'NORDER_SFH_PRIOR', 'NEBULAE_PRIOR'],

'Data' : ['ID', 'MAGZP', 'DIR_TEMP',
'CAT_BB', 'CAT_BB_DUST', 'SNLIM',
Expand Down Expand Up @@ -377,6 +377,15 @@ def update_input(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:
if 'ADD_NEBULAE' in self.input_keys:
if str2bool(inputs['ADD_NEBULAE']):
self.fneb = True
try:
# Correlation between Aneb and LW age? May add some time; see posterior_flexible
if inputs['NEBULAE_PRIOR'] == '1':
self.neb_correlate = True
else:
self.neb_correlate = False
except:
self.neb_correlate = False

try:
self.logUMIN = float(inputs['logUMIN'])
except:
Expand Down Expand Up @@ -1789,23 +1798,27 @@ def prepare_class(self, add_fir=None):

def get_shuffle(self, out, nshuf=3.0, amp=1e-4):
'''
amp : amplitude, 0 to 1.
Shuffles initial parameters of each walker, to give it extra randomeness.
'''
if amp>1:
amp = 1
pos = np.zeros((self.nwalk, self.ndim), 'float')
for ii in range(pos.shape[0]):
aa = 0
for aatmp,key in enumerate(out.params.valuesdict()):
if out.params[key].vary == True:
pos[ii,aa] += out.params[key].value
# This is critical to avoid parameters fall on the boundary.
delpar = (out.params[key].max-out.params[key].min)/1000
delpar = (out.params[key].max-out.params[key].min) * amp/2.
# or not,
delpar = 0
# delpar = 0
if np.random.uniform(0,1) > (1. - 1./self.ndim):
pos[ii,aa] = np.random.uniform(out.params[key].min+delpar, out.params[key].max-delpar)
pos[ii,aa] = np.random.uniform(out.params[key].value-delpar, out.params[key].value+delpar)
else:
if pos[ii,aa]<out.params[key].min+delpar or pos[ii,aa]>out.params[key].max-delpar:
pos[ii,aa] = np.random.uniform(out.params[key].min+delpar, out.params[key].max-delpar)
if pos[ii,aa]<out.params[key].min or pos[ii,aa]>out.params[key].max:
pos[ii,aa] = np.random.uniform(out.params[key].value-delpar, out.params[key].value+delpar)

aa += 1
return pos
Expand Down Expand Up @@ -1956,10 +1969,16 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
pos = self.get_shuffle(out, amp=amp_shuffle)
else:
pos = amp_shuffle * np.random.randn(self.nwalk, self.ndim)
pos += self.get_shuffle(out, amp=0)
# Check boundary;
aa = 0
for aatmp,key in enumerate(out.params.valuesdict()):
if out.params[key].vary:
pos[:,aa] += out.params[key].value
con = (out.params[key].min > pos[:,aa])
pos[:,aa][con] = out.params[key].min
con = (out.params[key].max < pos[:,aa])
pos[:,aa][con] = out.params[key].max
# pos[:,aa] = out.params[key].value
aa += 1

if self.f_zeus:
Expand Down
86 changes: 76 additions & 10 deletions gsf/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datetime import datetime
from astropy import units as u
from astropy.cosmology import WMAP9
from dust_extinction.averages import G03_SMCBar

################
# Line library
Expand Down Expand Up @@ -739,7 +740,7 @@ def apply_dust(yy, xx, nr, Av, dust_model=0):
elif dust_model == 2: # LMC
yyd, xxd, nrd = dust_gen(xx, yy, Av, nr, Rv=4.05, gamma=-0.06, Eb=2.8)
elif dust_model == 3: # SMC
yyd, xxd, nrd = dust_smc(xx, yy, Av, nr, Rv=2.74, x0=4.703, gamma=1.212, f_Alam=False)
yyd, xxd, nrd = dust_smc(xx, yy, Av, nr, Rv=2.74, x0=4.6, gamma=1.00, f_Alam=False)
elif dust_model == 4: # Kriek&Conroy with gamma=-0.2
yyd, xxd, nrd = dust_kc(xx, yy, Av, nr, Rv=4.05, gamma=-0.2)
else:
Expand Down Expand Up @@ -770,7 +771,7 @@ def dust_smc(lm, fl, Av, nr, Rv=2.74, x0=4.6, gamma=1.0, f_Alam=False):
# if any(np.diff(lm)<0):
# print('Something is wrong in lm: dust_smc of function.py')

lmm = lm/10000. # into micron
lmm = lm/10000. # into micron
nrd = nr #np.concatenate([nr1,nr2,nr3])
lmmc = lmm #np.concatenate([lmm1,lmm2,lmm3])
flc = fl #np.concatenate([fl1,fl2,fl3])
Expand All @@ -779,31 +780,61 @@ def dust_smc(lm, fl, Av, nr, Rv=2.74, x0=4.6, gamma=1.0, f_Alam=False):
c1,c2,c3,c4 = -4.959, 2.264, 0.389, 0.461
# SMC Wing Sample;
# c1,c2,c3,c4 = -0.856, 1.038, 3.215, 0.107
# x0,gamma = 4.703,1.212

x = 1./lmmc

# Manual
Dx = x**2 / ((x**2-x0**2)**2 + x**2*gamma**2)
Fx = 0.5392 * (x - 5.9)**2 + 0.05644 * (x-5.9)**3
con_fx = (x<5.9)
Fx[con_fx] = 0

EBlam_to_EB = c1 + c2*x + c3*Dx + c4*Fx
Alam = Av / Rv * EBlam_to_EB
Alam_to_Av = 1 + EBlam_to_EB / Rv

# By following Gordon's script here, https://github.com/karllark/dust_extinction,
# Generate region redder than 2760A by interpolation
lam_red = 2760. # AA
ref_wavs = np.array([0.276, 0.296, 0.37, 0.44, 0.55,
0.65, 0.81, 1.25, 1.65, 2.198, 3.1])*10**4
ref_ext = np.array([2.220, 2.000, 1.672, 1.374, 1.00,
0.801, 0.567, 0.25, 0.169, 0.11, 0.])

if np.max(lm) > lam_red:
Alam_to_Av[lm > lam_red] = np.interp(lm[lm > lam_red], ref_wavs, ref_ext, right=0.)

# Dust attenuation package;
# ext_model = G03_SMCBar()
# Alam_to_Av = ext_model(x/u.micron)

Alam = Av * Alam_to_Av
fl_cor = flc[:] * 10**(-0.4*Alam[:])

if False:
if False:#True:#
import matplotlib.pyplot as plt
# define the extinction model
ext_model = G03_SMCBar()
# generate the curves and plot them

plt.close()
xs = np.arange(0.5, 10, 0.01)
x = xs
Dx = x**2 / ((x**2-x0**2)**2 + x**2*gamma**2)
x = np.arange(ext_model.x_range[0], ext_model.x_range[1],0.1)/u.micron
plt.plot(x,ext_model(x),label='G03 SMCBar')

print(c1,c2,c3,c4, x0, gamma)
lm = np.arange(0.1, 2.0, 0.01)
x = 1/lm
Dx = x**2 / ((x**2-x0**2)**2 + (x**2)*(gamma**2))
Fx = 0.5392 * (x - 5.9)**2 + 0.05644 * (x-5.9)**3
con_fx = (x<5.9)
Fx[con_fx] = 0
EBlam_to_EB = c1 + c2*x + c3*Dx + c4*Fx
Av = 2.0
Av = 2.5
Alam = Av / Rv * EBlam_to_EB
plt.scatter(x, EBlam_to_EB, )
Alam_to_Av = 1 + EBlam_to_EB / Rv
plt.scatter(x, Alam_to_Av, color='k', marker='+')
plt.xlim(0.1, 9.0)
plt.show()
# Why mismatch with Gordon?
hoge

if f_Alam:
Expand Down Expand Up @@ -1000,6 +1031,41 @@ def dust_calz(lm, fl, Av:float, nr, Rv:float = 4.05, lmlimu:float = 3.115, f_Ala
Alam = Kl * Av / Rv
fl_cor = flc[:] * 10**(-0.4*Alam[:])

if False:#True:#
import matplotlib.pyplot as plt
plt.close()
lmm = np.arange(0.1, 2.0, 0.01)
nr = np.arange(0,len(lmm),1)
Kl = lmm[:]*0 #np.zeros(len(lm), dtype='float')
nrd = lmm[:]*0 #np.zeros(len(lm), dtype='float')
lmmc = lmm[:]*0 #np.zeros(len(lm), dtype='float')
flc = lmm[:]*0 #np.zeros(len(lm), dtype='float')
con1 = (lmm<=0.63)
con2 = (lmm>0.63) & (lmm<=lmlimu)
con3 = (lmm>lmlimu)

Kl[con1] = (2.659 * (-2.156 + 1.509/lmm[con1] - 0.198/lmm[con1]**2 + 0.011/lmm[con1]**3) + Rv)
Kl[con2] = (2.659 * (-1.857 + 1.040/lmm[con2]) + Rv)
Kl[con3] = (2.659 * (-1.857 + 1.040/lmlimu + lmm[con3] * 0) + Rv)

nrd[con1] = nr[con1]
nrd[con2] = nr[con2]
nrd[con3] = nr[con3]

lmmc[con1] = lmm[con1]
lmmc[con2] = lmm[con2]
lmmc[con3] = lmm[con3]

Av = 2.5
Alam = Kl * Av / Rv
plt.scatter(lmm, 10**(-0.4*Alam[:]), )
plt.xlim(0.1, 2.0)
plt.xscale('log')
plt.yscale('log')
plt.show()
hoge


if f_Alam:
return fl_cor, lmmc*10000., nrd, Alam
else:
Expand Down
72 changes: 68 additions & 4 deletions gsf/posterior_flexible.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy import exp as np_exp
from numpy import log as np_log
from scipy.special import erf
from scipy.stats import lognorm
from scipy.stats import lognorm,norm

class Post():
'''
Expand Down Expand Up @@ -141,6 +141,18 @@ def swap_pars_inv(self, pars):
return pars


def get_lw_age(self, vals):
'''
'''
tlw_tmp = 0
amp_tmp = 0
for nn in range(len(self.mb.age)):
key = 'A%d'%nn
tlw_tmp += 10**vals[key].value * self.mb.age[nn]
amp_tmp += 10**vals[key].value
tlw_tmp /= amp_tmp
return tlw_tmp, amp_tmp

def lnprob_emcee(self, pos, pars, fy:float, ey:float, wht:float, NR:float, f_fir:bool, f_chind:bool=True, SNlim:float=1.0, f_scale:bool=False,
lnpreject=-np.inf, f_like:bool=False, flat_prior:bool=False, gauss_prior:bool=True, f_val:bool=True, nsigma:float=1.0, out=None,
f_prior_sfh=False, alpha_sfh_prior=100, norder_sfh_prior=3, verbose=False, NRbb_lim=10000):
Expand Down Expand Up @@ -278,14 +290,34 @@ def lnprob_emcee(self, pos, pars, fy:float, ey:float, wht:float, NR:float, f_fir

# lognormal-prior for any params;
for ii,key_param in enumerate(self.mb.key_params_prior):
sigma = self.mb.key_params_prior_sigma[ii]
respr += self.get_lognormal_prior(vals, key_param, sigma=sigma, mu=0)
if key_param[:2] == 'AV':
sigma = self.mb.key_params_prior_sigma[ii]
respr += self.get_normal_prior(vals, key_param, sigma=sigma, mu=0)
else:
sigma = self.mb.key_params_prior_sigma[ii]
respr += self.get_lognormal_prior(vals, key_param, sigma=sigma, mu=0)

# Prior for emission line template??;
# Still in experiment;
if self.mb.neb_correlate:
respr += self.get_prior_neb(vals)

lnposterior = lnlike + respr

if not np.isfinite(lnposterior):
return lnpreject

return lnposterior


def get_prior_neb(self, vals, alpha=1.0):
'''
'''
tlw_tmp, amp_tmp = self.get_lw_age(vals)
# respr = np.log(tlw_tmp * 10**vals['Aneb']) #self.get_lognormal_prior(vals, key_param, sigma=sigma, mu=0)
Aneb_predict = np.log10(1/tlw_tmp) / (np.log10(self.mb.age.max()) - np.log10(self.mb.age.min()))
respr = -0.5 * ((Aneb_predict-(vals['Aneb']+np.log10(amp_tmp)))**2 * alpha)
return respr


def get_sfh_prior(self, vals, norder=3, alpha=100.0):
Expand Down Expand Up @@ -338,4 +370,36 @@ def get_lognormal_prior(self, vals, key_param, mu=0, sigma=100.0, check_prior=Fa
plt.plot(yy, np.log(self.mb.prior[key_param].pdf(yy)))
plt.show()

return np.log(self.mb.prior[key_param].pdf(y))
respr = np.log(self.mb.prior[key_param].pdf(y))

if not np.isfinite(respr):
respr = -1e10

return respr

def get_normal_prior(self, vals, key_param, mu=0, sigma=100.0, check_prior=False):
'''
'''
y = vals[key_param]

if self.mb.prior == None:
self.mb.prior = {}
for key_param_tmp in self.mb.fit_params:
self.mb.prior[key_param_tmp] = None

if self.mb.prior[key_param] == None:
self.mb.logger.info('Using normal prior for %s'%key_param)
self.mb.prior[key_param] = norm()
if check_prior:
import matplotlib.pyplot as plt
plt.close()
yy = np.arange(-2,2,0.1)
plt.plot(yy, np.log(self.mb.prior[key_param].pdf((yy-mu) * np.sqrt(2) / sigma)))
plt.show()

respr = np.log(self.mb.prior[key_param].pdf((y-mu) * np.sqrt(2) / sigma))

if not np.isfinite(respr):
respr = -1e10

return respr

0 comments on commit c0b6a90

Please sign in to comment.