Skip to content

Commit

Permalink
Update TFModelBase to PyRIIDModel; implement saving models as `.o…
Browse files Browse the repository at this point in the history
…nnx` and update all models.
  • Loading branch information
alanjvano authored and tymorrow committed Dec 11, 2023
1 parent e3c01e9 commit 4b49297
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 38 deletions.
85 changes: 75 additions & 10 deletions riid/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This module contains the base TFModel class."""
import json
import os
import uuid
import warnings
from enum import Enum

import numpy as np
import onnxruntime
import pandas as pd
import tensorflow as tf
import tf2onnx

import riid
from riid.data.labeling import label_to_index_element
Expand All @@ -23,7 +27,7 @@ class ModelInput(Enum):
ForegroundSpectrum = 2


class TFModelBase:
class PyRIIDModel:
"""Base class for TensorFlow models."""

CUSTOM_OBJECTS = {"multi_f1": multi_f1, "single_f1": single_f1}
Expand Down Expand Up @@ -107,37 +111,98 @@ def save(self, file_path: str):
"""Save the model to a file.
Args:
file_path: file path at which to save the model
file_path: file path at which to save the model, can be either .h5 or
.onnx format
Raises:
`ValueError` when the given file path already exists
"""
if os.path.exists(file_path):
raise ValueError("Path already exists.")

SUPPORTED_EXTS = {
"H5": ".h5",
"ONNX": ".onnx"
}
root, ext = os.path.splitext(file_path)
if ext.lower() not in SUPPORTED_EXTS.values():
raise NameError("Model must be an .onnx or .h5 file.")

warnings.filterwarnings("ignore")

self.model.save(file_path, save_format="h5")
pd.DataFrame([[v] for v in self.info.values()], self.info.keys()).to_hdf(file_path, "_info")
if ext.lower() == SUPPORTED_EXTS["H5"]:
self.model.save(file_path, save_format="h5")
pd.DataFrame(
[[v] for v in self.info.values()],
self.info.keys()
).to_hdf(file_path, "_info")

else:
model_path = root + SUPPORTED_EXTS["ONNX"]
model_info_path = root + "_info.json"

model_info_df = pd.DataFrame(
[[v] for v in self.info.values()],
self.info.keys()
)
model_info_df[0].to_json(model_info_path, indent=4)

tf2onnx.convert.from_keras(
self.model,
input_signature=None,
output_path=model_path
)

warnings.resetwarnings()

def load(self, file_path: str):
"""Load the model from a file.
Args:
file_path: file path from which to load the model
file_path: file path from which to load the model, must be either an
.h5 or .onnx file
"""
SUPPORTED_EXTS = {
"H5": ".h5",
"ONNX": ".onnx"
}
root, ext = os.path.splitext(file_path)
if ext.lower() not in SUPPORTED_EXTS.values():
raise NameError("Model must be an .onnx or .h5 file.")

warnings.filterwarnings("ignore", category=DeprecationWarning)

self.model = tf.keras.models.load_model(
file_path,
custom_objects=self.CUSTOM_OBJECTS
)
self._info = pd.read_hdf(file_path, "_info")[0].to_dict()
if ext.lower() == SUPPORTED_EXTS["H5"]:
self.model = tf.keras.models.load_model(
file_path,
custom_objects=self.CUSTOM_OBJECTS
)
self._info = pd.read_hdf(file_path, "_info")[0].to_dict()

else:
model_path = root + SUPPORTED_EXTS["ONNX"]
model_info_path = root + "_info.json"

with open(model_info_path) as fin:
model_info = json.load(fin)
self._info = model_info

self.onnx_session = onnxruntime.InferenceSession(model_path)

warnings.resetwarnings()

def get_predictions(self, x, **kwargs):
if self.model is None:
outputs = self.onnx_session.run(
[self.onnx_session.get_outputs()[0].name],
{self.onnx_session.get_inputs()[0].name: x.astype(np.float32)}
)[0]

else:
outputs = self.model.predict(x, **kwargs)

return outputs

def serialize(self) -> bytes:
"""Convert model to a bytes object.
Expand Down
6 changes: 3 additions & 3 deletions riid/models/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import tensorflow_probability as tfp

from riid.data.sampleset import SampleSet
from riid.models import TFModelBase
from riid.models import PyRIIDModel


class PoissonBayesClassifier(TFModelBase):
class PoissonBayesClassifier(PyRIIDModel):
"""This Poisson-Bayes classifier calculates the conditional Poisson log probability of each
seed spectrum given the measurement.
Expand Down Expand Up @@ -139,7 +139,7 @@ def predict(self, gross_ss: SampleSet, bg_ss: SampleSet,
bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32)
bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32)

prediction_probas = self.model.predict((
prediction_probas = self.get_predictions((
gross_spectra, gross_lts, bg_spectra, bg_lts
), batch_size=512, verbose=verbose)

Expand Down
30 changes: 13 additions & 17 deletions riid/models/neural_nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from riid.losses.sparsemax import SparsemaxLoss, sparsemax
from riid.metrics import (build_keras_semisupervised_metric_func, multi_f1,
single_f1)
from riid.models import ModelInput, TFModelBase
from riid.models import ModelInput, PyRIIDModel

tf2onnx.logging.basicConfig(level=tf2onnx.logging.WARNING)

Expand All @@ -47,7 +47,7 @@ def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.Data
return reordered_spectra_df


class MLPClassifier(TFModelBase):
class MLPClassifier(PyRIIDModel):
"""Multi-layer perceptron classifier."""
def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu",
loss: str = "categorical_crossentropy",
Expand Down Expand Up @@ -245,7 +245,8 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False):
X = [x_test, bg_ss.get_samples().astype(float)]
else:
X = x_test
results = self.model.predict(X, verbose=verbose)

results = self.get_predictions(X, verbose=verbose)

col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level)
col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1]
Expand All @@ -259,7 +260,7 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False):
ss.classified_by = self.info["model_id"]


class MultiEventClassifier(TFModelBase):
class MultiEventClassifier(PyRIIDModel):
"""A classifier for spectra from multiple detectors observing the same event."""

def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu",
Expand Down Expand Up @@ -423,7 +424,7 @@ def fit(self, list_of_ss: List[SampleSet], target_contributions: pd.DataFrame,

return history

def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame:
def predict(self, list_of_ss: List[SampleSet], verbose=False) -> pd.DataFrame:
"""Classify the spectra in the provided `SampleSet`(s) based on each one's results.
Args:
Expand All @@ -433,7 +434,8 @@ def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame:
`DataFrame` of predicted results for the `Sampleset`(s)
"""
X = [ss.prediction_probas for ss in list_of_ss]
results = self.model.predict(X) # output size will be n_samples by n_labels
# output size will be n_samples by n_labels
results = self.get_predictions(X, verbose=verbose)

col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level)
col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1]
Expand All @@ -446,7 +448,7 @@ def predict(self, list_of_ss: List[SampleSet]) -> pd.DataFrame:
return results_df


class LabelProportionEstimator(TFModelBase):
class LabelProportionEstimator(PyRIIDModel):
UNSUPERVISED_LOSS_FUNCS = {
"poisson_nll": poisson_nll_diff,
"normal_nll": normal_nll_diff,
Expand Down Expand Up @@ -847,7 +849,8 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b

return history

def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False):
def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False,
verbose=False):
"""Estimate the proportions of counts present in each sample of the provided SampleSet.
Results are stored inside the SampleSet's prediction_probas property.
Expand All @@ -861,16 +864,9 @@ def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False):
"""
test_spectra = ss.get_samples().astype(float)

if self.model is None:
outputs = self.onnx_session.run(
[self.onnx_session.get_outputs()[0].name],
{self.onnx_session.get_inputs()[0].name: test_spectra.astype(np.float32)}
)[0]
lpes = self.activation(tf.convert_to_tensor(outputs, dtype=tf.float32))
logits = self.get_predictions(test_spectra, verbose=verbose)

else:
logits = self.model.predict(test_spectra)
lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32))
lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32))

col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self._info["target_level"])
col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1]
Expand Down
25 changes: 17 additions & 8 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from riid.data.sampleset import SampleSet
from riid.losses import mish, jensen_shannon_distance
from riid.models import TFModelBase
from riid.models import PyRIIDModel


@tf.keras.saving.register_keras_serializable(package="riid")
Expand Down Expand Up @@ -235,7 +235,7 @@ def call(self, x):
return decoded


class ARAD(TFModelBase):
class ARAD(PyRIIDModel):
"""PyRIID-compatible wrapper around ARAD models.
"""
def __init__(self, model: Model = ARADv2TF()):
Expand All @@ -247,11 +247,20 @@ def __init__(self, model: Model = ARADv2TF()):

self.model = model

# TODO: enable saving as ONNX
def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2,
es_verbose: int = 0, verbose: bool = False):
"""Fit a model to the given `SampleSet`.
def fit(self, ss: SampleSet, epochs: int = 300, es_verbose: int = 0,
verbose: bool = False):
"""Fit a model to the given `SampleSet`."""
Args:
ss: `SampleSet` of `n` spectra where `n` >= 1
epochs: maximum number of training epochs
validation_split: percentage of the training data to use as validation data
es_verbose: verbosity level for `tf.keras.callbacks.EarlyStopping`
verbose: whether to show detailed model training output
Returns:
reconstructed_spectra: output of ARAD model
"""
if ss.n_samples <= 0:
raise ValueError("No spectr[a|um] provided!")

Expand Down Expand Up @@ -312,7 +321,7 @@ def fit(self, ss: SampleSet, epochs: int = 300, es_verbose: int = 0,
spectra,
epochs=epochs,
verbose=verbose,
validation_split=0.2,
validation_split=validation_split,
callbacks=callbacks,
shuffle=True,
batch_size=batch_size
Expand All @@ -337,7 +346,7 @@ def predict(self, ss: SampleSet, ood_threshold: float = 0.5,
norm_ss.normalize()
spectra = norm_ss.get_samples().astype(float)

reconstructed_spectra = self.model.predict(spectra, verbose=verbose)
reconstructed_spectra = self.get_predictions(spectra, verbose=verbose)

if isinstance(self.model, ARADv1TF):
reconstruction_metric = entropy
Expand Down

0 comments on commit 4b49297

Please sign in to comment.