Skip to content

Commit

Permalink
pushing
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Aug 27, 2024
1 parent 6ae4c06 commit 7db5ee1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,22 @@ jobs:
task:
- name: Check Python style
run: |
set -e
isort --check --verbose .
black --check --verbose .
ruff check .
- name: Check Rust style
run: |
rustfmt --edition 2021 src/*.rs --check
- name: Lint Python
run: |
flake8 tests/python/ && flake8 python/
ruff check --select I .
- name: Types Python
run: |
mypy tests/python/ && mypy python/
- name: Run Python tests
run: |
pytest -vs --color=yes tests/python/
pyright .
steps:
- name: Checkout repository
uses: actions/checkout@v3
Expand All @@ -166,6 +163,11 @@ jobs:
# The cache key depends on pyproject.toml and Cargo.toml
key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml', '**/Cargo.toml, **/Cargo.lock') }}--${{ hashFiles('python/**', 'src/**') }}

- name: Install make
run: |
sudo apt-get update
sudo apt-get install -y make
- name: ${{ matrix.task.name }}
run: |
source .venv/bin/activate
Expand Down
7 changes: 5 additions & 2 deletions python/dolma/core/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,11 @@ def decompress_path(path: str, dest: Optional[str] = None) -> str:
the original path will be returned.
"""
for supported_ext in get_supported_extensions():
# explicit string conversion
ext_text = str(supported_ext)

# not the supported extension
if not path.endswith(supported_ext):
if not path.endswith(ext_text):
continue

if dest is None:
Expand All @@ -504,7 +507,7 @@ def decompress_path(path: str, dest: Optional[str] = None) -> str:

# to get the decompressed file name, we remove the bit of the extension that
# indicates the compression type.
decompressed_fn = base_fn + ext.replace(supported_ext, "")
decompressed_fn = base_fn + ext.replace(ext_text, "")

# finally, we get cache directory and join the decompressed file name to it
dest = join_path("", get_cache_dir(), decompressed_fn)
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/taggers/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def preprocess(self, text: str) -> List[Tuple[str, Tuple[int, int]]]:

def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]:
tokens, _ = zip(*self.preprocess(text_slice.text))
preds = self.classifier.predict(" ".join(tokens), k=-1)
preds = self.classifier.predict(" ".join(tokens), k=-1) # pyright: ignore
out = [
Prediction(label=label.replace("__label__", ""), score=score)
for label, score in sorted(zip(*preds), key=lambda x: x[1], reverse=True)
Expand Down

0 comments on commit 7db5ee1

Please sign in to comment.