Skip to content

Commit

Permalink
memory optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mtakahiro committed Jan 31, 2023
1 parent a133a66 commit a6ef2b6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 38 deletions.
48 changes: 15 additions & 33 deletions gsf/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
import scipy.interpolate as interpolate
from astropy.io import fits,ascii
import corner
import emcee
import zeus
import pandas as pd
import asdf

# import from custom codes
from .function import check_line_man, check_line_cz_man, calc_Dn4, savecpkl, get_leastsq, print_err
from .zfit import check_redshift,get_chi2
from .writing import get_param
from .function_class import Func
from .minimizer import Minimizer
from .posterior_flexible import Post

############################
py_v = (sys.version_info[0])
Expand All @@ -41,7 +46,6 @@
LW = [2800, 3347, 3727, 3799, 3836, 3869, 4102, 4341, 4861, 4960, 5008, 5175, 6563, 6717, 6731]
fLW = np.zeros(len(LW), dtype='int')

# NRbb_lim = 10000 # BB data is associated with ids greater than this number.

class Mainbody():
'''
Expand Down Expand Up @@ -298,10 +302,6 @@ def update_input(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:
self.band_rf['%s_lam'%(self.filts_rf[ii])] = fd[:,1]
self.band_rf['%s_res'%(self.filts_rf[ii])] = fd[:,2] / np.max(fd[:,2])

# Tau comparison?
# -> Deprecated;
# self.ftaucomp = 0

# Check if func model for SFH;
try:
self.SFH_FORM = int(inputs['SFH_FORM'])
Expand Down Expand Up @@ -473,9 +473,9 @@ def update_input(self, inputs, c:float=3e18, Mpc_cm:float=3.08568025e+24, m0set:
try:
self.Avmin = float(inputs['AVMIN'])
self.Avmax = float(inputs['AVMAX'])
if Avmin == Avmax:
if self.Avmin == self.Avmax:
self.nAV = 0
self.AVFIX = Avmin
self.AVFIX = self.Avmin
self.has_AVFIX = True
else:
self.nAV = 1
Expand Down Expand Up @@ -850,9 +850,9 @@ def fit_redshift(self, xm_tmp, fm_tmp, delzz=0.01, ezmin=0.01, zliml=0.01,
x_cz = self.dict['x'][con_cz] # Observed range
NR_cz = self.dict['NR'][con_cz]

# kind='cubic' causes an error if len(xm_tmp)<=3;
fint = interpolate.interp1d(xm_tmp, fm_tmp, kind='nearest', fill_value="extrapolate")
fm_s = fint(x_cz)
del fint

#
# If Eazy result exists;
Expand Down Expand Up @@ -997,9 +997,7 @@ def fit_redshift(self, xm_tmp, fm_tmp, delzz=0.01, ezmin=0.01, zliml=0.01,

# Visual inspection;
if self.fzvis==1:
#
# Ask interactively;
#
data_model_new = np.zeros((len(x_cz),4),'float')
data_model_new[:,0] = x_cz
data_model_new[:,1] = fm_s
Expand Down Expand Up @@ -1073,6 +1071,7 @@ def fit_redshift(self, xm_tmp, fm_tmp, delzz=0.01, ezmin=0.01, zliml=0.01,
print('Error is %.3f per cent.'%(eC2sigma*100))
print('##############################################################\n')
plt.show()
plt.close()

flag_z = raw_input('Do you want to continue with the input redshift, Cz0, Cz1, Cz2, and chi2/nu, %.5f %.5f %.5f %.5f %.5f? ([y]/n/m) '%\
(self.zgal, self.Cz0, self.Cz1, self.Cz2, self.fitc_cz_prev))
Expand Down Expand Up @@ -1113,7 +1112,7 @@ def get_zdist(self, f_interact=False, f_ascii=True):
fig = plt.figure(figsize=(6.5,2.5))
fig.subplots_adjust(top=0.96, bottom=0.16, left=0.09, right=0.99, hspace=0.15, wspace=0.25)
ax1 = fig.add_subplot(111)
n, nbins, patches = ax1.hist(self.res_cz.flatchain['z'], bins=200, density=True, color='gray', label='')
n, nbins, _ = ax1.hist(self.res_cz.flatchain['z'], bins=200, density=True, color='gray', label='')

yy = np.arange(0,np.max(n),1)
xx = yy * 0 + self.z_cz[1]
Expand Down Expand Up @@ -1152,11 +1151,9 @@ def get_zdist(self, f_interact=False, f_ascii=True):

if f_interact:
fig.savefig(file_out, dpi=300)
# return fig, ax1
else:
plt.savefig(file_out, dpi=300)
plt.close()
# return True
except:
print('z-distribution figure is not generated.')
pass
Expand Down Expand Up @@ -1517,15 +1514,6 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
f_plot_chain : book
Plot MC sample chain.
'''
import emcee
import zeus
try:
import multiprocess
except:
import multiprocessing as multiprocess

from .posterior_flexible import Post

# Call likelihood/prior/posterior function;
class_post = Post(self)

Expand Down Expand Up @@ -1572,10 +1560,7 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
print('#####################################')
print('\n\n')

Av_tmp = out.params['Av'].value
AA_tmp = np.zeros(len(self.age), dtype='float')
ZZ_tmp = np.zeros(len(self.age), dtype='float')
nrd_tmp, fm_tmp, xm_tmp = self.fnc.get_template(out, f_val=True, f_nrd=True, f_neb=False)
_, fm_tmp, xm_tmp = self.fnc.get_template(out, f_val=True, f_nrd=True, f_neb=False)
else:
csq = out.chisqr
rcsq = out.redchi
Expand Down Expand Up @@ -1740,7 +1725,7 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,

# Plot for chain.
if f_plot_chain:
fig, axes = plt.subplots(self.ndim, figsize=(10, 7), sharex=True)
_, axes = plt.subplots(self.ndim, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = []
for key in out.params.valuesdict():
Expand All @@ -1761,6 +1746,9 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
axes.set_xlabel("step number")
plt.savefig('%s/chain_%s.png'%(self.DIR_OUT,self.ID))
plt.close()
# For memory optimization;
del samples, axes


# Similar for nested;
# Dummy just to get structures;
Expand All @@ -1785,7 +1773,6 @@ def main(self, cornerplot:bool=True, specplot=1, sigz=1.0, ezmin=0.01, ferr=0,
params_value[key] = np.median(flat_samples[nburn:,ii])
ii += 1

import pandas as pd
flatchain = pd.DataFrame(data=flat_samples[nburn:,:], columns=var_names)

class get_res:
Expand Down Expand Up @@ -1847,7 +1834,6 @@ def __init__(self, flatchain, var_names, params_value, res):
params_value[key] = np.median(res0.samples[nburn:,ii])
ii += 1

import pandas as pd
flatchain = pd.DataFrame(data=res0.samples[nburn:,:], columns=var_names)

class get_res:
Expand Down Expand Up @@ -1886,7 +1872,6 @@ def __init__(self, flatchain, var_names, params_value, res):
'burnin':burnin, 'nwalkers':self.nwalk,'niter':self.nmc,'ndim':self.ndim},
savepath+cpklname) # Already burn in
else:
import asdf
cpklname = 'chain_' + self.ID + '_corner.asdf'
tree = {'chain':flatchain.to_dict(), 'burnin':burnin, 'nwalkers':self.nwalk,'niter':self.nmc,'ndim':self.ndim}
af = asdf.AsdfFile(tree)
Expand Down Expand Up @@ -1930,7 +1915,6 @@ def __init__(self, flatchain, var_names, params_value, res):

return 2 # Cannot set to 1, to distinguish from retuen True


elif flag_z == 'm':
zrecom = raw_input('What is your manual input for redshift? [%.3f] '%(self.zgal))
if zrecom != '':
Expand Down Expand Up @@ -2072,8 +2056,6 @@ def search_redshift(self, dict, xm_tmp, fm_tmp, zliml=0.01, zlimu=6.0, delzz=0.0
chi2s : numpy.array
Array of chi2 values corresponding to zspace.
'''
import scipy.interpolate as interpolate

zspace = np.arange(zliml,zlimu,delzz)
chi2s = np.zeros((len(zspace),2), 'float')
if prior == None:
Expand Down
11 changes: 6 additions & 5 deletions gsf/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .maketmp_filt import maketemp,maketemp_tau
from .function_class import Func,Func_tau
from .basic_func import Basic,Basic_tau
from .function import read_input

import timeit
start = timeit.default_timer()
Expand Down Expand Up @@ -81,7 +82,7 @@ def run_gsf_template(inputs, fplt=0, tau_lim=0.001, idman=None, nthin=1, delwave
return MB


def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
def run_gsf_all(parfile, fplt, cornerplot=True, f_plot_chain=True, f_Alog=True, idman:str=None,
zman=None, zman_min=None, zman_max=None, f_label=True, f_symbol=True,
f_SFMS=False, f_fill=True, save_sed=True, figpdf=False, mmax=300,
f_prior_sfh=False, norder_sfh_prior=3,
Expand All @@ -104,7 +105,6 @@ def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
######################
# Read from Input file
######################
from .function import read_input
inputs = read_input(parfile)
MB = Mainbody(inputs, c=3e18, Mpc_cm=3.08568025e+24, m0set=25.0, pixelscale=0.06,
cosmo=cosmo, idman=idman, zman=zman, zman_min=zman_min, zman_max=zman_max)
Expand Down Expand Up @@ -171,7 +171,7 @@ def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
#
# 2. Main fitting part.
#
flag_suc = MB.main(cornerplot=cornerplot, f_shuffle=f_shuffle, amp_shuffle=amp_shuffle, Zini=Zini,
flag_suc = MB.main(cornerplot=cornerplot, f_plot_chain=f_plot_chain, f_shuffle=f_shuffle, amp_shuffle=amp_shuffle, Zini=Zini,
f_prior_sfh=f_prior_sfh, norder_sfh_prior=norder_sfh_prior)

while (flag_suc and flag_suc!=2):
Expand All @@ -191,7 +191,7 @@ def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
print('Going into another round with updated templates and redshift.')
print('\n\n')

flag_suc = MB.main(cornerplot=cornerplot, f_shuffle=f_shuffle, amp_shuffle=amp_shuffle, Zini=Zini,
flag_suc = MB.main(cornerplot=cornerplot, f_plot_chain=f_plot_chain, f_shuffle=f_shuffle, amp_shuffle=amp_shuffle, Zini=Zini,
f_prior_sfh=f_prior_sfh, norder_sfh_prior=norder_sfh_prior)

# Total calculation time
Expand Down Expand Up @@ -233,7 +233,6 @@ def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
dust_model=MB.dust_model, DIR_TMP=MB.DIR_TMP, f_label=f_label, f_fill=f_fill,
f_fancyplot=f_fancyplot, f_plot_resid=f_plot_resid, scale=scale, f_plot_filter=f_plot_filter)


if fplt == 6:
# Use the final redshift;
hd_sum = fits.open(os.path.join(MB.DIR_OUT, 'summary_%s.fits'%MB.ID))[0].header
Expand All @@ -252,6 +251,8 @@ def run_gsf_all(parfile, fplt, cornerplot=True, f_Alog=True, idman:str=None,
#from .plot_sed_logA import plot_corner_physparam_summary_tau as plot_corner_physparam_summary
print('One for Tau model is TBD...')

return MB


if __name__ == "__main__":
'''
Expand Down

0 comments on commit a6ef2b6

Please sign in to comment.