Skip to content

Commit

Permalink
style: fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
k4black committed May 14, 2024
1 parent c9fc0ad commit 1acf4d9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ jobs:
cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
python -m pip install -e .[test]
python -m pip install -e .[all,test]
- name: Run isort check
run: python -m isort codebleu --check
- name: Run black check
run: python -m black codebleu --check
- name: Run ruff check
run: python -m ruff codebleu
run: python -m ruff check codebleu
- name: Run mypy check
run: python -m mypy codebleu

Expand All @@ -41,7 +41,7 @@ jobs:
cache: 'pip' # caching pip dependencies
- name: Install lib from source and dependencies
run: |
python -m pip install -e .[test]
python -m pip install -e .[all,test]
- name: Run tests
run: python -m pytest

Expand Down
8 changes: 6 additions & 2 deletions codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ def make_weights(reference_tokens, key_word_list):
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps)

# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language)
syntax_match_score = syntax_match.corpus_syntax_match(
references, hypothesis, lang, tree_sitter_language=tree_sitter_language
)

# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language)
dataflow_match_score = dataflow_match.corpus_dataflow_match(
references, hypothesis, lang, tree_sitter_language=tree_sitter_language
)

alpha, beta, gamma, theta = weights
code_bleu_score = (
Expand Down
2 changes: 1 addition & 1 deletion codebleu/syntax_match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from tree_sitter import Language, Parser
from tree_sitter import Parser

from .parser import (
DFG_csharp,
Expand Down
11 changes: 10 additions & 1 deletion codebleu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from tree_sitter import Language


AVAILABLE_LANGS = [
"java",
"javascript",
Expand Down Expand Up @@ -133,36 +132,46 @@ def get_tree_sitter_language(lang: str) -> Language:
try:
if lang == "java":
import tree_sitter_java

return Language(tree_sitter_java.language())
elif lang == "javascript":
import tree_sitter_javascript

return Language(tree_sitter_javascript.language())
elif lang == "c_sharp":
import tree_sitter_c_sharp

return Language(tree_sitter_c_sharp.language())
elif lang == "php":
import tree_sitter_php

try:
return Language(tree_sitter_php.language()) # type: ignore[attr-defined]
except AttributeError:
return Language(tree_sitter_php.language_php())
elif lang == "c":
import tree_sitter_c

return Language(tree_sitter_c.language())
elif lang == "cpp":
import tree_sitter_cpp

return Language(tree_sitter_cpp.language())
elif lang == "python":
import tree_sitter_python

return Language(tree_sitter_python.language())
elif lang == "go":
import tree_sitter_go

return Language(tree_sitter_go.language())
elif lang == "ruby":
import tree_sitter_ruby

return Language(tree_sitter_ruby.language())
elif lang == "rust":
import tree_sitter_rust

return Language(tree_sitter_rust.language())
else:
assert False, "Not reachable"
Expand Down

0 comments on commit 1acf4d9

Please sign in to comment.