Skip to content

Commit

Permalink
Merge pull request #8 from mtakahiro/version1.3
Browse files Browse the repository at this point in the history
Version 1.3.0
  • Loading branch information
Takahiro Morishita authored Jul 26, 2020
2 parents fd1aeba + cf3e93f commit 90126d2
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 322 deletions.
2 changes: 1 addition & 1 deletion gsf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__author__ = 'Takahiro Morishita'
__email__ = '[email protected]'
__version__ = '1.2'
__version__ = '1.3'
__credits__ = 'STScI'
121 changes: 40 additions & 81 deletions gsf/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixelscale
def update_input(self, inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixelscale=0.06, Lsun=3.839*1e33, cosmo=None):
'''
INPUT:
==========
======
parfile: Ascii file that lists parameters for everything.
Mpc_cm : cm/Mpc
Expand Down Expand Up @@ -292,11 +292,13 @@ def get_lines(self, LW0):

def read_data(self, Cz0, Cz1, zgal, add_fir=False):
'''
Input:
======
Cz0, Cz1 : Normalization coeffs for grism spectra.
zgal : Current redshift estimate.
Note:
=======
=====
Can be used for any SFH
'''
Expand Down Expand Up @@ -408,12 +410,12 @@ def read_data(self, Cz0, Cz1, zgal, add_fir=False):
def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.01, lines=False, prior=None, method='powell'):
'''
Purpose:
=========
========
Search redshift space to find the best redshift and probability distribution.
Input:
=========
======
fm_tmp : a library for various templates. Should be in [ n * len(wavelength)].
xm_tmp : a wavelength array, common for the templates above, at z=0. Should be in [len(wavelength)].
Expand All @@ -425,7 +427,7 @@ def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.0
method : powell is more accurate. nelder is faster.
Return:
=========
=======
zspace :
chi2s :
Expand Down Expand Up @@ -513,19 +515,19 @@ def residual_z(pars,z):
def fit_redshift(self, dict, xm_tmp, fm_tmp, delzz=0.01, ezmin=0.01, zliml=0.01, zlimu=6., snlim=0):
'''
Purpose:
==========
========
Find an optimal redshift, before going into a big fit, by using several templates.
Input:
==========
======
delzz : Delta z in redshift search space
zliml : Lower limit range for redshift
zlimu : Upper limit range for redshift
ezmin : Minimum redshift uncertainty.
snlim : SN limit for data points. Those below the number will be cut from the fit.
Note:
==========
=====
Can be used for any SFH.
'''
Expand All @@ -538,7 +540,6 @@ def fit_redshift(self, dict, xm_tmp, fm_tmp, delzz=0.01, ezmin=0.01, zliml=0.01,
sn = dict['fy']/dict['ey']
# Only spec data?
con_cz = (dict['NR']<10000) & (sn>snlim)
#con_cz = (dict['NR']<100000) & (sn>snlim)
fy_cz = dict['fy'][con_cz] # Already scaled by self.Cz0
ey_cz = dict['ey'][con_cz]
x_cz = dict['x'][con_cz] # Observed range
Expand Down Expand Up @@ -800,19 +801,24 @@ def add_param(self, fit_params, sigz=1.0):


#def main(self, zgal, flag_m, Cz0, Cz1, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0, f_move=False):
def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0, f_move=False, verbose=False):
def main(self, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0, f_move=False, verbose=False, skip_fitz=False):
'''
Input:
========
flag_m : related to redshift error in redshift check func.
======
ferr : For error parameter
zgal : Input redshift.
#
#
# sigz (float): confidence interval for redshift fit.
# ezmin (float): minimum error in redshift.
#
skip_fitz (bool): Skip redshift fit.
sigz (float): confidence interval for redshift fit.
ezmin (float): minimum error in redshift.p
'''
import emcee
try:
import multiprocess
except:
import multiprocessing as multiprocess

from .posterior_flexible import Post

print('########################')
print('### Fitting Function ###')
Expand Down Expand Up @@ -886,7 +892,6 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
self.dict = dict

# Call likelihood/prior/posterior function;
from .posterior_flexible import Post
class_post = Post(self)

###############################
Expand Down Expand Up @@ -986,22 +991,20 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
AA_tmp = np.zeros(len(self.age), dtype='float64')
ZZ_tmp = np.zeros(len(self.age), dtype='float64')
fm_tmp, xm_tmp = fnc.tmp04_val(out, self.zgal, self.lib)
#print(self.lib[:,0])

########################
# Check redshift
########################
flag_z = self.fit_redshift(dict, xm_tmp, fm_tmp)
if skip_fitz:
flag_z = 'y'
else:
flag_z = self.fit_redshift(dict, xm_tmp, fm_tmp)

#################################################
# Gor for mcmc phase
#################################################
if flag_z == 'y' or flag_z == '':
#zrecom = self.zprev
#Czrec0 = self.Cz0
#Czrec1 = self.Cz1

# plot z-distribution
self.get_zdist()

#######################
Expand Down Expand Up @@ -1040,17 +1043,14 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0

################################
print('\nMinimizer Defined\n')
import emcee
mini = Minimizer(class_post.lnprob, out.params, fcn_args=[dict['fy'],dict['ey'],dict['wht2'],self.f_dust], f_disp=self.f_disp, moves=emcee.moves.DEMove(sigma=1e-05, gamma0=None)) #, f_move=f_move)
mini = Minimizer(class_post.lnprob, out.params, fcn_args=[dict['fy'],dict['ey'],dict['wht2'],self.f_dust], f_disp=self.f_disp, \
moves=[(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2),]
)
#moves=emcee.moves.DEMove(sigma=1e-05, gamma0=None))
print('######################')
print('### Starting emcee ###')
print('######################')

try:
import multiprocess
except:
import multiprocessing as multiprocess

ncpu0 = int(multiprocess.cpu_count()/2)
try:
ncpu = int(inputs['NCPU'])
Expand Down Expand Up @@ -1086,44 +1086,6 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
print('### Saving chain took %.1f sec'%(tcalc_mc))
print('#################################')

'''
Avmc = np.percentile(res.flatchain['Av'], [16,50,84])
Avpar = np.zeros((1,3), dtype='float64')
Avpar[0,:] = Avmc
####################
# Best parameters
####################
Amc = np.zeros((len(self.age),3), dtype='float64')
Ab = np.zeros(len(self.age), dtype='float64')
Zmc = np.zeros((len(self.age),3), dtype='float64')
Zb = np.zeros(len(self.age), dtype='float64')
NZbest = np.zeros(len(self.age), dtype='int')
f0 = fits.open(self.DIR_TMP + 'ms_' + self.ID + '_PA' + self.PA + '.fits')
sedpar = f0[1]
ms = np.zeros(len(self.age), dtype='float64')
for aa in range(len(self.age)):
Ab[aa] = res.params['A'+str(aa)].value
Amc[aa,:] = np.percentile(res.flatchain['A'+str(aa)], [16,50,84])
try:
Zb[aa] = res.params['Z'+str(aa)].value
Zmc[aa,:] = np.percentile(res.flatchain['Z'+str(aa)], [16,50,84])
except:
Zb[aa] = res.params['Z0'].value
Zmc[aa,:] = np.percentile(res.flatchain['Z0'], [16,50,84])
NZbest[aa]= bfnc.Z2NZ(Zb[aa])
ms[aa] = sedpar.data['ML_' + str(NZbest[aa])][aa]
Avb = res.params['Av'].value
if self.f_dust:
Mdust_mc = np.zeros(3, dtype='float64')
Tdust_mc = np.zeros(3, dtype='float64')
Mdust_mc[:] = np.percentile(res.flatchain['MDUST'], [16,50,84])
Tdust_mc[:] = np.percentile(res.flatchain['TDUST'], [16,50,84])
print(Mdust_mc)
print(Tdust_mc)
'''

####################
# MCMC corner plot.
Expand All @@ -1136,7 +1098,7 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
plot_density=False, levels=[0.68, 0.95, 0.997], truth_color='gray', color='#4682b4')
fig1.savefig('SPEC_' + self.ID + '_PA' + self.PA + '_corner.pdf')

# Do analysis on MCMC results.
# Analyze MCMC results.
# Write to file.
stop = timeit.default_timer()
tcalc = stop - start
Expand All @@ -1147,7 +1109,7 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
stop_mc = timeit.default_timer()
tcalc_mc = stop_mc - start_mc

return False #, self.zgal, self.Cz0, self.Cz1
return False


elif flag_z == 'm':
Expand Down Expand Up @@ -1176,7 +1138,7 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
print('\n\n')
print('Generate model templates with input redshift and Scale.')
print('\n\n')
return True #, self.zgal, self.Cz0, self.Cz1
return True

else:
print('\n\n')
Expand All @@ -1192,33 +1154,30 @@ def main(self, flag_m, cornerplot=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0
self.Cz0 = self.Czrec0
self.Cz1 = self.Czrec1

return True #, self.zgal, self.Cz0, self.Cz1
return True

else:
print('\n\n')
print('There is nothing to do.')
print('Terminating process.')
print('\n\n')
return -1 #, self.zgal, self.Czrec0, self.Czrec1
return -1



def quick_fit(self, zgal, flag_m, Cz0, Cz1, specplot=1, sigz=1.0, ezmin=0.01, ferr=0, f_move=False):
def quick_fit(self, zgal, Cz0, Cz1, specplot=1, sigz=1.0, ezmin=0.01, ferr=0, f_move=False):
'''
Purpose:
==========
Fit input data with a prepared template library, to get a chi-min result.
Input:
==========
flag_m : related to redshift error in redshift check func.
ferr : For error parameter
zgal : Input redshift.
#
#
# sigz (float): confidence interval for redshift fit.
# ezmin (float): minimum error in redshift.
#
sigz (float): confidence interval for redshift fit.
ezmin (float): minimum error in redshift.
'''
from .posterior_flexible import Post

Expand Down
41 changes: 38 additions & 3 deletions gsf/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,41 @@
LW0 = [2800, 3347, 3727, 3799, 3836, 3869, 4102, 4341, 4861, 4960, 5008, 5175, 6563, 6717, 6731]
fLW = np.zeros(len(LW0), dtype='int') # flag.


def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = None, printEnd = "\r", emojis = ['🥚','🐣','🐥','🦆']):
'''
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : positive number of decimals in percent complete (Int)
length - Optional : character length of bar (Int)
fill - Optional : bar fill character (Str)
printEnd - Optional : end character (e.g. "\r", "\r\n") (Str)
'''
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
if fill == None:
if float(percent) < 33:
fill = emojis[0]
elif float(percent) < 66:
fill = emojis[1]
elif float(percent) < 99:
fill = emojis[2]
else:
fill = emojis[3]


filledLength = int(length * iteration // total)
bar = fill * filledLength + '-' * (length - filledLength)
string = '(%d/%d)'%(iteration,total)
print(f'\r{prefix} |{bar}| {percent}% {suffix} {string}', end = printEnd)
# Print New Line on Complete
if iteration == total:
print()


def get_input():
'''
This returns somewhat a common default input dictionary.
Expand All @@ -34,13 +69,13 @@ def get_input():
def read_input(parfile):
'''
Purpose:
==========
========
#
# Get info from param file.
#
Return:
===========
=======
Input dictionary.
'''
Expand Down Expand Up @@ -818,7 +853,7 @@ def filconv_fast(filts, band, l0, f0, fw=False):
def filconv(band0, l0, f0, DIR, fw=False):
'''
Input:
============
======
f0: Flux for spectrum, in fnu
l0: Wavelength for spectrum, in AA (that matches filter response curve's.)
Expand Down
12 changes: 7 additions & 5 deletions gsf/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def run_gsf_template(inputs, fplt=0):
'''
Purpose:
==========
========
This is only for 0 and 1, to get templates.
Not for fitting, nor plotting.
Expand Down Expand Up @@ -130,7 +130,7 @@ def run_gsf_all(parfile, fplt, cornerplot=True):
#
MB.zprev = MB.zgal #zrecom # redshift from previous run

flag_suc = MB.main(0, cornerplot=cornerplot)
flag_suc = MB.main(cornerplot=cornerplot)

while (flag_suc and flag_suc!=-1):

Expand All @@ -142,7 +142,7 @@ def run_gsf_all(parfile, fplt, cornerplot=True):
print('Going into another trial with updated templates and redshift.')
print('\n\n')

flag_suc = MB.main(1, cornerplot=cornerplot)
flag_suc = MB.main(cornerplot=cornerplot)

# Total calculation time
stop = timeit.default_timer()
Expand All @@ -152,10 +152,12 @@ def run_gsf_all(parfile, fplt, cornerplot=True):
if fplt <= 3 and flag_suc != -1:
from .plot_sfh import plot_sfh
from .plot_sed import plot_sed
plot_sfh(MB, f_comp=MB.ftaucomp, fil_path=MB.DIR_FILT,

plot_sfh(MB, f_comp=MB.ftaucomp, fil_path=MB.DIR_FILT, mmax=100,
inputs=MB.inputs, dust_model=MB.dust_model, DIR_TMP=MB.DIR_TMP, f_SFMS=True)

plot_sed(MB, fil_path=MB.DIR_FILT,
figpdf=False, save_sed=True, inputs=MB.inputs, nmc_rand=1000,
figpdf=False, save_sed=True, inputs=MB.inputs, mmax=30,
dust_model=MB.dust_model, DIR_TMP=MB.DIR_TMP, f_label=True)


Expand Down
Loading

0 comments on commit 90126d2

Please sign in to comment.