From 008afb11a7d5be1091514b6a097d4cda49629851 Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Wed, 27 Nov 2024 10:30:37 -0500 Subject: [PATCH] Rename SitchEvaluators --- examples/poyo/train.py | 4 ++-- examples/poyo_plus/train.py | 4 ++-- torch_brain/utils/stitcher.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/poyo/train.py b/examples/poyo/train.py index 663b729..981f3d4 100644 --- a/examples/poyo/train.py +++ b/examples/poyo/train.py @@ -20,7 +20,7 @@ from torch_brain.models.poyo import POYOTokenizer, poyo_mp from torch_brain.utils import callbacks as tbrain_callbacks from torch_brain.utils import seed_everything -from torch_brain.utils.stitcher import MultiSessionDecodingStitchEvaluator +from torch_brain.utils.stitcher import DecodingStitchEvaluator from torch_brain.data import Dataset, collate from torch_brain.data.sampler import ( DistributedStitchingFixedWindowSampler, @@ -83,7 +83,7 @@ def main(cfg: DictConfig): modality_spec=modality_spec, ) - stitch_evaluator = MultiSessionDecodingStitchEvaluator( + stitch_evaluator = DecodingStitchEvaluator( session_ids=data_module.get_session_ids(), modality_spec=modality_spec, ) diff --git a/examples/poyo_plus/train.py b/examples/poyo_plus/train.py index 5cca621..8beef99 100644 --- a/examples/poyo_plus/train.py +++ b/examples/poyo_plus/train.py @@ -17,7 +17,7 @@ from torch_brain.utils import callbacks as tbrain_callbacks from torch_brain.utils import seed_everything from torch_brain.utils.datamodules import DataModule -from torch_brain.utils.stitcher import MultiSessionMultiTaskStitchEvaluator +from torch_brain.utils.stitcher import MultiTaskDecodingStitchEvaluator # higher speed on machines with tensor cores torch.set_float32_matmul_precision("medium") @@ -178,7 +178,7 @@ def main(cfg: DictConfig): steps_per_epoch=len(data_module.train_dataloader()), ) - evaluator = MultiSessionMultiTaskStitchEvaluator( + evaluator = MultiTaskDecodingStitchEvaluator( dataset_config_dict=data_module.get_recording_config_dict() ) diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index e11e193..2b7fa53 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -74,7 +74,7 @@ def stitch(timestamps: torch.Tensor, values: torch.Tensor) -> torch.Tensor: return averages -class MultiSessionDecodingStitchEvaluator(L.Callback): +class DecodingStitchEvaluator(L.Callback): r"""A convenient stitching and evaluation framework to use when: 1. Your model outputs have associated timestamps 2. And your sampling strategy involves overlapping time windows, requiring @@ -226,7 +226,7 @@ def on_test_epoch_end(self, *args, **kwargs): self.on_validation_epoch_end(*args, **kwargs, prefix="test") -class MultiSessionMultiTaskStitchEvaluator(L.Callback): +class MultiTaskDecodingStitchEvaluator(L.Callback): def __init__(self, dataset_config_dict: dict): metrics = defaultdict(lambda: defaultdict(dict)) # setup the metrics