From 8f22428d4c25de57d829d5901337fc07f9e7a523 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 24 Oct 2024 00:59:01 -0700 Subject: [PATCH] imports --- classifiers/src/dolma_classifiers/inference.py | 3 +-- classifiers/src/dolma_classifiers/utils.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/classifiers/src/dolma_classifiers/inference.py b/classifiers/src/dolma_classifiers/inference.py index 6bfa7173..6a36d326 100644 --- a/classifiers/src/dolma_classifiers/inference.py +++ b/classifiers/src/dolma_classifiers/inference.py @@ -182,8 +182,7 @@ def writer_worker( break # I've finished processing this source; close the file - stack.pop(path.source).close() - + stack.pop(path.source).close() # pyright: ignore console_logger.info(f"Closed {source_destination_mapping[path.source]}") progress_logger.increment(files=1) diff --git a/classifiers/src/dolma_classifiers/utils.py b/classifiers/src/dolma_classifiers/utils.py index 2f04967e..a4db13e2 100644 --- a/classifiers/src/dolma_classifiers/utils.py +++ b/classifiers/src/dolma_classifiers/utils.py @@ -1,8 +1,10 @@ import os import re from contextlib import ExitStack -from typing import Any, ContextManager, Dict +from hashlib import md5 +from typing import Any, ContextManager, Dict, Generic, TypeVar +import msgspec import torch import torch.distributed as dist from smart_open.compression import ( @@ -50,12 +52,14 @@ def sanitize_model_name(model_name: str, suffix_data: Any = None) -> str: return stripped_trailing_underscores -class KeyedExitStack: +T = TypeVar("T") + +class KeyedExitStack(Generic[T]): """From https://claude.site/artifacts/7150ff45-3cb1-41e5-be5c-0f0890aa332e""" def __init__(self): self.stack = ExitStack() - self.resources: Dict[str, ContextManager] = {} + self.resources: Dict[str, T] = {} def __enter__(self): self.stack.__enter__() @@ -64,7 +68,7 @@ def __enter__(self): def __exit__(self, *exc_details): return self.stack.__exit__(*exc_details) - def push(self, key: str, cm: ContextManager) -> Any: + def push(self, key: str, cm: ContextManager[T]) -> T: """Push a context manager onto the stack with an associated key.""" resource = self.stack.enter_context(cm) self.resources[key] = resource @@ -74,7 +78,7 @@ def __contains__(self, key: str) -> bool: """Check if a resource with the given key is in the stack.""" return key in self.resources - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> T: """Get a resource by key.""" return self.resources[key]