From bedb9dfbeb0c6152e115e1e678ffb15c34fcfb3b Mon Sep 17 00:00:00 2001 From: Lenz Date: Fri, 28 Apr 2023 12:28:15 +0200 Subject: [PATCH] Added BEN scheme sampling for training --- configs/params.toml | 2 + conftest.py | 12 ++++- src/cell_classification/model_builder.py | 56 ++++++++++++++++++++---- tests/model_builder_test.py | 44 ++++++++++++++++++- tests/segmentation_data_prep_test.py | 4 +- 5 files changed, 105 insertions(+), 13 deletions(-) diff --git a/configs/params.toml b/configs/params.toml index 2877d92..b4a4411 100644 --- a/configs/params.toml +++ b/configs/params.toml @@ -42,6 +42,8 @@ quantile_warmup_steps = 100000 confidence_thresholds = [0.1, 0.9] ema = 0.01 location = false +BEN = false + [loss_kwargs] from_logits = false label_smoothing = 0.1 diff --git a/conftest.py b/conftest.py index 76c000a..dd39927 100644 --- a/conftest.py +++ b/conftest.py @@ -1,10 +1,20 @@ from typing import Any, Dict, Iterator - +import numpy as np import pytest import toml +import os +import sys +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +sys.path.append(os.path.join(os.path.dirname(__file__), "src")) +sys.path.append(os.path.join(os.path.dirname(__file__), "tests")) @pytest.fixture(scope="function") def config_params() -> Iterator[Dict]: params: Dict[str, Any] = toml.load("./configs/params.toml") yield params + +@pytest.fixture(scope="function") +def rng() -> Iterator[np.random.Generator]: + rng_ = np.random.default_rng(seed=42) + yield rng_ \ No newline at end of file diff --git a/src/cell_classification/model_builder.py b/src/cell_classification/model_builder.py index d7a69a0..d8ba786 100644 --- a/src/cell_classification/model_builder.py +++ b/src/cell_classification/model_builder.py @@ -134,18 +134,56 @@ def prep_data(self): ) # shuffle, batch and augment the datasets - self.train_dataset = self.train_dataset.shuffle(self.params["shuffle_buffer_size"]).batch( - self.params["batch_size"] * np.max([self.num_gpus, 1]) - ) - self.validation_datasets = [validation_dataset.batch( - self.params["batch_size"] * np.max([self.num_gpus, 1]) - ) for validation_dataset in self.validation_datasets] - self.test_datasets = [test_dataset.batch( - self.params["batch_size"] * np.max([self.num_gpus, 1]) - ) for test_dataset in self.test_datasets] + if self.params["BEN"]: + self.train_dataset = self.make_pure_batches(self.train_dataset).shuffle( + self.params["shuffle_buffer_size"] + ).shuffle(self.params["shuffle_buffer_size"]) + self.validation_datasets = [ + self.make_pure_batches(validation_dataset) for validation_dataset in + self.validation_datasets + ] + self.test_datasets = [ + self.make_pure_batches(test_dataset) for test_dataset in self.test_datasets + ] + else: + self.train_dataset = self.train_dataset.shuffle(self.params["shuffle_buffer_size"]).batch( + self.params["batch_size"] * np.max([self.num_gpus, 1]) + ) + self.validation_datasets = [validation_dataset.batch( + self.params["batch_size"] * np.max([self.num_gpus, 1]) + ) for validation_dataset in self.validation_datasets] + self.test_datasets = [test_dataset.batch( + self.params["batch_size"] * np.max([self.num_gpus, 1]) + ) for test_dataset in self.test_datasets] self.dataset_names = self.params["dataset_names"] + def make_pure_batches(self, dataset): + """Makes batches from a dataset such that each batch contains only one fov and marker + Args: + dataset: tf.data.Dataset + Returns: + dataset: tf.data.Dataset + """ + def key_func(example): + """Returns a hash bucket for each example based on dataset, fov, and marker""" + i = 1e10 + hash_ = tf.strings.to_hash_bucket( + example["dataset"] + example["folder_name"] + example["marker"], i + ) + return hash_ + + def reduce_func(key, dataset): + """Reduces the windows to batches""" + return dataset.batch(self.params["batch_size"]) + + dataset = dataset.group_by_window( + key_func=key_func, + reduce_func=reduce_func, + window_size=self.params["batch_size"], + ) + return dataset + def prep_model(self): """Prepares the model for training""" # prepare folders diff --git a/tests/model_builder_test.py b/tests/model_builder_test.py index 94dede4..2d84db2 100644 --- a/tests/model_builder_test.py +++ b/tests/model_builder_test.py @@ -12,7 +12,7 @@ from cell_classification.segmentation_data_prep import (feature_description, parse_dict) -from .segmentation_data_prep_test import prep_object_and_inputs +from segmentation_data_prep_test import prep_object_and_inputs tf.config.run_functions_eagerly(True) @@ -57,6 +57,18 @@ def test_prep_data(config_params): config_params["num_validation"] = [2, 2] config_params["num_test"] = [2, 2] config_params["batch_size"] = 2 + + # check if BEN works, so batches only contain samples from one dataset + config_params["BEN"] = True + trainer = ModelBuilder(config_params) + trainer.prep_data() + for batch in trainer.train_dataset: + # test that batches are pure with respect to marker, fov and datasets + assert len(set(list(batch["marker"].numpy()))) == 1 + assert len(set(list(batch["folder_name"].numpy()))) == 1 + assert len(set(list(batch["dataset"].numpy()))) == 1 + + config_params["BEN"] = False trainer = ModelBuilder(config_params) trainer.prep_data() @@ -504,3 +516,33 @@ def test_fov_filter(config_params): for example in dataset_filtered: fov_list.append(example["folder_name"].numpy().decode()) assert set(fov) == set(fov_list) + + +def test_make_pure_batches(config_params): + with tempfile.TemporaryDirectory() as temp_dir: + data_prep, _, _, _ = prep_object_and_inputs(temp_dir, selected_markers=["CD4", "CD56"]) + data_prep.tf_record_path = temp_dir + data_prep.tile_size = [32, 32] + data_prep.stride = [32, 32] + data_prep.make_tf_record() + tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord") + config_params["record_path"] = [tf_record_path] + config_params["path"] = temp_dir + config_params["experiment"] = "test" + config_params["dataset_names"] = ["test1"] + config_params["num_steps"] = 20 + config_params["dataset_sample_probs"] = [1.0] + config_params["batch_size"] = 4 + trainer = ModelBuilder(config_params) + dataset = tf.data.TFRecordDataset(tf_record_path) + dataset = dataset.map(lambda x: tf.io.parse_single_example(x, feature_description)) + dataset = dataset.map(parse_dict) + dataset_pure = trainer.make_pure_batches(dataset) + dataset_pure = dataset_pure.shuffle(1000) + for example in dataset_pure: + # test batch_size + assert len(example["marker"]) == 4 + # test that batches are pure with respect to marker, fov and datasets + assert len(set(list(example["marker"].numpy()))) == 1 + assert len(set(list(example["folder_name"].numpy()))) == 1 + assert len(set(list(example["dataset"].numpy()))) == 1 diff --git a/tests/segmentation_data_prep_test.py b/tests/segmentation_data_prep_test.py index 6c44ada..a0c8687 100644 --- a/tests/segmentation_data_prep_test.py +++ b/tests/segmentation_data_prep_test.py @@ -32,8 +32,8 @@ def prep_object( def prep_object_and_inputs( - temp_dir, imaging_platform="imaging_platform", dataset="dataset", selected_markers=["CD4"], - num_folders=5, scale=[0.5, 1.0, 1.5, 2.0, 5.0], + temp_dir, imaging_platform="imaging_platform", dataset="dataset", num_folders=5, + selected_markers=["CD4"], scale=[0.5, 1.0, 1.5, 2.0, 5.0], ): # create temporary folders with data for the tests conversion_matrix = prepare_conversion_matrix()