Skip to content

Commit

Permalink
Cherry pick zstd compressor (#180)
Browse files Browse the repository at this point in the history
* Cherry pick zstd compressor

* Bump version

* Pin rust in gh action

* force linux matrix to grab pinned rust

* Try 1.78.0

* Force time version
  • Loading branch information
undfined authored Jul 26, 2024
1 parent 5868453 commit a01a222
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 65 deletions.
77 changes: 37 additions & 40 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
name: CI

on:
push:
branches:
- main
- master
tags:
- '*'
pull_request:
branches:
- main
- master
workflow_dispatch:

push:
branches:
- main
- master
tags:
- "*"
pull_request:
branches:
- main
- master
workflow_dispatch:

permissions:
contents: read
Expand All @@ -22,7 +21,6 @@ env:
DOLMA_TEST_S3_PREFIX: s3://dolma-tests
RUST_CHANNEL: stable


jobs:
info:
name: Run info
Expand All @@ -40,32 +38,31 @@ jobs:
echo "PR base repo: ${{ github.event.pull_request.base.repo.full_name }}/tree/${{ github.event.pull_request.base.ref }}"
echo "PR head repo: ${{ github.event.pull_request.head.repo.full_name }}/tree/${{ github.event.pull_request.head.ref }}"
should_build:
name: "Check if build"
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.ref }}
- name: List branches and tags
run: |
git branch -a
git tag -l
git log | head -n 1000
- id: check_version
run: |
set +e
has_updated=$(git diff --name-only '${{ github.event.pull_request.base.sha }}' | grep -E 'pyproject.toml|Cargo.toml')
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}'
if [[ -n "${has_updated}" ]] || [[ "${is_main_or_release}" == 'true' ]]; then
echo "should_build=true" >> $GITHUB_OUTPUT
else
echo "should_build=false" >> $GITHUB_OUTPUT
fi
shell: bash
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.ref }}
- name: List branches and tags
run: |
git branch -a
git tag -l
git log | head -n 1000
- id: check_version
run: |
set +e
has_updated=$(git diff --name-only '${{ github.event.pull_request.base.sha }}' | grep -E 'pyproject.toml|Cargo.toml')
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}'
if [[ -n "${has_updated}" ]] || [[ "${is_main_or_release}" == 'true' ]]; then
echo "should_build=true" >> $GITHUB_OUTPUT
else
echo "should_build=false" >> $GITHUB_OUTPUT
fi
shell: bash
outputs:
should_build: ${{ steps.check_version.outputs.should_build }}

Expand All @@ -88,7 +85,7 @@ jobs:

- name: Setup system libraries
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
run: |
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source musl-tools
Expand All @@ -103,7 +100,7 @@ jobs:
if: steps.cache-venv.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: '3.8'
python-version: "3.8"
architecture: "x64"

- name: Create a new Python environment & install maturin
Expand Down Expand Up @@ -188,7 +185,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
- name: Install 32bit version of libc
if: ${{ matrix.target == 'x86' || contains(matrix.target, 'i686') }}
run: |
Expand Down Expand Up @@ -222,7 +219,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
architecture: ${{ matrix.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
Expand All @@ -247,7 +244,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ threadpool = "1.8.1"
tokenizers = { version = "0.15.0", features = ["http"] }
tokio = { version = "1.27.0", features = ["full"] }
tokio-util = "0.7.7"
time = "0.3.36"
unicode-segmentation = "1.7"
openssl = { version = "0.10.63", features = ["vendored"] }
adblock = { version = "0.8.6", features = ["content-blocking"] }
Expand Down
31 changes: 10 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.0.4"
version = "1.0.5"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down Expand Up @@ -30,6 +30,7 @@ dependencies = [
"numpy",
"necessary>=0.4.3",
"charset-normalizer>=3.2.0",
"zstandard>=0.23.0",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down Expand Up @@ -117,35 +118,27 @@ pii = ["presidio_analyzer==2.2.32", "regex"]
# language detection; by default, we use fastttext, everything else is optional
lang = [
"fasttext-wheel==0.9.2",
"LTpycld2==0.42", # fork of pycld2 that works on Apple Silicon
"LTpycld2==0.42", # fork of pycld2 that works on Apple Silicon
"lingua-language-detector>=2.0.0",
"langdetect>=1.0.9"
"langdetect>=1.0.9",
]

# extension to parse warc files
warc = [
"fastwarc",
"w3lib",
"url-normalize",

]
warc = ["fastwarc", "w3lib", "url-normalize"]
trafilatura = [
# must include warc dependencies
"dolma[warc]",
# core package
"trafilatura>=1.6.1",
# following are all for speeding up trafilatura
"brotli",
"cchardet >= 2.1.7; python_version < '3.11'", # build issue
"faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
"cchardet >= 2.1.7; python_version < '3.11'", # build issue
"faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
"htmldate[speed] >= 1.4.3",
"py3langid >= 0.2.2",
]

resiliparse = [
"dolma[warc]",
"resiliparse",
]
resiliparse = ["dolma[warc]", "resiliparse"]

# all extensions
all = [
Expand All @@ -154,15 +147,11 @@ all = [
"dolma[pii]",
"dolma[trafilatura]",
"dolma[resiliparse]",
"dolma[lang]"
"dolma[lang]",
]

[build-system]
requires = [
"maturin[patchelf]>=1.1,<2.0",
"setuptools >= 61.0.0",
"wheel"
]
requires = ["maturin[patchelf]>=1.1,<2.0", "setuptools >= 61.0.0", "wheel"]
build-backend = "maturin"


Expand Down
38 changes: 38 additions & 0 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import io
import os
import re
import string
Expand All @@ -14,8 +15,11 @@

import nltk
import uniseg.wordbreak
import zstandard
from necessary import necessary
from nltk.tokenize.punkt import PunktSentenceTokenizer
from omegaconf import OmegaConf as om
from smart_open import register_compressor

try:
nltk.data.find("tokenizers/punkt")
Expand Down Expand Up @@ -148,3 +152,37 @@ def dataclass_to_dict(dataclass_instance) -> dict:

# force typecasting because a dataclass instance will always be a dict
return cast(dict, om.to_object(om.structured(dataclass_instance)))


def add_compression():
"""
Adds support for zstandard (.zst) compression format to the smart_open library.
This function registers a custom compressor for the .zst file extension in the smart_open library.
The compressor uses the zstandard library to handle zstandard compression.
"""

def _handle_zstd(file_obj, mode):
result = zstandard.open(filename=file_obj, mode=mode)
# zstandard.open returns an io.TextIOWrapper in text mode, but otherwise
# returns a raw stream reader/writer, and we need the `io` wrapper
# to make FileLikeProxy work correctly.
if "b" in mode and "w" in mode:
result = io.BufferedWriter(result)
elif "b" in mode and "r" in mode:
result = io.BufferedReader(result)
return result

register_compressor(".zst", _handle_zstd)
register_compressor(".zstd", _handle_zstd)


with necessary(("smart_open", "7.0.4"), soft=True) as SMART_OPEN_HAS_ZSTD:
if SMART_OPEN_HAS_ZSTD:
# add additional extension for smart_open
from smart_open.compression import _handle_zstd

register_compressor(".zstd", _handle_zstd)
else:
# add zstd compression
add_compression()

0 comments on commit a01a222

Please sign in to comment.