Skip to content

Commit

Permalink
Refactor synthesizers to better preserve seed info.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Aug 29, 2024
1 parent 7646e94 commit 304074e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 61 deletions.
2 changes: 1 addition & 1 deletion riid/data/sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,13 +1800,13 @@ def _pcf_dict_to_ss(pcf_dict: dict, verbose=True):
"total_counts": sum(spectrum["spectrum"]),
"neutron_counts": spectrum["header"]["Total_Neutron_Counts"],
"distance_cm": distance,
"areal_density": ad,
"ecal_order_0": order_0,
"ecal_order_1": order_1,
"ecal_order_2": order_2,
"ecal_order_3": order_3,
"ecal_low_e": low_E,
"atomic_number": an,
"areal_density": ad,
"occupancy_flag": spectrum["header"]["Occupancy_Flag"],
"tag": spectrum["header"]["Tag"],
}
Expand Down
55 changes: 13 additions & 42 deletions riid/data/synthetic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def _verify_n_samples_synthesized(self, actual: int, expected: int):
"Be sure to remove any columns from your seeds' sources DataFrame that "
"contain all zeroes.")

def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal,
lt_targets, snr_targets, rt_targets=None, distance_cm=None):
def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, lt_targets, snr_targets):
if not (self.return_fg or self.return_gross):
raise ValueError("Computing to return nothing.")

Expand Down Expand Up @@ -127,10 +126,8 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal,
# Sample sets
if self.return_fg:
snrs = fg_counts / np.sqrt(long_bg_counts.clip(1))
fg_ss = get_fg_sample_set(fg_spectra, fg_sources, ecal, lt_targets,
snrs=snrs, total_counts=fg_counts,
real_times=rt_targets, distance_cm=distance_cm,
timestamps=self._synthesis_start_dt)
fg_ss = get_fg_sample_set(fg_spectra, fg_sources, lt_targets,
snrs=snrs, total_counts=fg_counts)
self._n_samples_synthesized += fg_ss.n_samples
if self.return_gross:
tiled_fg_sources = _tile_sources_and_scale(
Expand All @@ -146,40 +143,28 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal,
gross_sources = get_merged_sources_samplewise(tiled_fg_sources, tiled_bg_sources)
gross_counts = gross_spectra.sum(axis=1)
snrs = fg_counts / np.sqrt(bg_counts.clip(1))
gross_ss = get_gross_sample_set(gross_spectra, gross_sources, ecal,
lt_targets, snrs, gross_counts,
real_times=rt_targets, distance_cm=distance_cm,
timestamps=self._synthesis_start_dt)
gross_ss = get_gross_sample_set(gross_spectra, gross_sources,
lt_targets, snrs, gross_counts)
self._n_samples_synthesized += gross_ss.n_samples

return fg_ss, gross_ss


def get_sample_set(spectra, sources, ecal, live_times, snrs, total_counts=None,
real_times=None, distance_cm=None, timestamps=None,
descriptions=None) -> SampleSet:
def _get_minimal_ss(spectra, sources, live_times, snrs, total_counts=None) -> SampleSet:
n_samples = spectra.shape[0]
if n_samples <= 0:
raise ValueError(f"Can't build SampleSet with {n_samples} samples.")

ss = SampleSet()
ss.spectra_state = SpectraState.Counts
ss.spectra = pd.DataFrame(spectra)
ss.sources = sources
ss.info.description = np.full(n_samples, "") # Ensures the length of info equal n_samples
if descriptions:
ss.info.description = descriptions
ss.info.snr = snrs
ss.info.timestamp = timestamps
ss.info.total_counts = total_counts if total_counts is not None else spectra.sum(axis=1)
ss.info.ecal_order_0 = ecal[0]
ss.info.ecal_order_1 = ecal[1]
ss.info.ecal_order_2 = ecal[2]
ss.info.ecal_order_3 = ecal[3]
ss.info.ecal_low_e = ecal[4]
ss.info.live_time = live_times
ss.info.real_time = real_times if real_times is not None else live_times
ss.info.distance_cm = distance_cm
ss.info.occupancy_flag = 0
ss.info.tag = " " # TODO: test if this can be empty string
ss.info.tag = " " # TODO: test if this can be an empty string

return ss

Expand All @@ -196,44 +181,30 @@ def _tile_sources_and_scale(sources, n_samples, scalars) -> pd.DataFrame:
return tiled_sources


def get_fg_sample_set(spectra, sources, ecal, live_times, snrs, total_counts,
real_times=None, distance_cm=None, timestamps=None,
descriptions=None) -> SampleSet:
def get_fg_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet:
tiled_sources = _tile_sources_and_scale(
sources,
spectra.shape[0],
spectra.sum(axis=1)
)
ss = get_sample_set(
ss = _get_minimal_ss(
spectra=spectra,
sources=tiled_sources,
ecal=ecal,
live_times=live_times,
snrs=snrs,
total_counts=total_counts,
real_times=real_times,
distance_cm=distance_cm,
timestamps=timestamps,
descriptions=descriptions
)
ss.spectra_type = SpectraType.Foreground
return ss


def get_gross_sample_set(spectra, sources, ecal, live_times, snrs, total_counts,
real_times=None, distance_cm=None, timestamps=None,
descriptions=None) -> SampleSet:
ss = get_sample_set(
def get_gross_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet:
ss = _get_minimal_ss(
spectra=spectra,
sources=sources,
ecal=ecal,
live_times=live_times,
snrs=snrs,
total_counts=total_counts,
real_times=real_times,
distance_cm=distance_cm,
timestamps=timestamps,
descriptions=descriptions
)
ss.spectra_type = SpectraType.Gross
return ss
Expand Down
48 changes: 42 additions & 6 deletions riid/data/synthetic/passby.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _calculate_passby_shape(self, fwhm: float):
return 1 / (np.power(samples, 2) + 1)

def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float,
fg_seed: np.array, bg_seed: np.array, fg_ecal: np.array,
fg_seed: np.array, bg_seed: np.array,
fg_sources: pd.Series, bg_sources: pd.Series):
"""Generate a `SampleSet` with a sequence of spectra representative of a single pass-by.
Expand Down Expand Up @@ -212,7 +212,6 @@ def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float,
fg_sources,
bg_seed,
bg_sources,
fg_ecal,
live_times,
snr_targets
)
Expand Down Expand Up @@ -275,17 +274,54 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
for fg_i in range(fg_seeds_ss.n_samples):
fg_pmf = fg_seeds_ss.spectra.iloc[fg_i]
fg_sources = fg_seeds_ss.sources.iloc[fg_i]
fg_ecal = fg_seeds_ss.ecal[fg_i]
for t_i in range(self.events_per_seed):
fwhm = fwhm_targets[t_i]
snr = snr_targets[t_i]
dwell_time = dwell_time_targets[t_i]
pb_args = (fwhm, snr, dwell_time, fg_pmf, bg_pmf,
fg_ecal, fg_sources, bg_sources)
pb_args = (fg_i, fwhm, snr, dwell_time, fg_pmf, bg_pmf,
fg_sources, bg_sources)
args.append(pb_args)

# TODO: follow prevents periodic progress reports
passbys = [self._generate_single_passby(*a) for a in args]
passbys = []
for a in args:
f, fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources = a
fg_passby_ss, gross_passby_ss = self._generate_single_passby(
fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources
)
live_times = None
if fg_passby_ss is not None:
live_times = fg_passby_ss.info.live_time
elif gross_passby_ss is not None:
live_times = gross_passby_ss.info.live_time
else:
live_times = 1.0

fg_seed_ecal = fg_seeds_ss.ecal[f]
fg_seed_info = fg_seeds_ss.info.iloc[f]
batch_rt_targets = live_times * (1 - fg_seed_info.dead_time_prop)
fg_seed_distance_cm = fg_seed_info.distance_cm
fg_seed_dead_time_prop = fg_seed_info.dead_time_prop
fg_seed_ad = fg_seed_info.areal_density
fg_seed_an = fg_seed_info.atomic_number
fg_seed_neutron_counts = fg_seed_info.neutron_counts

def _set_remaining_info(ss: SampleSet | None):
if ss is None:
return
ss.ecal = fg_seed_ecal
ss.info.real_time = batch_rt_targets
ss.info.distance_cm = fg_seed_distance_cm
ss.info.dead_time_prop = fg_seed_dead_time_prop
ss.info.areal_density = fg_seed_ad
ss.info.atomic_number = fg_seed_an
ss.info.neutron_counts = fg_seed_neutron_counts
ss.info.timestamp = self._synthesis_start_dt

_set_remaining_info(fg_passby_ss)
_set_remaining_info(gross_passby_ss)

passbys.append((fg_passby_ss, gross_passby_ss))

if verbose:
delay = time() - tstart
Expand Down
34 changes: 27 additions & 7 deletions riid/data/synthetic/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,40 @@ def _get_synthetic_samples(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
bg_sources = bg_seeds_ss.sources.iloc[b]
fg_seed = fg_seeds_ss.spectra.iloc[f]
fg_sources = fg_seeds_ss.sources.iloc[f]
fg_seed_rt = fg_seeds_ss.info.real_time.iloc[f]
fg_seed_lt = fg_seeds_ss.info.live_time.iloc[f]
batch_lt_targets = lt_targets[batch_begin_idx:batch_end_idx]
batch_rt_targets = lt_targets[batch_begin_idx:batch_end_idx] * (fg_seed_rt / fg_seed_lt)
batch_snr_targets = snr_targets[batch_begin_idx:batch_end_idx]
distance_cm = fg_seeds_ss.info.distance_cm.iloc[f]

ecal = fg_seeds_ss.ecal[f]
fg_batch_ss, gross_batch_ss = self._get_batch(
fg_seed, fg_sources,
bg_seed, bg_sources,
ecal, batch_lt_targets, batch_snr_targets, batch_rt_targets,
distance_cm
batch_lt_targets,
batch_snr_targets
)

fg_seed_ecal = fg_seeds_ss.ecal[f]
fg_seed_info = fg_seeds_ss.info.iloc[f]
batch_rt_targets = batch_lt_targets * (1 - fg_seed_info.dead_time_prop)
fg_seed_distance_cm = fg_seed_info.distance_cm
fg_seed_dead_time_prop = fg_seed_info.dead_time_prop
fg_seed_ad = fg_seed_info.areal_density
fg_seed_an = fg_seed_info.atomic_number
fg_seed_neutron_counts = fg_seed_info.neutron_counts

def _set_remaining_info(ss: SampleSet | None):
if ss is None:
return
ss.ecal = fg_seed_ecal
ss.info.real_time = batch_rt_targets
ss.info.distance_cm = fg_seed_distance_cm
ss.info.dead_time_prop = fg_seed_dead_time_prop
ss.info.areal_density = fg_seed_ad
ss.info.atomic_number = fg_seed_an
ss.info.neutron_counts = fg_seed_neutron_counts
ss.info.timestamp = self._synthesis_start_dt

_set_remaining_info(fg_batch_ss)
_set_remaining_info(gross_batch_ss)

fg_ss_batches.append(fg_batch_ss)
gross_ss_batches.append(gross_batch_ss)

Expand Down
5 changes: 0 additions & 5 deletions tests/staticsynth_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,16 @@ def test_get_batch(self):
[0.3, 0.4, 0.3],
index=["X", "Y", "Z"]
)
ecal = (0, 3000, 100, 0, 0)
lts = np.array([4.2]).astype(float)
snrs = np.array([63.2]).astype(float)
distance_cm = 50

fg_ss, gross_ss = synth._get_batch(
fg_seed=fg_seed,
fg_sources=fg_sources,
bg_seed=bg_seed,
bg_sources=bg_sources,
ecal=ecal,
lt_targets=lts,
snr_targets=snrs,
distance_cm=distance_cm
)

self.assertTrue(np.allclose(
Expand All @@ -242,7 +238,6 @@ def test_get_batch(self):
gross_ss.sources.loc[:, bg_sources.index].sum(axis=1) / synth.bg_cps,
lts,
))

self.assertTrue(np.allclose(
gross_ss.sources.loc[:, fg_sources.index],
fg_ss.sources.loc[:, fg_sources.index],
Expand Down

0 comments on commit 304074e

Please sign in to comment.