Skip to content

Commit

Permalink
Rename SitchEvaluators
Browse files Browse the repository at this point in the history
  • Loading branch information
vinamarora8 committed Nov 27, 2024
1 parent ab185b6 commit 008afb1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/poyo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch_brain.models.poyo import POYOTokenizer, poyo_mp
from torch_brain.utils import callbacks as tbrain_callbacks
from torch_brain.utils import seed_everything
from torch_brain.utils.stitcher import MultiSessionDecodingStitchEvaluator
from torch_brain.utils.stitcher import DecodingStitchEvaluator
from torch_brain.data import Dataset, collate
from torch_brain.data.sampler import (
DistributedStitchingFixedWindowSampler,
Expand Down Expand Up @@ -83,7 +83,7 @@ def main(cfg: DictConfig):
modality_spec=modality_spec,
)

stitch_evaluator = MultiSessionDecodingStitchEvaluator(
stitch_evaluator = DecodingStitchEvaluator(
session_ids=data_module.get_session_ids(),
modality_spec=modality_spec,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/poyo_plus/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch_brain.utils import callbacks as tbrain_callbacks
from torch_brain.utils import seed_everything
from torch_brain.utils.datamodules import DataModule
from torch_brain.utils.stitcher import MultiSessionMultiTaskStitchEvaluator
from torch_brain.utils.stitcher import MultiTaskDecodingStitchEvaluator

# higher speed on machines with tensor cores
torch.set_float32_matmul_precision("medium")
Expand Down Expand Up @@ -178,7 +178,7 @@ def main(cfg: DictConfig):
steps_per_epoch=len(data_module.train_dataloader()),
)

evaluator = MultiSessionMultiTaskStitchEvaluator(
evaluator = MultiTaskDecodingStitchEvaluator(
dataset_config_dict=data_module.get_recording_config_dict()
)

Expand Down
4 changes: 2 additions & 2 deletions torch_brain/utils/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def stitch(timestamps: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
return averages


class MultiSessionDecodingStitchEvaluator(L.Callback):
class DecodingStitchEvaluator(L.Callback):
r"""A convenient stitching and evaluation framework to use when:
1. Your model outputs have associated timestamps
2. And your sampling strategy involves overlapping time windows, requiring
Expand Down Expand Up @@ -226,7 +226,7 @@ def on_test_epoch_end(self, *args, **kwargs):
self.on_validation_epoch_end(*args, **kwargs, prefix="test")


class MultiSessionMultiTaskStitchEvaluator(L.Callback):
class MultiTaskDecodingStitchEvaluator(L.Callback):
def __init__(self, dataset_config_dict: dict):
metrics = defaultdict(lambda: defaultdict(dict))
# setup the metrics
Expand Down

0 comments on commit 008afb1

Please sign in to comment.