Skip to content

Commit

Permalink
Merge pull request #19 from reworkd/alphanumic
Browse files Browse the repository at this point in the history
🫡 Aplha numeric scoring
  • Loading branch information
awtkns authored Nov 27, 2023
2 parents a639692 + 59b869a commit 4d2eecc
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 51 deletions.
68 changes: 20 additions & 48 deletions bananalyzer/data/schemas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union

import pytest
from deepdiff import DeepDiff
from playwright.async_api import Page
from pydantic import BaseModel, Field, model_validator

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.runner.evals import (
validate_end_url_match,
validate_field_match,
validate_json_match,
)

GoalType = Literal[
"fetch", # Scrape specific JSON information from a single page. Does not require navigation
Expand All @@ -18,19 +20,6 @@
]


def format_new_lines(d: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively replace newlines in strings with spaces."""
new_dict: Dict[str, Any] = {}
for k, v in d.items():
if isinstance(v, dict):
new_dict[k] = format_new_lines(v)
elif isinstance(v, str):
new_dict[k] = v.replace("\n", " ")
else:
new_dict[k] = v
return new_dict


class Eval(BaseModel):
"""
Base class for all evals. Evals are used to determine if an action or result is correct
Expand All @@ -46,39 +35,22 @@ def eval_action(self, action: str) -> bool:
raise NotImplementedError("eval_action not implemented")

def eval_results(
self, page: Page, result: Union[List[str], Dict[str, Any]]
self, page: Page, result: Dict[str, Any], field: Optional[str] = None
) -> None:
if self.type == "json_match":
# TODO: We should probably code gen to remove newlines or update test data to contain new lines
if isinstance(self.expected, dict):
assert isinstance(result, dict)
self.expected = format_new_lines(self.expected)
result = format_new_lines(result)

# TODO: Pass in schema in the backend and handle this OUTSIDE of tests
# Adding missing keys in actual with None if they are expected to be None
for key, value in self.expected.items():
if value is None and key not in result:
result[key] = None

diff = DeepDiff(
self.expected,
result,
ignore_order=True,
report_repetition=True,
)
if diff:
# Pretty print both expected and actual results
pretty_expected = json.dumps(self.expected, indent=4)
pretty_actual = json.dumps(result, indent=4)

diff_msg = f"Actual: {pretty_actual}\nExpected: {pretty_expected}"
pytest.fail(f"JSONEval mismatch!\n{diff_msg}")

elif self.type == "end_url_match":
if page.url != self.expected:
diff_msg = f"Actual URL:\t{page.url}\nExpected URL:\t{self.expected}"
pytest.fail(f"URLEval mismatch!\n{diff_msg}")
if (
self.type == "json_match"
and field is not None
and type(self.expected) is dict
):
return validate_field_match(self.expected, result, field)

if self.type == "json_match" and type(self.expected) is dict:
return validate_json_match(self.expected, result)

if self.type == "end_url_match" and type(self.expected) is str:
return validate_end_url_match(self.expected, page.url)

raise NotImplementedError("No evaluation type implemented")


class Example(BaseModel):
Expand Down
100 changes: 100 additions & 0 deletions bananalyzer/runner/evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import re
from typing import Any, Callable, Dict

import pytest
from deepdiff import DeepDiff

Result = Dict[str, Any]
NON_ALPHANUMERIC_REGEX = re.compile(r"[^a-zA-Z0-9]")


def validate_field_match(expected: Result, actual: Result, field: str) -> None:
expected_value = expected.get(field, None)
actual_value = actual.get(field, None)

matcher = get_matcher(expected_value, actual_value)
if not matcher(actual_value, expected_value):
diff_msg = f"Actual: {actual_value}\nExpected: {expected_value}"
pytest.fail(f"FieldEval mismatch!\n{diff_msg}")


def validate_json_match(expected: Result, actual: Result) -> None:
if isinstance(expected, dict):
expected = format_new_lines(expected)
actual = format_new_lines(actual)

# TODO: Pass in schema in the backend and handle this OUTSIDE of tests
# Adding missing keys in actual with None if they are expected to be None
for key, value in expected.items():
if value is None and key not in actual:
actual[key] = None

diff = DeepDiff(
expected,
actual,
ignore_order=True,
report_repetition=True,
)
if diff:
# Pretty print both expected and actual results
pretty_expected = json.dumps(expected, indent=4)
pretty_actual = json.dumps(actual, indent=4)

diff_msg = f"Actual: {pretty_actual}\nExpected: {pretty_expected}"
pytest.fail(f"JSONEval mismatch!\n{diff_msg}")


def validate_end_url_match(expected: str, actual: str) -> None:
if actual != expected:
diff_msg = f"Actual URL:\t{actual}\nExpected URL:\t{expected}"
pytest.fail(f"URLEval mismatch!\n{diff_msg}")


def format_new_lines(d: Result) -> Result:
"""Recursively replace newlines in strings with spaces."""
new_dict: Result = {}
for k, v in d.items():
if isinstance(v, dict):
new_dict[k] = format_new_lines(v)
elif isinstance(v, str):
new_dict[k] = v.replace("\n", " ")
else:
new_dict[k] = v
return new_dict


def sanitize_string(input_str: str) -> str:
return NON_ALPHANUMERIC_REGEX.sub("", input_str).lower()


def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool:
sanitized_actual = sanitize_string(actual)
sanitized_expected = sanitize_string(expected)

# Check if alphanumeric content matches
if sanitized_actual != sanitized_expected:
return False

non_alnum_actual = "".join(char for char in actual if not char.isalnum())
non_alnum_expected = "".join(char for char in expected if not char.isalnum())

# Compare the sequence of non-alphanumeric characters with a tolerance for
# additional/missing characters
diff_count = 0
for char1, char2 in zip(non_alnum_actual, non_alnum_expected):
if char1 != char2:
diff_count += 1

# Account for length difference if one sequence is longer than the other
length_diff = abs(len(non_alnum_actual) - len(non_alnum_expected))
diff_count += length_diff

return diff_count <= tolerance


def get_matcher(expected_value: Any, actual_value: Any) -> Callable[[Any, Any], bool]:
if isinstance(expected_value, str) and isinstance(actual_value, str):
return is_string_similar
else:
return lambda x, y: x == y
2 changes: 1 addition & 1 deletion bananalyzer/runner/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _generate_eval_test(self, eval_: Eval, i: int) -> str:
return f"""
@pytest.mark.parametrize("key", {list(eval_.expected.keys())})
async def test_match_field(self, key, result) -> None:
assert self.example.evals[{i}].expected.get(key, None) == result.get(key, None)
self.example.evals[{i}].eval_results(None, result, field=key)
"""
return f"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bananalyzer"
version = "0.5.4"
version = "0.5.5"
description = "Open source AI Agent evaluation framework for web tasks 🐒🍌"
authors = ["asim-shrestha <[email protected]>"]
readme = "README.md"
Expand Down
84 changes: 84 additions & 0 deletions tests/test_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

from bananalyzer.runner.evals import (
is_string_similar,
sanitize_string,
validate_field_match,
)


@pytest.mark.parametrize(
"input_str, expected",
[
("hello world!", "helloworld"),
("HELLO_WORLD", "helloworld"),
("Hello1 WoRlD@!", "hello1world"),
("", ""),
("123456", "123456"),
("!@#$%^&*()", ""),
("Hello World 123", "helloworld123"),
(" ", ""),
("HELLO", "hello"),
("hello", "hello"),
("HelloWorld2023", "helloworld2023"),
],
)
def test_sanitize_string(input_str, expected):
assert sanitize_string(input_str) == expected


@pytest.mark.parametrize(
"actual, expected, tolerance, expected_result",
[
("Hello-World", "hello world", 2, True),
("test123", "test!123", 2, True),
("string-with-chars", "stringwithchars", 2, True),
("text_with_underscores", "textwithunderscores", 2, True),
("hello", "he-llo", 1, True),
("string", "string!!", 2, True),
("foo", "foo--", 2, True),
("short", "s-h-o-r-t", 2, False),
("text", "text----", 2, False),
("word", "w-o-r-d-e", 3, False),
("name", "n-a-m-e--", 3, False),
("different", "diff3r3nt", 2, False),
("text", "txet", 2, False),
("hello", "world", 2, False),
("abc", "def", 2, False),
("example", "ex-ample", 0, False),
("", "", 2, True),
("a", "a-", 1, True),
("b", "b--", 1, False),
("c+", "c-", 0, False),
("d---", "d+++", 1, False),
("++e+++", "---e--", 0, False),
],
)
def test_is_string_similar(actual, expected, tolerance, expected_result):
assert is_string_similar(actual, expected, tolerance) == expected_result


@pytest.mark.parametrize(
"expected, actual, field",
[
({"field": "Hello World"}, {"field": "Hello-World"}, "field"),
({"field": "test"}, {"field": "test!!"}, "field"),
({"field": 123}, {"field": 123}, "field"),
({"field": [1, 2, 3]}, {"field": [1, 2, 3]}, "field"),
],
)
def test_validate_field_match_pass(expected, actual, field):
validate_field_match(expected, actual, field)


@pytest.mark.parametrize(
"expected, actual, field",
[
({"field": "example"}, {"field": "example 123"}, "field"),
({"field": "short string"}, {"field": "short string!!!"}, "field"),
({"field": [1, 2, 3]}, {"field": [1, 2]}, "field"),
],
)
def test_validate_field_match_fail(expected, actual, field):
with pytest.raises(pytest.fail.Exception):
validate_field_match(expected, actual, field)
3 changes: 2 additions & 1 deletion tests/test_example_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pytest_mock import MockFixture

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.data.schemas import Eval, Example, format_new_lines
from bananalyzer.data.schemas import Eval, Example
from bananalyzer.runner.evals import format_new_lines


def test_format_new_lines() -> None:
Expand Down

0 comments on commit 4d2eecc

Please sign in to comment.