diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index 2b7fa53..e52eba8 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -115,7 +115,7 @@ def __init__( self, session_ids: Iterable[str], modality_spec: Optional[ModalitySpec] = None, - metric_factory: Optional[Callable[[int], ModalitySpec]] = None, + metric_factory: Optional[Callable[..., torchmetrics.Metric]] = None, quiet=False, ): r"""