From 69230b9d15d38f89bcb5f5464efdf887a2272688 Mon Sep 17 00:00:00 2001 From: Vinam Arora Date: Thu, 28 Nov 2024 02:28:35 -0500 Subject: [PATCH] Fix metric_factory's type --- torch_brain/utils/stitcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index 2b7fa53..e52eba8 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -115,7 +115,7 @@ def __init__( self, session_ids: Iterable[str], modality_spec: Optional[ModalitySpec] = None, - metric_factory: Optional[Callable[[int], ModalitySpec]] = None, + metric_factory: Optional[Callable[..., torchmetrics.Metric]] = None, quiet=False, ): r"""