diff --git a/riid/data/sampleset.py b/riid/data/sampleset.py index 47dd033..c9e5547 100644 --- a/riid/data/sampleset.py +++ b/riid/data/sampleset.py @@ -1800,13 +1800,13 @@ def _pcf_dict_to_ss(pcf_dict: dict, verbose=True): "total_counts": sum(spectrum["spectrum"]), "neutron_counts": spectrum["header"]["Total_Neutron_Counts"], "distance_cm": distance, - "areal_density": ad, "ecal_order_0": order_0, "ecal_order_1": order_1, "ecal_order_2": order_2, "ecal_order_3": order_3, "ecal_low_e": low_E, "atomic_number": an, + "areal_density": ad, "occupancy_flag": spectrum["header"]["Occupancy_Flag"], "tag": spectrum["header"]["Tag"], } diff --git a/riid/data/synthetic/base.py b/riid/data/synthetic/base.py index 9209c7f..da6af9e 100644 --- a/riid/data/synthetic/base.py +++ b/riid/data/synthetic/base.py @@ -81,8 +81,8 @@ def _verify_n_samples_synthesized(self, actual: int, expected: int): "Be sure to remove any columns from your seeds' sources DataFrame that " "contain all zeroes.") - def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, - lt_targets, snr_targets, rt_targets=None, distance_cm=None): + def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, + lt_targets, snr_targets) -> tuple[SampleSet, SampleSet | None]: if not (self.return_fg or self.return_gross): raise ValueError("Computing to return nothing.") @@ -127,10 +127,8 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, # Sample sets if self.return_fg: snrs = fg_counts / np.sqrt(long_bg_counts.clip(1)) - fg_ss = get_fg_sample_set(fg_spectra, fg_sources, ecal, lt_targets, - snrs=snrs, total_counts=fg_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) + fg_ss = get_fg_sample_set(fg_spectra, fg_sources, lt_targets, + snrs=snrs, total_counts=fg_counts) self._n_samples_synthesized += fg_ss.n_samples if self.return_gross: tiled_fg_sources = _tile_sources_and_scale( @@ -146,40 +144,28 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, gross_sources = get_merged_sources_samplewise(tiled_fg_sources, tiled_bg_sources) gross_counts = gross_spectra.sum(axis=1) snrs = fg_counts / np.sqrt(bg_counts.clip(1)) - gross_ss = get_gross_sample_set(gross_spectra, gross_sources, ecal, - lt_targets, snrs, gross_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) + gross_ss = get_gross_sample_set(gross_spectra, gross_sources, + lt_targets, snrs, gross_counts) self._n_samples_synthesized += gross_ss.n_samples return fg_ss, gross_ss -def get_sample_set(spectra, sources, ecal, live_times, snrs, total_counts=None, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: +def _get_minimal_ss(spectra, sources, live_times, snrs, total_counts=None) -> SampleSet: n_samples = spectra.shape[0] + if n_samples <= 0: + raise ValueError(f"Can't build SampleSet with {n_samples} samples.") ss = SampleSet() ss.spectra_state = SpectraState.Counts ss.spectra = pd.DataFrame(spectra) ss.sources = sources ss.info.description = np.full(n_samples, "") # Ensures the length of info equal n_samples - if descriptions: - ss.info.description = descriptions ss.info.snr = snrs - ss.info.timestamp = timestamps ss.info.total_counts = total_counts if total_counts is not None else spectra.sum(axis=1) - ss.info.ecal_order_0 = ecal[0] - ss.info.ecal_order_1 = ecal[1] - ss.info.ecal_order_2 = ecal[2] - ss.info.ecal_order_3 = ecal[3] - ss.info.ecal_low_e = ecal[4] ss.info.live_time = live_times - ss.info.real_time = real_times if real_times is not None else live_times - ss.info.distance_cm = distance_cm ss.info.occupancy_flag = 0 - ss.info.tag = " " # TODO: test if this can be empty string + ss.info.tag = " " # TODO: test if this can be an empty string return ss @@ -196,44 +182,30 @@ def _tile_sources_and_scale(sources, n_samples, scalars) -> pd.DataFrame: return tiled_sources -def get_fg_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: +def get_fg_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet: tiled_sources = _tile_sources_and_scale( sources, spectra.shape[0], spectra.sum(axis=1) ) - ss = get_sample_set( + ss = _get_minimal_ss( spectra=spectra, sources=tiled_sources, - ecal=ecal, live_times=live_times, snrs=snrs, total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions ) ss.spectra_type = SpectraType.Foreground return ss -def get_gross_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: - ss = get_sample_set( +def get_gross_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet: + ss = _get_minimal_ss( spectra=spectra, sources=sources, - ecal=ecal, live_times=live_times, snrs=snrs, total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions ) ss.spectra_type = SpectraType.Gross return ss diff --git a/riid/data/synthetic/passby.py b/riid/data/synthetic/passby.py index 8423461..70e7a44 100644 --- a/riid/data/synthetic/passby.py +++ b/riid/data/synthetic/passby.py @@ -177,7 +177,7 @@ def _calculate_passby_shape(self, fwhm: float): return 1 / (np.power(samples, 2) + 1) def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float, - fg_seed: np.array, bg_seed: np.array, fg_ecal: np.array, + fg_seed: np.array, bg_seed: np.array, fg_sources: pd.Series, bg_sources: pd.Series): """Generate a `SampleSet` with a sequence of spectra representative of a single pass-by. @@ -212,7 +212,6 @@ def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float, fg_sources, bg_seed, bg_sources, - fg_ecal, live_times, snr_targets ) @@ -275,17 +274,54 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, for fg_i in range(fg_seeds_ss.n_samples): fg_pmf = fg_seeds_ss.spectra.iloc[fg_i] fg_sources = fg_seeds_ss.sources.iloc[fg_i] - fg_ecal = fg_seeds_ss.ecal[fg_i] for t_i in range(self.events_per_seed): fwhm = fwhm_targets[t_i] snr = snr_targets[t_i] dwell_time = dwell_time_targets[t_i] - pb_args = (fwhm, snr, dwell_time, fg_pmf, bg_pmf, - fg_ecal, fg_sources, bg_sources) + pb_args = (fg_i, fwhm, snr, dwell_time, fg_pmf, bg_pmf, + fg_sources, bg_sources) args.append(pb_args) # TODO: follow prevents periodic progress reports - passbys = [self._generate_single_passby(*a) for a in args] + passbys = [] + for a in args: + f, fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources = a + fg_passby_ss, gross_passby_ss = self._generate_single_passby( + fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources + ) + live_times = None + if fg_passby_ss is not None: + live_times = fg_passby_ss.info.live_time + elif gross_passby_ss is not None: + live_times = gross_passby_ss.info.live_time + else: + live_times = 1.0 + + fg_seed_ecal = fg_seeds_ss.ecal[f] + fg_seed_info = fg_seeds_ss.info.iloc[f] + batch_rt_targets = live_times * (1 - fg_seed_info.dead_time_prop) + fg_seed_distance_cm = fg_seed_info.distance_cm + fg_seed_dead_time_prop = fg_seed_info.dead_time_prop + fg_seed_ad = fg_seed_info.areal_density + fg_seed_an = fg_seed_info.atomic_number + fg_seed_neutron_counts = fg_seed_info.neutron_counts + + def _set_remaining_info(ss: SampleSet | None): + if ss is None: + return + ss.ecal = fg_seed_ecal + ss.info.real_time = batch_rt_targets + ss.info.distance_cm = fg_seed_distance_cm + ss.info.dead_time_prop = fg_seed_dead_time_prop + ss.info.areal_density = fg_seed_ad + ss.info.atomic_number = fg_seed_an + ss.info.neutron_counts = fg_seed_neutron_counts + ss.info.timestamp = self._synthesis_start_dt + + _set_remaining_info(fg_passby_ss) + _set_remaining_info(gross_passby_ss) + + passbys.append((fg_passby_ss, gross_passby_ss)) if verbose: delay = time() - tstart diff --git a/riid/data/synthetic/static.py b/riid/data/synthetic/static.py index cdf0aed..313448d 100644 --- a/riid/data/synthetic/static.py +++ b/riid/data/synthetic/static.py @@ -180,20 +180,40 @@ def _get_synthetic_samples(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, bg_sources = bg_seeds_ss.sources.iloc[b] fg_seed = fg_seeds_ss.spectra.iloc[f] fg_sources = fg_seeds_ss.sources.iloc[f] - fg_seed_rt = fg_seeds_ss.info.real_time.iloc[f] - fg_seed_lt = fg_seeds_ss.info.live_time.iloc[f] batch_lt_targets = lt_targets[batch_begin_idx:batch_end_idx] - batch_rt_targets = lt_targets[batch_begin_idx:batch_end_idx] * (fg_seed_rt / fg_seed_lt) batch_snr_targets = snr_targets[batch_begin_idx:batch_end_idx] - distance_cm = fg_seeds_ss.info.distance_cm.iloc[f] - ecal = fg_seeds_ss.ecal[f] fg_batch_ss, gross_batch_ss = self._get_batch( fg_seed, fg_sources, bg_seed, bg_sources, - ecal, batch_lt_targets, batch_snr_targets, batch_rt_targets, - distance_cm + batch_lt_targets, + batch_snr_targets ) + + fg_seed_ecal = fg_seeds_ss.ecal[f] + fg_seed_info = fg_seeds_ss.info.iloc[f] + batch_rt_targets = batch_lt_targets * (1 - fg_seed_info.dead_time_prop) + fg_seed_distance_cm = fg_seed_info.distance_cm + fg_seed_dead_time_prop = fg_seed_info.dead_time_prop + fg_seed_ad = fg_seed_info.areal_density + fg_seed_an = fg_seed_info.atomic_number + fg_seed_neutron_counts = fg_seed_info.neutron_counts + + def _set_remaining_info(ss: SampleSet | None): + if ss is None: + return + ss.ecal = fg_seed_ecal + ss.info.real_time = batch_rt_targets + ss.info.distance_cm = fg_seed_distance_cm + ss.info.dead_time_prop = fg_seed_dead_time_prop + ss.info.areal_density = fg_seed_ad + ss.info.atomic_number = fg_seed_an + ss.info.neutron_counts = fg_seed_neutron_counts + ss.info.timestamp = self._synthesis_start_dt + + _set_remaining_info(fg_batch_ss) + _set_remaining_info(gross_batch_ss) + fg_ss_batches.append(fg_batch_ss) gross_ss_batches.append(gross_batch_ss) diff --git a/tests/staticsynth_tests.py b/tests/staticsynth_tests.py index afbd297..8ec6ddf 100644 --- a/tests/staticsynth_tests.py +++ b/tests/staticsynth_tests.py @@ -217,20 +217,16 @@ def test_get_batch(self): [0.3, 0.4, 0.3], index=["X", "Y", "Z"] ) - ecal = (0, 3000, 100, 0, 0) lts = np.array([4.2]).astype(float) snrs = np.array([63.2]).astype(float) - distance_cm = 50 fg_ss, gross_ss = synth._get_batch( fg_seed=fg_seed, fg_sources=fg_sources, bg_seed=bg_seed, bg_sources=bg_sources, - ecal=ecal, lt_targets=lts, snr_targets=snrs, - distance_cm=distance_cm ) self.assertTrue(np.allclose( @@ -242,7 +238,6 @@ def test_get_batch(self): gross_ss.sources.loc[:, bg_sources.index].sum(axis=1) / synth.bg_cps, lts, )) - self.assertTrue(np.allclose( gross_ss.sources.loc[:, fg_sources.index], fg_ss.sources.loc[:, fg_sources.index],