Skip to content

Commit

Permalink
🆙 Bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
awtkns committed Nov 28, 2023
1 parent 9003787 commit f627be9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
23 changes: 13 additions & 10 deletions bananalyzer/runner/evals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import re
from difflib import SequenceMatcher
from typing import Any, Callable, Dict

import pytest
from deepdiff import DeepDiff

Result = Dict[str, Any]
NON_ALPHANUMERIC_REGEX = re.compile(r"[^a-zA-Z0-9]")


def validate_field_match(expected: Result, actual: Result, field: str) -> None:
Expand Down Expand Up @@ -65,35 +63,40 @@ def format_new_lines(d: Result) -> Result:


def sanitize_string(input_str: str) -> str:
return NON_ALPHANUMERIC_REGEX.sub("", input_str).lower()
return "".join(char for char in input_str if char.isalnum()).lower()


def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool:
if tolerance == 0:
return actual == expected

sanitized_actual = sanitize_string(actual)
sanitized_expected = sanitize_string(expected)

# Check if alphanumeric content matches
if sanitized_actual != sanitized_expected:
return False

diff_count = native_count_differences(actual, expected)
if diff_count <= tolerance:
return True

return SequenceMatcher(None, actual, expected).ratio() >= 0.8


def native_count_differences(actual: str, expected: str):
non_alnum_actual = "".join(char for char in actual if not char.isalnum())
non_alnum_expected = "".join(char for char in expected if not char.isalnum())

# Compare the sequence of non-alphanumeric characters with a tolerance for
# additional/missing characters
diff_count = 0
for char1, char2 in zip(non_alnum_actual, non_alnum_expected):
if char1 != char2:
diff_count += 1

# Account for length difference if one sequence is longer than the other
length_diff = abs(len(non_alnum_actual) - len(non_alnum_expected))
diff_count += length_diff

if diff_count <= tolerance:
return True

return SequenceMatcher(None, non_alnum_actual, non_alnum_expected).ratio() >= 0.7
return diff_count


def get_matcher(expected_value: Any, actual_value: Any) -> Callable[[Any, Any], bool]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bananalyzer"
version = "0.5.7"
version = "0.5.8"
description = "Open source AI Agent evaluation framework for web tasks 🐒🍌"
authors = ["asim-shrestha <[email protected]>"]
readme = "README.md"
Expand Down
8 changes: 7 additions & 1 deletion tests/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def test_sanitize_string(input_str, expected):
2,
True,
),
(
"160 Falmouth Road Mashpee MA 02649",
"160 Falmouth Road, Mashpee, MA 02649",
2,
True,
),
],
)
def test_is_string_similar(actual, expected, tolerance, expected_result):
Expand All @@ -87,7 +93,7 @@ def test_validate_field_match_pass(expected, actual, field):
"expected, actual, field",
[
({"field": "example"}, {"field": "example 123"}, "field"),
({"field": "short string"}, {"field": "short string!!!"}, "field"),
({"field": "short's string"}, {"field": "~~~~~~shorts string!!!"}, "field"),
({"field": [1, 2, 3]}, {"field": [1, 2]}, "field"),
],
)
Expand Down

0 comments on commit f627be9

Please sign in to comment.