Skip to content

Commit

Permalink
imports
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 24, 2024
1 parent e14a36b commit 8f22428
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
3 changes: 1 addition & 2 deletions classifiers/src/dolma_classifiers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def writer_worker(
break

# I've finished processing this source; close the file
stack.pop(path.source).close()

stack.pop(path.source).close() # pyright: ignore
console_logger.info(f"Closed {source_destination_mapping[path.source]}")
progress_logger.increment(files=1)

Expand Down
14 changes: 9 additions & 5 deletions classifiers/src/dolma_classifiers/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import re
from contextlib import ExitStack
from typing import Any, ContextManager, Dict
from hashlib import md5
from typing import Any, ContextManager, Dict, Generic, TypeVar

import msgspec
import torch
import torch.distributed as dist
from smart_open.compression import (
Expand Down Expand Up @@ -50,12 +52,14 @@ def sanitize_model_name(model_name: str, suffix_data: Any = None) -> str:
return stripped_trailing_underscores


class KeyedExitStack:
T = TypeVar("T")

class KeyedExitStack(Generic[T]):
"""From https://claude.site/artifacts/7150ff45-3cb1-41e5-be5c-0f0890aa332e"""

def __init__(self):
self.stack = ExitStack()
self.resources: Dict[str, ContextManager] = {}
self.resources: Dict[str, T] = {}

def __enter__(self):
self.stack.__enter__()
Expand All @@ -64,7 +68,7 @@ def __enter__(self):
def __exit__(self, *exc_details):
return self.stack.__exit__(*exc_details)

def push(self, key: str, cm: ContextManager) -> Any:
def push(self, key: str, cm: ContextManager[T]) -> T:
"""Push a context manager onto the stack with an associated key."""
resource = self.stack.enter_context(cm)
self.resources[key] = resource
Expand All @@ -74,7 +78,7 @@ def __contains__(self, key: str) -> bool:
"""Check if a resource with the given key is in the stack."""
return key in self.resources

def __getitem__(self, key: str) -> Any:
def __getitem__(self, key: str) -> T:
"""Get a resource by key."""
return self.resources[key]

Expand Down

0 comments on commit 8f22428

Please sign in to comment.