Skip to content

Commit

Permalink
feat: update for tree-sitter==0.22 with pre-build wheels
Browse files Browse the repository at this point in the history
  • Loading branch information
k4black committed May 14, 2024
1 parent 413982c commit c9fc0ad
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 124 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ jobs:
pip uninstall -y codebleu || true
# TODO: check the sdist package is not installed
pip install --upgrade --no-deps --no-index --find-links=./dist codebleu
# install dependencies for the package and tests
pip install .[test]
# install dependencies for the package languages and tests
pip install .[all,test]
- name: Test itself
run: python -m pytest --cov-report=xml
- name: Upload coverage
Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ or directly from git repo (require internet connection to download tree-sitter):
pip install git+https://github.com/k4black/codebleu.git
```

Also you have to install tree-sitter language you need (e.g. python, rust, etc):
```bash
pip install tree-sitter-python
```
Or you can install all languages:
```bash
pip install codebleu[all]
```


## Usage

Expand Down Expand Up @@ -96,11 +105,11 @@ Make your own fork and clone it:
git clone https://github.com/k4black/codebleu
```

For development, you need to install library (for so file to compile) with `test` extra:
For development, you need to install library with `all` precompiled languages and `test` extra:
(require internet connection to download tree-sitter)
```bash
python -m pip install -e .[test]
python -m pip install -e .\[test\] # for macos
python -m pip install -e .[all,test]
python -m pip install -e .\[all,test\] # for macos
```

For testing just run pytest:
Expand Down
22 changes: 6 additions & 16 deletions codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,9 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

from . import bleu, dataflow_match, syntax_match, weighted_ngram_match
from .utils import AVAILABLE_LANGS, get_tree_sitter_language

PACKAGE_DIR = Path(__file__).parent
AVAILABLE_LANGS = [
"java",
"javascript",
"c_sharp",
"php",
"c",
"cpp",
"python",
"go",
"ruby",
"rust",
] # keywords available


def calc_codebleu(
Expand All @@ -28,7 +17,6 @@ def calc_codebleu(
weights: Tuple[float, float, float, float] = (0.25, 0.25, 0.25, 0.25),
tokenizer: Optional[Callable] = None,
keywords_dir: Path = PACKAGE_DIR / "keywords",
lang_so_file: Path = PACKAGE_DIR / "my-languages.so",
) -> Dict[str, float]:
"""Calculate CodeBLEU score
Expand All @@ -48,7 +36,9 @@ def calc_codebleu(
assert lang in AVAILABLE_LANGS, f"Language {lang} is not supported (yet). Available languages: {AVAILABLE_LANGS}"
assert len(weights) == 4, "weights should be a tuple of 4 floats (alpha, beta, gamma, theta)"
assert keywords_dir.exists(), f"keywords_dir {keywords_dir} does not exist"
assert lang_so_file.exists(), f"lang_so_file {lang_so_file} does not exist"

# get the tree-sitter language for a given language
tree_sitter_language = get_tree_sitter_language(lang)

# preprocess inputs
references = [[x.strip() for x in ref] if isinstance(ref, list) else [ref.strip()] for ref in references]
Expand Down Expand Up @@ -80,10 +70,10 @@ def make_weights(reference_tokens, key_word_list):
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps)

# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang, str(lang_so_file))
syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language)

# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, str(lang_so_file))
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language)

alpha, beta, gamma, theta = weights
code_bleu_score = (
Expand Down
11 changes: 7 additions & 4 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.
import logging

from tree_sitter import Language, Parser
from tree_sitter import Parser

from .parser import (
DFG_csharp,
Expand All @@ -17,6 +17,7 @@
remove_comments_and_docstrings,
tree_to_token_index,
)
from .utils import get_tree_sitter_language

dfg_function = {
"python": DFG_python,
Expand All @@ -36,10 +37,12 @@ def calc_dataflow_match(references, candidate, lang, langso_so_file):
return corpus_dataflow_match([references], [candidate], lang, langso_so_file)


def corpus_dataflow_match(references, candidates, lang, langso_so_file):
LANGUAGE = Language(langso_so_file, lang)
def corpus_dataflow_match(references, candidates, lang, tree_sitter_language=None):
if not tree_sitter_language:
tree_sitter_language = get_tree_sitter_language(lang)

parser = Parser()
parser.set_language(LANGUAGE)
parser.language = tree_sitter_language
parser = [parser, dfg_function[lang]]
match_count = 0
total_count = 0
Expand Down
15 changes: 9 additions & 6 deletions codebleu/syntax_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DFG_ruby,
remove_comments_and_docstrings,
)
from .utils import get_tree_sitter_language

dfg_function = {
"python": DFG_python,
Expand All @@ -25,14 +26,16 @@
}


def calc_syntax_match(references, candidate, lang, lang_so_file):
return corpus_syntax_match([references], [candidate], lang, lang_so_file)
def calc_syntax_match(references, candidate, lang):
return corpus_syntax_match([references], [candidate], lang)


def corpus_syntax_match(references, candidates, lang, lang_so_file):
tree_sitter_language = Language(lang_so_file, lang)
def corpus_syntax_match(references, candidates, lang, tree_sitter_language=None):
if not tree_sitter_language:
tree_sitter_language = get_tree_sitter_language(lang)

parser = Parser()
parser.set_language(tree_sitter_language)
parser.language = tree_sitter_language
match_count = 0
match_count_candidate_to_reference = 0
total_count = 0
Expand Down Expand Up @@ -61,7 +64,7 @@ def get_all_sub_trees(root_node):
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth])
sub_tree_sexp_list.append([str(cur_node), cur_depth])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
Expand Down
66 changes: 66 additions & 0 deletions codebleu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@

from itertools import chain

from tree_sitter import Language


AVAILABLE_LANGS = [
"java",
"javascript",
"c_sharp",
"php",
"c",
"cpp",
"python",
"go",
"ruby",
"rust",
] # keywords available


def pad_sequence(
sequence,
Expand Down Expand Up @@ -104,3 +120,53 @@ def ngrams(
history.append(item)
yield tuple(history)
del history[0]


def get_tree_sitter_language(lang: str) -> Language:
"""
Get the tree-sitter language for a given language.
:param lang: the language name to get the tree-sitter language for
:return: the tree-sitter language
"""
assert lang in AVAILABLE_LANGS, f"Language {lang} not available. Available languages: {AVAILABLE_LANGS}"

try:
if lang == "java":
import tree_sitter_java
return Language(tree_sitter_java.language())
elif lang == "javascript":
import tree_sitter_javascript
return Language(tree_sitter_javascript.language())
elif lang == "c_sharp":
import tree_sitter_c_sharp
return Language(tree_sitter_c_sharp.language())
elif lang == "php":
import tree_sitter_php
try:
return Language(tree_sitter_php.language()) # type: ignore[attr-defined]
except AttributeError:
return Language(tree_sitter_php.language_php())
elif lang == "c":
import tree_sitter_c
return Language(tree_sitter_c.language())
elif lang == "cpp":
import tree_sitter_cpp
return Language(tree_sitter_cpp.language())
elif lang == "python":
import tree_sitter_python
return Language(tree_sitter_python.language())
elif lang == "go":
import tree_sitter_go
return Language(tree_sitter_go.language())
elif lang == "ruby":
import tree_sitter_ruby
return Language(tree_sitter_ruby.language())
elif lang == "rust":
import tree_sitter_rust
return Language(tree_sitter_rust.language())
else:
assert False, "Not reachable"
except ImportError:
raise ImportError(
f"Tree-sitter language for {lang} not available. Please install the language parser using `pip install tree-sitter-{lang}`."
)
2 changes: 1 addition & 1 deletion evaluate_app/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Each of the scores is in range `[0, 1]`, where `1` is the best score.

[//]: # (*Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*)

Using pip package (`pip install codebleu`):
Using pip package (`pip install codebleu`), also you have to install tree-sitter language you need (e.g. `pip install tree-sitter-python` or `pip install codebleu[all]` to install all languages):
```python
from codebleu import calc_codebleu

Expand Down
2 changes: 1 addition & 1 deletion evaluate_app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
git+https://github.com/huggingface/evaluate@main
codebleu>=0.2.0,<1.0.0
codebleu>=0.5.0,<1.0.0
20 changes: 16 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=61.0.0", "wheel", "tree-sitter>=0.20.0,<1.0.0", "requests>=2.0.0,<3.0.0"]
requires = ["setuptools>=61.0.0", "wheel"]
build-backend = "setuptools.build_meta"


Expand All @@ -9,7 +9,7 @@ description = "Unofficial CodeBLEU implementation that supports Linux, MacOS and
readme = "README.md"
license = {text = "MIT License"}
authors = [
{name = "Konstantin Chernyshev", email = "[email protected]"},
{name = "Konstantin Chernyshev", email = "kdchernyshev+github@gmail.com"},
]
keywords = ["codebleu", "code", "bleu", "nlp", "natural language processing", "programming", "evaluate", "evaluation", "code generation", "metrics"]
dynamic = ["version"]
Expand All @@ -23,7 +23,7 @@ classifiers = [
]

dependencies = [
"tree-sitter >=0.20.0,<1.0.0",
"tree-sitter >=0.22.0,<1.0.0",
"setuptools >=61.0.0", # distutils removed in 3.12, but distutils.ccompiler used in tree-sitter
]

Expand All @@ -37,7 +37,7 @@ exclude = ["tests", "tests.*", "codebleu.parser.tree-sitter"]


[tool.setuptools.package-data]
"*" = ["py.typed", "*.txt", "*.so", "*.dylib", "*.dll", "keywords/*"]
"*" = ["py.typed", "*.txt", "keywords/*"]


[project.scripts]
Expand All @@ -47,6 +47,18 @@ codebleu = "codebleu.__main__:main"
homepage = "https://github.com/k4black/codebleu"

[project.optional-dependencies]
all = [
"tree-sitter-python ~=0.21",
"tree-sitter-go ~=0.21",
"tree-sitter-javascript ~=0.21",
"tree-sitter-ruby ~=0.21",
"tree-sitter-php ~=0.22",
"tree-sitter-java ~=0.21",
"tree-sitter-c-sharp ~=0.21",
"tree-sitter-c ~=0.21",
"tree-sitter-cpp ~=0.22",
"tree-sitter-rust ~=0.21",
]
test = [
"pytest >=7.0.0,<9.0.0",
"pytest-cov >=4.0.0,<6.0.0",
Expand Down
87 changes: 0 additions & 87 deletions setup.py

This file was deleted.

0 comments on commit c9fc0ad

Please sign in to comment.