Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: to work with 3.12 and add examples tests #10

Merged
merged 10 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
Expand Down
156 changes: 9 additions & 147 deletions codebleu/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,11 @@

"""BLEU score implementation."""
import math
import sys
import warnings
from collections import Counter
from fractions import Fraction as _Fraction
from typing import Any

from .utils import ngrams


# _normalize=False was removed in 3.12, add custom class for back-compatibility
class Fraction(_Fraction):
# We're immutable, so use __new__ not __init__
def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction":
if sys.version_info >= (3, 12):
return super(Fraction, cls).__new__(cls, numerator, denominator)
else:
return super(Fraction, cls).__new__(cls, numerator, denominator, _normalize=False)


def sentence_bleu(
references,
hypothesis,
Expand Down Expand Up @@ -163,9 +149,9 @@ def corpus_bleu(
# For each order of ngram, calculate the numerator and
# denominator for the corpus-level modified precision.
for i, _ in enumerate(weights, start=1):
p_i = modified_precision(references, hypothesis, i)
p_numerators[i] += p_i.numerator
p_denominators[i] += p_i.denominator
p_i_numerator, p_i_denominator = modified_precision(references, hypothesis, i)
p_numerators[i] += p_i_numerator
p_denominators[i] += p_i_denominator

# Calculate the hypothesis length and the closest reference length.
# Adds them to the corpus-level hypothesis and reference counts.
Expand All @@ -182,8 +168,8 @@ def corpus_bleu(
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
weights = (1 / hyp_lengths,) * hyp_lengths

# Collects the various precision values for the different ngram orders.
p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) for i, _ in enumerate(weights, start=1)]
# Collects the various recall values for the different ngram orders.
p_n = [(p_numerators[i], p_denominators[i]) for i, _ in enumerate(weights, start=1)]

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
Expand All @@ -199,7 +185,7 @@ def corpus_bleu(
# it tries to retain the Fraction object as much as the
# smoothing method allows.
p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths)
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n))
s = bp * math.exp(math.fsum(s))
return s

Expand Down Expand Up @@ -295,7 +281,8 @@ def modified_precision(references, hypothesis, n):
# Usually this happens when the ngram order is > len(reference).
denominator = max(1, sum(counts.values()))

return Fraction(numerator, denominator, _normalize=False)
# return Fraction(numerator, denominator, _normalize=False)
return numerator, denominator


def closest_ref_length(references, hyp_len):
Expand Down Expand Up @@ -444,133 +431,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5):
self.alpha = alpha
self.k = k

def method0(self, p_n, *args, **kwargs):
"""
No smoothing.
"""
p_n_new = []
for i, p_i in enumerate(p_n):
if p_i.numerator != 0:
p_n_new.append(p_i)
else:
_msg = str(
"\nThe hypothesis contains 0 counts of {}-gram overlaps.\n"
"Therefore the BLEU score evaluates to 0, independently of\n"
"how many N-gram overlaps of lower order it contains.\n"
"Consider using lower n-gram order or use "
"SmoothingFunction()"
).format(i + 1)
warnings.warn(_msg)
# When numerator==0 where denonminator==0 or !=0, the result
# for the precision score should be equal to 0 or undefined.
# Due to BLEU geometric mean computation in logarithm space,
# we we need to take the return sys.float_info.min such that
# math.log(sys.float_info.min) returns a 0 precision score.
p_n_new.append(sys.float_info.min)
return p_n_new

def method1(self, p_n, *args, **kwargs):
"""
Smoothing method 1: Add *epsilon* counts to precision with 0 counts.
"""
return [(p_i.numerator + self.epsilon) / p_i.denominator if p_i.numerator == 0 else p_i for p_i in p_n]

def method2(self, p_n, *args, **kwargs):
"""
Smoothing method 2: Add 1 to both numerator and denominator from
Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of
machine translation quality using longest common subsequence and
skip-bigram statistics. In ACL04.
"""
return [Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) for p_i in p_n]

def method3(self, p_n, *args, **kwargs):
"""
Smoothing method 3: NIST geometric sequence smoothing
The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each
precision score whose matching n-gram count is null.
k is 1 for the first 'n' value for which the n-gram match count is null/
For example, if the text contains:
- one 2-gram match
- and (consequently) two 1-gram matches
the n-gram count for each individual precision score would be:
- n=1 => prec_count = 2 (two unigrams)
- n=2 => prec_count = 1 (one bigram)
- n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1)
- n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2)
"""
incvnt = 1 # From the mteval-v13a.pl, it's referred to as k.
for i, p_i in enumerate(p_n):
if p_i.numerator == 0:
p_n[i] = 1 / (2**incvnt * p_i.denominator)
incvnt += 1
return p_n

def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 4:
Shorter translations may have inflated precision values due to having
smaller denominators; therefore, we give them proportionally
smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry
suggests dividing by 1/ln(len(T)), where T is the length of the translation.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
for i, p_i in enumerate(p_n):
if p_i.numerator == 0 and hyp_len != 0:
incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST.
p_n[i] = incvnt / p_i.denominator
return p_n

def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 5:
The matched counts for similar values of n should be similar. To a
calculate the n-gram matched count, it averages the n−1, n and n+1 gram
matched counts.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
m = {}
# Requires an precision value for an addition ngram order.
p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)]
m[-1] = p_n[0] + 1
for i, p_i in enumerate(p_n):
p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3
m[i] = p_n[i]
return p_n

def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 6:
Interpolates the maximum likelihood estimate of the precision *p_n* with
a prior estimate *pi0*. The prior is estimated by assuming that the ratio
between pn and pn−1 will be the same as that between pn−1 and pn−2; from
Gao and He (2013) Training MRF-Based Phrase Translation Models using
Gradient Ascent. In NAACL.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
# This smoothing only works when p_1 and p_2 is non-zero.
# Raise an error with an appropriate message when the input is too short
# to use this smoothing technique.
assert p_n[2], "This smoothing method requires non-zero precision for bigrams."
for i, p_i in enumerate(p_n):
if i in [0, 1]: # Skips the first 2 orders of ngrams.
continue
else:
pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2]
# No. of ngrams in translation that matches the reference.
m = p_i.numerator
# No. of ngrams in translation.
ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1))
# Calculates the interpolated precision.
p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha)
return p_n

def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 7:
Interpolates methods 4 and 5.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
p_n = self.method4(p_n, references, hypothesis, hyp_len)
p_n = self.method5(p_n, references, hypothesis, hyp_len)
return p_n
return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n]
4 changes: 2 additions & 2 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file):
candidate = candidates[i]
for reference in references_sample:
try:
candidate = remove_comments_and_docstrings(candidate, "java")
candidate = remove_comments_and_docstrings(candidate, lang)
except Exception:
pass
try:
reference = remove_comments_and_docstrings(reference, "java")
reference = remove_comments_and_docstrings(reference, lang)
except Exception:
pass

Expand Down
30 changes: 15 additions & 15 deletions codebleu/parser/build.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Copyright (c) 2023 Konstantin Chernyshev.
# Licensed under the MIT license.

from tree_sitter import Language

Language.build_library(
"my-languages.so",
[
"tree-sitter/go",
"tree-sitter/javascript",
"tree-sitter/python",
"tree-sitter/php",
"tree-sitter/java",
"tree-sitter/ruby",
"tree-sitter/c-sharp",
"tree-sitter/c",
"tree-sitter/cpp",
],
)
if __name__ == "__main__":
Language.build_library(

Check warning on line 7 in codebleu/parser/build.py

View check run for this annotation

Codecov / codecov/patch

codebleu/parser/build.py#L6-L7

Added lines #L6 - L7 were not covered by tests
"my-languages.so",
[
"tree-sitter/go",
"tree-sitter/javascript",
"tree-sitter/python",
"tree-sitter/php",
"tree-sitter/java",
"tree-sitter/ruby",
"tree-sitter/c-sharp",
"tree-sitter/c",
"tree-sitter/cpp",
],
)
28 changes: 17 additions & 11 deletions codebleu/syntax_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ def calc_syntax_match(references, candidate, lang, lang_so_file):


def corpus_syntax_match(references, candidates, lang, lang_so_file):
# print(os.listdir())
JAVA_LANGUAGE = Language(lang_so_file, lang)
tree_sitter_language = Language(lang_so_file, lang)
parser = Parser()
parser.set_language(JAVA_LANGUAGE)
parser.set_language(tree_sitter_language)
match_count = 0
match_count_candidate_to_reference = 0
total_count = 0

for i in range(len(candidates)):
references_sample = references[i]
candidate = candidates[i]
for reference in references_sample:
try:
candidate = remove_comments_and_docstrings(candidate, "java")
candidate = remove_comments_and_docstrings(candidate, lang)
except Exception:
pass
try:
reference = remove_comments_and_docstrings(reference, "java")
reference = remove_comments_and_docstrings(reference, lang)
except Exception:
pass

Expand All @@ -69,15 +69,21 @@ def get_all_sub_trees(root_node):
return sub_tree_sexp_list

cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
ref_sexps = get_all_sub_trees(reference_tree)
ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)]

# print(cand_sexps)
# print(ref_sexps)

for sub_tree, depth in ref_sexps:
# TODO: fix, now we count number of reference subtrees matching candidate,
# but we should count number of candidate subtrees matching reference
# See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf
for sub_tree in ref_sexps:
if sub_tree in cand_sexps:
match_count += 1
total_count += len(ref_sexps)

for sub_tree in cand_sexps:
if sub_tree in ref_sexps:
match_count_candidate_to_reference += 1

total_count += len(ref_sexps)
# print(f'match_count {match_count} / {total_count}')
# print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}')
score = match_count / total_count
return score
Loading