Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 24, 2024
1 parent 4e3fbf2 commit a193813
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 23 deletions.
33 changes: 17 additions & 16 deletions classifiers/src/dolma_classifiers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@
"""

import argparse
from functools import partial
from hashlib import md5
from collections import abc
import logging
import multiprocessing as mp
import os
import logging
import re
import time
from itertools import chain
from collections import abc
from functools import partial
from hashlib import md5
from itertools import chain, zip_longest
from math import ceil
from itertools import zip_longest
from queue import Queue as QueueType
from typing import (
TYPE_CHECKING,
Expand All @@ -38,23 +37,25 @@
)
from urllib.parse import urlparse

import msgspec
import fsspec
import jq
import msgspec
import smart_open
from smart_open.compression import _handle_zstd
import tqdm

import torch # pyright: ignore
from torch.utils.data import IterableDataset, DataLoader, get_worker_info # pyright: ignore
import tqdm
import wandb
from smart_open.compression import _handle_zstd
from torch.nn.utils.rnn import pad_sequence

from torch.utils.data import ( # pyright: ignore
DataLoader,
IterableDataset,
get_worker_info,
)
from transformers import BatchEncoding, PreTrainedTokenizer
import wandb
import jq

from .loggers import ProgressLogger, WandbLogger, get_logger
from .models import Registry
from .utils import setup, cleanup, get_local_gpu_rank, sanitize_model_name
from .loggers import get_logger, WandbLogger, ProgressLogger
from .utils import cleanup, get_local_gpu_rank, sanitize_model_name, setup


class Document(NamedTuple):
Expand Down
2 changes: 1 addition & 1 deletion classifiers/src/dolma_classifiers/loggers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import logging
import os
import time

import wandb
Expand Down
13 changes: 9 additions & 4 deletions classifiers/src/dolma_classifiers/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from functools import partial
import torch
from typing import Type, NamedTuple
from typing import NamedTuple, Type

import torch
from torch.nn import functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from .utils import get_local_gpu_rank, sanitize_model_name
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)

from .loggers import get_logger
from .utils import get_local_gpu_rank, sanitize_model_name


class Prediction(NamedTuple):
Expand Down
7 changes: 5 additions & 2 deletions classifiers/src/dolma_classifiers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import re
from typing import Any


from smart_open.compression import _handle_zstd, get_supported_compression_types, register_compressor
import torch
import torch.distributed as dist
from smart_open.compression import (
_handle_zstd,
get_supported_compression_types,
register_compressor,
)


def get_rank_and_world_size():
Expand Down

0 comments on commit a193813

Please sign in to comment.