Skip to content

Commit d98c11f

Browse files
committed
Add checks for ARAD models; clean up PyRIIDModel saving extensions.
1 parent 147f780 commit d98c11f

File tree

3 files changed

+50
-35
lines changed

3 files changed

+50
-35
lines changed

examples/modeling/arad.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,46 @@ def show_summaries(model):
4848

4949
# Generate some training data
5050
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
51-
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1)
51+
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(300)
5252

5353
static_synth = StaticSynthesizer(
54-
samples_per_seed=250,
55-
snr_function="log10",
54+
samples_per_seed=10,
55+
snr_function_args=(0, 0),
5656
return_fg=False,
5757
return_gross=True,
5858
)
59-
_, train_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
59+
_, train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
60+
train_ss.downsample_spectra(target_bins=128)
6061
train_ss.normalize()
6162

63+
# Train the models
6264
print("training ARADv1...")
6365
arad_v1.fit(train_ss, epochs=50, verbose=True)
6466
print("training ARADv2...")
6567
arad_v2.fit(train_ss, epochs=50, verbose=True)
6668

6769
# Generate some test data
68-
static_synth.samples_per_seed = 50
69-
_, test_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
70+
static_synth.samples_per_seed = 3
71+
_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
72+
test_ss.downsample_spectra(target_bins=128)
7073
test_ss.normalize()
7174

7275
# Predict
73-
arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True)
76+
arad_v1_reconstructions = arad_v1.predict(
77+
test_ss,
78+
verbose=True,
79+
ood_threshold=2.349
80+
)
7481
recon_errors = test_ss.info["recon_error"].values
7582
ood_decisions = test_ss.info["ood"].values
7683
print((f"ARADv1: mean reconstruction error = {np.mean(recon_errors):.3f} (KLD)\n"
7784
f" OOD rate = {np.mean(ood_decisions):.2f}"))
7885

79-
arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True)
86+
arad_v2_reconstructions = arad_v2.predict(
87+
test_ss,
88+
verbose=True,
89+
ood_threshold=0.15678
90+
)
8091
recon_errors = test_ss.info["recon_error"].values
8192
ood_decisions = test_ss.info["ood"].values
8293
print((f"ARADv2: mean reconstruction error = {np.mean(recon_errors):.3f} (JSD)\n"

riid/models/__init__.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class PyRIIDModel:
3131
"""Base class for TensorFlow models."""
3232

3333
CUSTOM_OBJECTS = {"multi_f1": multi_f1, "single_f1": single_f1}
34+
SUPPORTED_SAVE_EXTS = {"H5": ".h5", "ONNX": ".onnx"}
3435

3536
def __init__(self, *args, **kwargs):
3637
self._info = {}
@@ -120,25 +121,21 @@ def save(self, file_path: str):
120121
if os.path.exists(file_path):
121122
raise ValueError("Path already exists.")
122123

123-
SUPPORTED_EXTS = {
124-
"H5": ".h5",
125-
"ONNX": ".onnx"
126-
}
127124
root, ext = os.path.splitext(file_path)
128-
if ext.lower() not in SUPPORTED_EXTS.values():
125+
if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values():
129126
raise NameError("Model must be an .onnx or .h5 file.")
130127

131128
warnings.filterwarnings("ignore")
132129

133-
if ext.lower() == SUPPORTED_EXTS["H5"]:
130+
if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]:
134131
self.model.save(file_path, save_format="h5")
135132
pd.DataFrame(
136133
[[v] for v in self.info.values()],
137134
self.info.keys()
138135
).to_hdf(file_path, "_info")
139136

140137
else:
141-
model_path = root + SUPPORTED_EXTS["ONNX"]
138+
model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"]
142139
model_info_path = root + "_info.json"
143140

144141
model_info_df = pd.DataFrame(
@@ -162,25 +159,22 @@ def load(self, file_path: str):
162159
file_path: file path from which to load the model, must be either an
163160
.h5 or .onnx file
164161
"""
165-
SUPPORTED_EXTS = {
166-
"H5": ".h5",
167-
"ONNX": ".onnx"
168-
}
162+
169163
root, ext = os.path.splitext(file_path)
170-
if ext.lower() not in SUPPORTED_EXTS.values():
164+
if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values():
171165
raise NameError("Model must be an .onnx or .h5 file.")
172166

173167
warnings.filterwarnings("ignore", category=DeprecationWarning)
174168

175-
if ext.lower() == SUPPORTED_EXTS["H5"]:
169+
if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]:
176170
self.model = tf.keras.models.load_model(
177171
file_path,
178172
custom_objects=self.CUSTOM_OBJECTS
179173
)
180174
self._info = pd.read_hdf(file_path, "_info")[0].to_dict()
181175

182176
else:
183-
model_path = root + SUPPORTED_EXTS["ONNX"]
177+
model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"]
184178
model_info_path = root + "_info.json"
185179

186180
with open(model_info_path) as fin:

riid/models/neural_nets/arad.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from scipy.spatial.distance import jensenshannon
1515
from scipy.stats import entropy
1616

17-
from riid.data.sampleset import SampleSet
17+
from riid.data.sampleset import SampleSet, SpectraState
1818
from riid.losses import jensen_shannon_distance, mish
1919
from riid.models import PyRIIDModel
2020

@@ -247,6 +247,21 @@ def __init__(self, model: Model = ARADv2TF()):
247247

248248
self.model = model
249249

250+
def _check_spectra(self, ss):
251+
"""Checks if SampleSet spectra are compatible with ARAD models."""
252+
if ss.n_samples <= 0:
253+
raise ValueError("No spectr[a|um] provided!")
254+
if not ss.all_spectra_sum_to_one():
255+
raise ValueError("All spectra must sum to one.")
256+
if not ss.spectra_state == SpectraState.L1Normalized:
257+
raise ValueError(
258+
f"SpectraState must be L1Normalzied, provided SpectraState is {ss.spectra_state}."
259+
)
260+
if not ss.n_channels == 128:
261+
raise ValueError(
262+
f"Spectra must have 128 channels, provided spectra have {ss.n_channels} channels."
263+
)
264+
250265
def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,
251266
es_verbose: int = 0, verbose: bool = False):
252267
"""Fit a model to the given `SampleSet`.
@@ -261,13 +276,9 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,
261276
Returns:
262277
reconstructed_spectra: output of ARAD model
263278
"""
264-
if ss.n_samples <= 0:
265-
raise ValueError("No spectr[a|um] provided!")
279+
self._check_spectra(ss)
266280

267-
norm_ss = ss[:]
268-
norm_ss.downsample_spectra(target_bins=128)
269-
norm_ss.normalize()
270-
x = norm_ss.get_samples().astype(float)
281+
x = ss.get_samples().astype(float)
271282

272283
is_v1 = isinstance(self.model, ARADv1TF)
273284
is_v2 = isinstance(self.model, ARADv2TF)
@@ -343,12 +354,11 @@ def predict(self, ss: SampleSet, ood_threshold: float = 0.5,
343354
Returns:
344355
reconstructed_spectra: output of ARAD model
345356
"""
346-
norm_ss = ss[:]
347-
norm_ss.downsample_spectra(target_bins=128)
348-
norm_ss.normalize()
349-
spectra = norm_ss.get_samples().astype(float)
357+
self._check_spectra(ss)
358+
359+
x = ss.get_samples().astype(float)
350360

351-
reconstructed_spectra = self.get_predictions(spectra, verbose=verbose)
361+
reconstructed_spectra = self.get_predictions(x, verbose=verbose)
352362

353363
is_v1 = isinstance(self.model, ARADv1TF)
354364
is_v2 = isinstance(self.model, ARADv2TF)
@@ -357,7 +367,7 @@ def predict(self, ss: SampleSet, ood_threshold: float = 0.5,
357367
elif is_v2:
358368
reconstruction_metric = jensenshannon
359369

360-
reconstruction_errors = reconstruction_metric(spectra, reconstructed_spectra, axis=1)
370+
reconstruction_errors = reconstruction_metric(x, reconstructed_spectra, axis=1)
361371
ood_decisions = reconstruction_errors > ood_threshold
362372
ss.info["recon_error"] = reconstruction_errors
363373
ss.info["ood"] = ood_decisions

0 commit comments

Comments
 (0)