From 1acf4d93b215628d74df2c2192e8a2897c7758ea Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Tue, 14 May 2024 21:08:49 +0200 Subject: [PATCH] style: fix style --- .github/workflows/test.yml | 6 +++--- codebleu/codebleu.py | 8 ++++++-- codebleu/syntax_match.py | 2 +- codebleu/utils.py | 11 ++++++++++- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f898b5c..51897ea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,13 +18,13 @@ jobs: cache: 'pip' # caching pip dependencies - name: Install dependencies run: | - python -m pip install -e .[test] + python -m pip install -e .[all,test] - name: Run isort check run: python -m isort codebleu --check - name: Run black check run: python -m black codebleu --check - name: Run ruff check - run: python -m ruff codebleu + run: python -m ruff check codebleu - name: Run mypy check run: python -m mypy codebleu @@ -41,7 +41,7 @@ jobs: cache: 'pip' # caching pip dependencies - name: Install lib from source and dependencies run: | - python -m pip install -e .[test] + python -m pip install -e .[all,test] - name: Run tests run: python -m pytest diff --git a/codebleu/codebleu.py b/codebleu/codebleu.py index 401adde..27855cf 100644 --- a/codebleu/codebleu.py +++ b/codebleu/codebleu.py @@ -70,10 +70,14 @@ def make_weights(reference_tokens, key_word_list): weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps) # calculate syntax match - syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language) + syntax_match_score = syntax_match.corpus_syntax_match( + references, hypothesis, lang, tree_sitter_language=tree_sitter_language + ) # calculate dataflow match - dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, tree_sitter_language=tree_sitter_language) + dataflow_match_score = dataflow_match.corpus_dataflow_match( + references, hypothesis, lang, tree_sitter_language=tree_sitter_language + ) alpha, beta, gamma, theta = weights code_bleu_score = ( diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 5e05c80..860ae12 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from tree_sitter import Language, Parser +from tree_sitter import Parser from .parser import ( DFG_csharp, diff --git a/codebleu/utils.py b/codebleu/utils.py index 98acbcb..468df81 100644 --- a/codebleu/utils.py +++ b/codebleu/utils.py @@ -9,7 +9,6 @@ from tree_sitter import Language - AVAILABLE_LANGS = [ "java", "javascript", @@ -133,36 +132,46 @@ def get_tree_sitter_language(lang: str) -> Language: try: if lang == "java": import tree_sitter_java + return Language(tree_sitter_java.language()) elif lang == "javascript": import tree_sitter_javascript + return Language(tree_sitter_javascript.language()) elif lang == "c_sharp": import tree_sitter_c_sharp + return Language(tree_sitter_c_sharp.language()) elif lang == "php": import tree_sitter_php + try: return Language(tree_sitter_php.language()) # type: ignore[attr-defined] except AttributeError: return Language(tree_sitter_php.language_php()) elif lang == "c": import tree_sitter_c + return Language(tree_sitter_c.language()) elif lang == "cpp": import tree_sitter_cpp + return Language(tree_sitter_cpp.language()) elif lang == "python": import tree_sitter_python + return Language(tree_sitter_python.language()) elif lang == "go": import tree_sitter_go + return Language(tree_sitter_go.language()) elif lang == "ruby": import tree_sitter_ruby + return Language(tree_sitter_ruby.language()) elif lang == "rust": import tree_sitter_rust + return Language(tree_sitter_rust.language()) else: assert False, "Not reachable"