Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added BEN scheme sampling for training #62

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/params.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ quantile_warmup_steps = 100000
confidence_thresholds = [0.1, 0.9]
ema = 0.01
location = false
BEN = false

[loss_kwargs]
from_logits = false
label_smoothing = 0.1
Expand Down
12 changes: 11 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from typing import Any, Dict, Iterator

import numpy as np
import pytest
import toml
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
sys.path.append(os.path.join(os.path.dirname(__file__), "tests"))

@pytest.fixture(scope="function")
def config_params() -> Iterator[Dict]:
params: Dict[str, Any] = toml.load("./configs/params.toml")
yield params

@pytest.fixture(scope="function")
def rng() -> Iterator[np.random.Generator]:
rng_ = np.random.default_rng(seed=42)
yield rng_
56 changes: 47 additions & 9 deletions src/cell_classification/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,56 @@ def prep_data(self):
)

# shuffle, batch and augment the datasets
self.train_dataset = self.train_dataset.shuffle(self.params["shuffle_buffer_size"]).batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
)
self.validation_datasets = [validation_dataset.batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
) for validation_dataset in self.validation_datasets]
self.test_datasets = [test_dataset.batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
) for test_dataset in self.test_datasets]
if self.params["BEN"]:
self.train_dataset = self.make_pure_batches(self.train_dataset).shuffle(
self.params["shuffle_buffer_size"]
).shuffle(self.params["shuffle_buffer_size"])
self.validation_datasets = [
self.make_pure_batches(validation_dataset) for validation_dataset in
self.validation_datasets
]
self.test_datasets = [
self.make_pure_batches(test_dataset) for test_dataset in self.test_datasets
]
else:
self.train_dataset = self.train_dataset.shuffle(self.params["shuffle_buffer_size"]).batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
)
self.validation_datasets = [validation_dataset.batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
) for validation_dataset in self.validation_datasets]
self.test_datasets = [test_dataset.batch(
self.params["batch_size"] * np.max([self.num_gpus, 1])
) for test_dataset in self.test_datasets]

self.dataset_names = self.params["dataset_names"]

def make_pure_batches(self, dataset):
"""Makes batches from a dataset such that each batch contains only one fov and marker
Args:
dataset: tf.data.Dataset
Returns:
dataset: tf.data.Dataset
"""
def key_func(example):
"""Returns a hash bucket for each example based on dataset, fov, and marker"""
i = 1e10
hash_ = tf.strings.to_hash_bucket(
example["dataset"] + example["folder_name"] + example["marker"], i
)
return hash_

def reduce_func(key, dataset):
"""Reduces the windows to batches"""
return dataset.batch(self.params["batch_size"])

dataset = dataset.group_by_window(
key_func=key_func,
reduce_func=reduce_func,
window_size=self.params["batch_size"],
)
return dataset

def prep_model(self):
"""Prepares the model for training"""
# prepare folders
Expand Down
44 changes: 43 additions & 1 deletion tests/model_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cell_classification.segmentation_data_prep import (feature_description,
parse_dict)

from .segmentation_data_prep_test import prep_object_and_inputs
from segmentation_data_prep_test import prep_object_and_inputs

tf.config.run_functions_eagerly(True)

Expand Down Expand Up @@ -57,6 +57,18 @@ def test_prep_data(config_params):
config_params["num_validation"] = [2, 2]
config_params["num_test"] = [2, 2]
config_params["batch_size"] = 2

# check if BEN works, so batches only contain samples from one dataset
config_params["BEN"] = True
trainer = ModelBuilder(config_params)
trainer.prep_data()
for batch in trainer.train_dataset:
# test that batches are pure with respect to marker, fov and datasets
assert len(set(list(batch["marker"].numpy()))) == 1
assert len(set(list(batch["folder_name"].numpy()))) == 1
assert len(set(list(batch["dataset"].numpy()))) == 1

config_params["BEN"] = False
trainer = ModelBuilder(config_params)
trainer.prep_data()

Expand Down Expand Up @@ -504,3 +516,33 @@ def test_fov_filter(config_params):
for example in dataset_filtered:
fov_list.append(example["folder_name"].numpy().decode())
assert set(fov) == set(fov_list)


def test_make_pure_batches(config_params):
with tempfile.TemporaryDirectory() as temp_dir:
data_prep, _, _, _ = prep_object_and_inputs(temp_dir, selected_markers=["CD4", "CD56"])
data_prep.tf_record_path = temp_dir
data_prep.tile_size = [32, 32]
data_prep.stride = [32, 32]
data_prep.make_tf_record()
tf_record_path = os.path.join(data_prep.tf_record_path, data_prep.dataset + ".tfrecord")
config_params["record_path"] = [tf_record_path]
config_params["path"] = temp_dir
config_params["experiment"] = "test"
config_params["dataset_names"] = ["test1"]
config_params["num_steps"] = 20
config_params["dataset_sample_probs"] = [1.0]
config_params["batch_size"] = 4
trainer = ModelBuilder(config_params)
dataset = tf.data.TFRecordDataset(tf_record_path)
dataset = dataset.map(lambda x: tf.io.parse_single_example(x, feature_description))
dataset = dataset.map(parse_dict)
dataset_pure = trainer.make_pure_batches(dataset)
dataset_pure = dataset_pure.shuffle(1000)
for example in dataset_pure:
# test batch_size
assert len(example["marker"]) == 4
# test that batches are pure with respect to marker, fov and datasets
assert len(set(list(example["marker"].numpy()))) == 1
assert len(set(list(example["folder_name"].numpy()))) == 1
assert len(set(list(example["dataset"].numpy()))) == 1
4 changes: 2 additions & 2 deletions tests/segmentation_data_prep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def prep_object(


def prep_object_and_inputs(
temp_dir, imaging_platform="imaging_platform", dataset="dataset", selected_markers=["CD4"],
num_folders=5, scale=[0.5, 1.0, 1.5, 2.0, 5.0],
temp_dir, imaging_platform="imaging_platform", dataset="dataset", num_folders=5,
selected_markers=["CD4"], scale=[0.5, 1.0, 1.5, 2.0, 5.0],
):
# create temporary folders with data for the tests
conversion_matrix = prepare_conversion_matrix()
Expand Down