Skip to content

Commit

Permalink
Merge pull request #30 from mtakahiro/minor_speeding_up
Browse files Browse the repository at this point in the history
Minor speeding up
  • Loading branch information
mtakahiro authored May 7, 2022
2 parents 4f55256 + a26db80 commit 604c838
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 171 deletions.
58 changes: 35 additions & 23 deletions gsf/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ class Mainbody():
or the width to the next age bin.
'''

def __init__(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixelscale=0.06, Lsun=3.839*1e33, cosmo=None, idman=None, zman=None):
def __init__(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:float=25.0, pixelscale:float=0.06, Lsun:float=3.839*1e33, cosmo=None, idman=None, zman=None):
self.update_input(inputs, idman=idman, zman=zman)


def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixelscale=0.06, Lsun=3.839*1e33, cosmo=None, \
idman=None, zman=None, sigz=5.0):
def update_input(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:float=25.0, pixelscale:float=0.06, Lsun:float=3.839*1e33, cosmo=None, \
idman=None, zman=None, sigz:float=5.0):
'''
The purpose of this module is to register/update the parameter attributes in `Mainbody`
by visiting the configuration file.
Expand Down Expand Up @@ -196,6 +196,7 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels

# Nebular emission;
self.fneb = False
self.nlogU = 0
try:
if int(inputs['ADD_NEBULAE']) == 1:
self.fneb = True
Expand All @@ -204,13 +205,16 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
self.logUMAX = float(inputs['logUMAX'])
self.DELlogU = float(inputs['DELlogU'])
self.logUs = np.arange(self.logUMIN, self.logUMAX, self.DELlogU)
self.nlogU = len(self.logUs)
except:
self.logUMIN = -2.5
self.logUMAX = -2.0
self.DELlogU = 0.5
self.logUs = np.arange(self.logUMIN, self.logUMAX, self.DELlogU)
self.nlogU = len(self.logUs)
try:
self.logUFIX = float(inputs['logUFIX'])
self.nlogU = 1
except:
self.logUFIX = None
except:
Expand Down Expand Up @@ -303,7 +307,6 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
aamin.append(nn)
self.aamin = aamin

#self.npeak = np.arange(0,len(self.age),1)
self.npeak = len(self.age)
self.nage = np.arange(0,len(self.age),1)

Expand Down Expand Up @@ -342,6 +345,7 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
print('Cannot find ZMC. Set to %d.'%(self.fzmc))

# Metallicity
self.has_ZFIX = False
try:
self.ZFIX = float(inputs['ZFIX'])
try:
Expand All @@ -353,16 +357,19 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
self.Zall = np.arange(self.Zmin, self.Zmax, self.delZ)
print('\n##########################')
print('ZFIX is found.\nZ will be fixed to: %.2f'%(self.ZFIX))
self.has_ZFIX = True
except:
self.Zmax, self.Zmin = float(inputs['ZMAX']), float(inputs['ZMIN'])
self.delZ = float(inputs['DELZ'])
if self.Zmax == self.Zmin or self.delZ == 0:
self.delZ = 0.0
self.ZFIX = self.Zmin
self.Zall = np.asarray([self.ZFIX])
self.has_ZFIX = True
elif np.abs(self.Zmax - self.Zmin) <= self.delZ:
self.ZFIX = self.Zmin
self.Zall = np.asarray([self.ZFIX])
self.has_ZFIX = True
else:
self.Zall = np.arange(self.Zmin, self.Zmax, self.delZ)
# If BPASS;
Expand All @@ -387,6 +394,7 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
self.Zall = np.arange(self.Zmin, self.Zmax, self.delZ) # in logZsun
print('\n##########################')
print('ZFIX is found.\nZ will be fixed to: %.2f'%(self.ZFIX))
self.has_ZFIX = True
except:
print('ZFIX is not found.')
print('Metallicities available in BPASS are limited and discrete. ZFIX is recommended.',self.Zall)
Expand All @@ -399,19 +407,22 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels


# N of param:
self.has_AVFIX = False
try:
Avfix = float(inputs['AVFIX'])
self.AVFIX = Avfix
self.nAV = 0
print('\n##########################')
print('AVFIX is found.\nAv will be fixed to:\n %.2f'%(Avfix))
self.has_AVFIX = True
except:
try:
self.Avmin = float(inputs['AVMIN'])
self.Avmax = float(inputs['AVMAX'])
if Avmin == Avmax:
self.nAV = 0
self.AVFIX = Avmin
self.has_AVFIX = True
else:
self.nAV = 1
except:
Expand Down Expand Up @@ -539,8 +550,10 @@ def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixels
try:
# Length of each ssp templates.
self.tau0 = np.asarray([float(x.strip()) for x in inputs['TAU0'].split(',')])
self.ntau0 = len(self.tau0)
except:
self.tau0 = np.asarray([-1.0])
self.ntau0 = 1

# IMF
try:
Expand Down Expand Up @@ -591,7 +604,7 @@ def get_lines(self, LW0):
return LW, fLW


def read_data(self, Cz0, Cz1, zgal, add_fir=False, idman=None):
def read_data(self, Cz0:float, Cz1:float, zgal:float, add_fir:bool=False, idman=None):
'''
Parameters
----------
Expand Down Expand Up @@ -646,9 +659,7 @@ def read_data(self, Cz0, Cz1, zgal, add_fir=False, idman=None):
eybb = np.asarray([])
exbb = np.asarray([])

#con_bb = (eybb>0)
con_bb = ()

xx2 = xbb[con_bb]
ex2 = exbb[con_bb]
fy2 = fybb[con_bb]
Expand All @@ -657,11 +668,11 @@ def read_data(self, Cz0, Cz1, zgal, add_fir=False, idman=None):
xx01 = np.append(xx0,xx1)
fy01 = np.append(fy0,fy1)
ey01 = np.append(ey0,ey1)
xx = np.append(xx01,xx2)
fy = np.append(fy01,fy2)
ey = np.append(ey01,ey2)
xx = np.append(xx01,xx2)
fy = np.append(fy01,fy2)
ey = np.append(ey01,ey2)

wht = 1./np.square(ey)
wht = 1./np.square(ey)
con_wht = (ey<0)
wht[con_wht] = 0

Expand Down Expand Up @@ -702,20 +713,22 @@ def read_data(self, Cz0, Cz1, zgal, add_fir=False, idman=None):
b = nrd_yyd
nrd_yyd_sort = b[np.lexsort(([-1,1]*b[:,[1,0]]).T)]
NR = nrd_yyd_sort[:,0]
x = nrd_yyd_sort[:,1]
x = nrd_yyd_sort[:,1]
fy = nrd_yyd_sort[:,2]
ey = nrd_yyd_sort[:,3]
wht = nrd_yyd_sort[:,4]
wht2= nrd_yyd_sort[:,5]

sn = fy/ey
self.n_optir = len(sn)

dict = {}
dict = {'NR':NR, 'x':x, 'fy':fy, 'ey':ey, 'NRbb':NRbb, 'xbb':xx2, 'exbb':ex2, 'fybb':fy2, 'eybb':ey2, 'wht':wht, 'wht2': wht2, 'sn':sn}

return dict


def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.01, lines=False, prior=None, method='powell'):
def search_redshift(self, dict, xm_tmp, fm_tmp, zliml:float=0.01, zlimu:float=6.0, delzz:float=0.01, lines:bool=False, prior=None, method:str='powell'):
'''
This module explores the redshift space to find the best redshift and probability distribution.
Expand Down Expand Up @@ -775,9 +788,6 @@ def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.0
xobs = np.append(x01,x2)

wht = 1./np.square(eycon)
#if lines:
# wht2, ypoly = check_line_cz_man(fcon, xobs, wht, fm_s, z)
#else:
wht2 = wht

# Set parameters;
Expand All @@ -786,8 +796,9 @@ def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.0
fit_par_cz.add('C%d'%nn, value=1., min=0., max=1e5)

def residual_z(pars,z):
vals = pars.valuesdict()

'''
'''
vals = pars.valuesdict()
xm_s = xm_tmp * (1+z)
fm_s = np.zeros(len(xm_tmp),'float')

Expand All @@ -810,7 +821,7 @@ def residual_z(pars,z):
out_cz = minimize(residual_z, fit_par_cz, args=([zspace[zz]]), method=method)
keys = fit_report(out_cz).split('\n')

csq = out_cz.chisqr
csq = out_cz.chisqr
rcsq = out_cz.redchi
fitc_cz = [csq, rcsq]

Expand Down Expand Up @@ -1466,9 +1477,9 @@ def get_shuffle(self, out, nshuf=3.0, amp=1e-4):
return pos


def main(self, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
f_move=False, verbose=False, skip_fitz=False, out=None, f_plot_accept=True,
f_shuffle=True, amp_shuffle=1e-2, check_converge=True, Zini=None, f_plot_chain=True):
def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
f_move:bool=False, verbose:bool=False, skip_fitz:bool=False, out=None, f_plot_accept:bool=True,
f_shuffle:bool=True, amp_shuffle=1e-2, check_converge:bool=True, Zini=None, f_plot_chain:bool=True):
'''
Main module of this script.
Expand Down Expand Up @@ -1580,6 +1591,7 @@ def main(self, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
fit_name = 'leastsq'
else:
fit_name = self.fneld

out = minimize(class_post.residual, self.fit_params, args=(self.dict['fy'], self.dict['ey'], self.dict['wht2'], self.f_dust), method=fit_name)
print('\nMinimizer refinement;')
print(fit_report(out))
Expand Down Expand Up @@ -1640,7 +1652,7 @@ def main(self, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
kwargs={'f_val':True, 'out':out, 'lnpreject':-np.inf},\
)
# Run MCMC
nburn = int(self.nmc / 10)
nburn = int(self.nmc/10)

print('Running burn-in')
sampler.run_mcmc(pos, nburn)
Expand Down
71 changes: 24 additions & 47 deletions gsf/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy.interpolate as interpolate

c = 3.e18 # A/s
d = 10**(73.6/2.5) # From [ergs/s/cm2/A] to [ergs/s/cm2/Hz]
#d = 10**(73.6/2.5) # From [ergs/s/cm2/A] to [ergs/s/cm2/Hz]

################
# Line library
Expand Down Expand Up @@ -236,14 +236,12 @@ def get_leastsq(MB, ZZtmp, fneld, age, fit_params, residual, fy, ey, wht, ID0, c

fwz.write('# minimizer: %s\n' % fit_name)

try:
if MB.has_ZFIX:
ZZtmp = [MB.ZFIX]
except:
pass

for zz in range(len(ZZtmp)):
for zz in range(MB.nZ):
ZZ = ZZtmp[zz]
for aa in range(len(age)):
for aa in range(MB.npeak):
if MB.ZEVOL == 1 or aa == 0:
fit_params['Z'+str(aa)].value = ZZ

Expand All @@ -255,15 +253,15 @@ def get_leastsq(MB, ZZtmp, fneld, age, fit_params, residual, fy, ey, wht, ID0, c

fwz.write('%s %.2f %.5f'%(ID0, ZZ, fitc[1]))

AA_tmp = np.zeros(len(age), dtype='float')
ZZ_tmp = np.zeros(len(age), dtype='float')
for aa in range(len(age)):
AA_tmp = np.zeros(MB.npeak, dtype='float')
ZZ_tmp = np.zeros(MB.npeak, dtype='float')
for aa in range(MB.npeak):
AA_tmp[aa] = out_tmp.params['A'+str(aa)].value
fwz.write(' %.5f'%(AA_tmp[aa]))

Av_tmp = out_tmp.params['Av'].value
fwz.write(' %.5f'%(Av_tmp))
for aa in range(len(age)):
for aa in range(MB.npeak):
if MB.ZEVOL == 1 or aa == 0:
ZZ_tmp[aa] = out_tmp.params['Z'+str(aa)].value
fwz.write(' %.5f'%(ZZ_tmp[aa]))
Expand Down Expand Up @@ -539,7 +537,9 @@ def get_filt(LIBFILT, NFILT):
f = open(LIBFILT + '', 'r')


def get_fit(x,y,xer,yer, nsfh='Del.'):
def get_fit(x, y, xer, yer, nsfh:str = 'Del.'):
'''
'''
from lmfit import Model, Parameters, minimize, fit_report, Minimizer

fit_params = Parameters()
Expand All @@ -549,7 +549,7 @@ def get_fit(x,y,xer,yer, nsfh='Del.'):
fit_params.add('A', value=1, min=0, max=5000)

def residual_tmp(pars):
vals = pars.valuesdict()
vals = pars.valuesdict()
t0_tmp, tau_tmp, A_tmp = vals['t0'],vals['tau'],vals['A']

if nsfh == 'Del.':
Expand All @@ -559,27 +559,16 @@ def residual_tmp(pars):
elif nsfh == 'Cons.':
model = SFH_cons(t0_tmp, tau_tmp, A_tmp, tt=x)

#con = (model>minsfr)
con = (model>0)
#print(model[con])
#resid = np.abs(model - y)[con] / np.sqrt(yer[con])
#resid = np.square(model - y)[con] / np.square(yer[con])
#resid = np.square(np.log10(model[con]) - y[con]) / np.square(yer[con])
#resid = (np.log10(model[con]) - y[con]) / np.sqrt(yer[con])
resid = (np.log10(model[con]) - y[con]) / yer[con]
#print(yer[con])
#resid = (model - y)[con] / (yer[con])
# i.e. residual/sigma
return resid


out = minimize(residual_tmp, fit_params, method='powell')
#out = minimize(residual, fit_params, method='nelder')
print(fit_report(out))

t0 = out.params['t0'].value
tau = out.params['tau'].value
A = out.params['A'].value
t0 = out.params['t0'].value
tau = out.params['tau'].value
A = out.params['A'].value
param = [t0, tau, A]

keys = fit_report(out).split('\n')
Expand Down Expand Up @@ -739,7 +728,7 @@ def dust_kc(lm, fl, Av, nr, Rv=4.05, gamma=0, lmlimu=3.115, lmv=5000/10000, f_Al
return fl_cor, lmmc*10000., nrd


def dust_calz(lm, fl, Av, nr, Rv=4.05, lmlimu=3.115, f_Alam=False):
def dust_calz(lm, fl, Av:float, nr, Rv:float = 4.05, lmlimu:float = 3.115, f_Alam:bool = False):
'''
Parameters
----------
Expand All @@ -756,40 +745,28 @@ def dust_calz(lm, fl, Av, nr, Rv=4.05, lmlimu=3.115, f_Alam=False):
lmlimu : float
Upper limit. 2.2 in Calz+00
'''
Kl = np.zeros(len(lm), dtype='float')
nrd = np.zeros(len(lm), dtype='float')
lmmc = np.zeros(len(lm), dtype='float')
flc = np.zeros(len(lm), dtype='float')
Kl = lm[:]*0 #np.zeros(len(lm), dtype='float')
nrd = lm[:]*0 #np.zeros(len(lm), dtype='float')
lmmc = lm[:]*0 #np.zeros(len(lm), dtype='float')
flc = lm[:]*0 #np.zeros(len(lm), dtype='float')

lmm = lm/10000. # in micron
lmm = lm/10000. # in micron
con1 = (lmm<=0.63)
con2 = (lmm>0.63) & (lmm<=lmlimu)
con3 = (lmm>lmlimu)

Kl1 = (2.659 * (-2.156 + 1.509/lmm[con1] - 0.198/lmm[con1]**2 + 0.011/lmm[con1]**3) + Rv)
Kl2 = (2.659 * (-1.857 + 1.040/lmm[con2]) + Rv)
Kl3 = (2.659 * (-1.857 + 1.040/lmlimu + lmm[con3] * 0) + Rv)
Kl[con1] = Kl1
Kl[con2] = Kl2
Kl[con3] = Kl3
Kl[con1] = (2.659 * (-2.156 + 1.509/lmm[con1] - 0.198/lmm[con1]**2 + 0.011/lmm[con1]**3) + Rv)
Kl[con2] = (2.659 * (-1.857 + 1.040/lmm[con2]) + Rv)
Kl[con3] = (2.659 * (-1.857 + 1.040/lmlimu + lmm[con3] * 0) + Rv)

nr1 = nr[con1]
nr2 = nr[con2]
nr3 = nr[con3]
nrd[con1] = nr[con1]
nrd[con2] = nr[con2]
nrd[con3] = nr[con3]

lmm1 = lmm[con1]
lmm2 = lmm[con2]
lmm3 = lmm[con3]
lmmc[con1] = lmm[con1]
lmmc[con2] = lmm[con2]
lmmc[con3] = lmm[con3]

fl1 = fl[con1]
fl2 = fl[con2]
fl3 = fl[con3]
flc[con1] = fl[con1]
flc[con2] = fl[con2]
flc[con3] = fl[con3]
Expand Down
Loading

0 comments on commit 604c838

Please sign in to comment.