From 7973435a23d42ecb4b216add87ba71f3652b05e6 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:34:32 -0800 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20alpha=20numeric=20scor?= =?UTF-8?q?ing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/data/schemas.py | 70 +++++++++-------------------- bananalyzer/runner/evals.py | 80 +++++++++++++++++++++++++++++++++ bananalyzer/runner/generator.py | 2 +- tests/test_evals.py | 48 ++++++++++++++++++++ tests/test_example.py | 3 +- 5 files changed, 153 insertions(+), 50 deletions(-) create mode 100644 bananalyzer/runner/evals.py create mode 100644 tests/test_evals.py diff --git a/bananalyzer/data/schemas.py b/bananalyzer/data/schemas.py index 18d3ef02..fcd5c42c 100644 --- a/bananalyzer/data/schemas.py +++ b/bananalyzer/data/schemas.py @@ -1,12 +1,14 @@ -import json from typing import Any, Dict, List, Literal, Optional, Union -import pytest -from deepdiff import DeepDiff from playwright.async_api import Page from pydantic import BaseModel, Field, model_validator from bananalyzer.data.fetch_schemas import fetch_schemas +from bananalyzer.runner.evals import ( + validate_end_url_match, + validate_field_match, + validate_json_match, +) GoalType = Literal[ "fetch", # Scrape specific JSON information from a single page. Does not require navigation @@ -18,19 +20,6 @@ ] -def format_new_lines(d: Dict[str, Any]) -> Dict[str, Any]: - """Recursively replace newlines in strings with spaces.""" - new_dict: Dict[str, Any] = {} - for k, v in d.items(): - if isinstance(v, dict): - new_dict[k] = format_new_lines(v) - elif isinstance(v, str): - new_dict[k] = v.replace("\n", " ") - else: - new_dict[k] = v - return new_dict - - class Eval(BaseModel): """ Base class for all evals. Evals are used to determine if an action or result is correct @@ -45,38 +34,23 @@ def eval_action(self, action: str) -> bool: """ raise NotImplementedError("eval_action not implemented") - def eval_results(self, page: Page, result: Dict[str, Any]) -> None: - if self.type == "json_match": - assert isinstance(self.expected, dict) - - # TODO: We should probably code gen to remove newlines or update test data to contain new lines - formatted_expected = format_new_lines(self.expected) - formatted_actual = format_new_lines(result) - - # TODO: Pass in schema in the backend and handle this OUTSIDE of tests - # Adding missing keys in actual with None if they are expected to be None - for key, value in formatted_expected.items(): - if value is None and key not in formatted_actual: - formatted_actual[key] = None - - diff = DeepDiff( - formatted_expected, - formatted_actual, - ignore_order=True, - report_repetition=True, - ) - if diff: - # Pretty print both expected and actual results - pretty_expected = json.dumps(formatted_expected, indent=4) - pretty_actual = json.dumps(formatted_actual, indent=4) - - diff_msg = f"Actual: {pretty_actual}\nExpected: {pretty_expected}" - pytest.fail(f"JSONEval mismatch!\n{diff_msg}") - - elif self.type == "end_url_match": - if page.url != self.expected: - diff_msg = f"Actual URL:\t{page.url}\nExpected URL:\t{self.expected}" - pytest.fail(f"URLEval mismatch!\n{diff_msg}") + def eval_results( + self, page: Page, result: Dict[str, Any], field: Optional[str] = None + ) -> None: + if ( + self.type == "json_match" + and field is not None + and type(self.expected) is dict + ): + return validate_field_match(self.expected, result, field) + + if self.type == "json_match" and type(self.expected) is dict: + return validate_json_match(self.expected, result) + + if self.type == "end_url_match" and type(self.expected) is str: + return validate_end_url_match(self.expected, page.url) + + raise NotImplementedError("No evaluation type implemented") class Example(BaseModel): diff --git a/bananalyzer/runner/evals.py b/bananalyzer/runner/evals.py new file mode 100644 index 00000000..788e3027 --- /dev/null +++ b/bananalyzer/runner/evals.py @@ -0,0 +1,80 @@ +import json +import re +from typing import Any, Dict + +import pytest +from deepdiff import DeepDiff + +Result = Dict[str, Any] +NON_ALPHANUMERIC_PATTERN = re.compile(r"[^a-zA-Z0-9]") + + +def validate_field_match(expected: Result, actual: Result, field: str) -> None: + expected_value = expected.get(field, None) + actual_value = actual.get(field, None) + + sanitized_expected = ( + sanitize_string(str(expected_value)) if expected_value is not None else None + ) + sanitized_actual = ( + sanitize_string(str(actual_value)) if actual_value is not None else None + ) + + if sanitized_expected != sanitized_actual: + diff_msg = f"Actual: {actual_value}\nExpected: {expected_value}" + pytest.fail(f"FieldEval mismatch!\n{diff_msg}") + + +def validate_json_match(expected: Result, actual: Result) -> None: + assert isinstance(expected, dict) + + # TODO: We should probably code gen to remove newlines or update test data to contain new lines + formatted_expected = format_new_lines(expected) + formatted_actual = format_new_lines(actual) + + # TODO: Pass in schema in the backend and handle this OUTSIDE of tests + # Adding missing keys in actual with None if they are expected to be None + for key, value in formatted_expected.items(): + if value is None and key not in formatted_actual: + formatted_actual[key] = None + + diff = DeepDiff( + formatted_expected, + formatted_actual, + ignore_order=True, + report_repetition=True, + ) + + if diff: + # Pretty print both expected and actual results + pretty_expected = json.dumps(formatted_expected, indent=4) + pretty_actual = json.dumps(formatted_actual, indent=4) + + diff_msg = f"Actual: {pretty_actual}\nExpected: {pretty_expected}" + pytest.fail(f"JSONEval mismatch!\n{diff_msg}") + + +def validate_end_url_match(expected: str, actual: str) -> None: + if actual != expected: + diff_msg = f"Actual URL:\t{actual}\nExpected URL:\t{expected}" + pytest.fail(f"URLEval mismatch!\n{diff_msg}") + + +def sanitize_string(input_str: str) -> str: + """Remove non-alphanumeric characters and convert to lowercase.""" + + sanitized = NON_ALPHANUMERIC_PATTERN.sub("", input_str) + return sanitized.lower() + + +def format_new_lines(d: Result) -> Result: + """Recursively replace newlines in strings with spaces.""" + new_dict: Result = {} + for k, v in d.items(): + if isinstance(v, dict): + new_dict[k] = format_new_lines(v) + elif isinstance(v, str): + new_dict[k] = v.replace("\n", " ") + else: + new_dict[k] = v + return new_dict diff --git a/bananalyzer/runner/generator.py b/bananalyzer/runner/generator.py index ab61f201..4927baf6 100644 --- a/bananalyzer/runner/generator.py +++ b/bananalyzer/runner/generator.py @@ -35,7 +35,7 @@ def _generate_eval_test(self, eval_: Eval, i: int) -> str: return f""" @pytest.mark.parametrize("key", {list(eval_.expected.keys())}) async def test_match_field(self, key, result) -> None: - assert self.example.evals[{i}].expected.get(key, None) == result.get(key, None) + assert self.example.evals[{i}].eval_results(None, result, field=key) """ return f""" diff --git a/tests/test_evals.py b/tests/test_evals.py new file mode 100644 index 00000000..d256dd28 --- /dev/null +++ b/tests/test_evals.py @@ -0,0 +1,48 @@ +import pytest + +from bananalyzer.runner.evals import sanitize_string, validate_field_match + + +@pytest.mark.parametrize( + "input_str, expected", + [ + ("hello world!", "helloworld"), + ("HELLO_WORLD", "helloworld"), + ("Hello1 WoRlD@!", "hello1world"), + ("", ""), + ("123456", "123456"), + ("!@#$%^&*()", ""), + ("Hello World 123", "helloworld123"), + (" ", ""), + ("HELLO", "hello"), + ("hello", "hello"), + ("HelloWorld2023", "helloworld2023"), + ], +) +def test_sanitize_string(input_str, expected): + assert sanitize_string(input_str) == expected + + +@pytest.mark.parametrize( + "expected, actual, field", + [ + ({"field": "TestValue123"}, {"field": "test value!123"}, "field"), + ({"field": "AnotherTest123"}, {"field": "another test 123"}, "field"), + ({"field": ""}, {"field": " "}, "field"), + ], +) +def test_validate_field_match(expected, actual, field): + validate_field_match(expected, actual, field) + + +@pytest.mark.parametrize( + "expected, actual, field", + [ + ({"field": "TestValue123"}, {"field": "DifferentValue123"}, "field"), + ({"field": None}, {"field": "testvalue"}, "field"), + ({"field": "Value"}, {"other_field": "Value"}, "field"), + ], +) +def test_validate_field_match_fail(expected, actual, field): + with pytest.raises(pytest.fail.Exception): + validate_field_match(expected, actual, field) diff --git a/tests/test_example.py b/tests/test_example.py index e00d9358..7ebf27f4 100644 --- a/tests/test_example.py +++ b/tests/test_example.py @@ -5,7 +5,8 @@ from pydantic import ValidationError from bananalyzer.data.fetch_schemas import fetch_schemas -from bananalyzer.data.schemas import Eval, Example, format_new_lines +from bananalyzer.data.schemas import Eval, Example +from bananalyzer.runner.evals import format_new_lines def test_format_new_lines() -> None: From 4e3cc538c51621d7f92e214d3e7e45e0e6d66ac7 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:37:53 -0800 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20alpha=20numeric=20scor?= =?UTF-8?q?ing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/runner/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bananalyzer/runner/generator.py b/bananalyzer/runner/generator.py index 4927baf6..f8e4c690 100644 --- a/bananalyzer/runner/generator.py +++ b/bananalyzer/runner/generator.py @@ -35,7 +35,7 @@ def _generate_eval_test(self, eval_: Eval, i: int) -> str: return f""" @pytest.mark.parametrize("key", {list(eval_.expected.keys())}) async def test_match_field(self, key, result) -> None: - assert self.example.evals[{i}].eval_results(None, result, field=key) + self.example.evals[{i}].eval_results(None, result, field=key) """ return f""" From 578474a149f46c231be6a45109f94d44e263bbf3 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:12:21 -0800 Subject: [PATCH 3/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20a=20tolerance=20of=20t?= =?UTF-8?q?wo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/runner/evals.py | 32 +++++++++++++----------- pyproject.toml | 2 +- tests/test_evals.py | 49 +++++++++++++++++++++++++++++++------ 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/bananalyzer/runner/evals.py b/bananalyzer/runner/evals.py index 788e3027..983b3dbc 100644 --- a/bananalyzer/runner/evals.py +++ b/bananalyzer/runner/evals.py @@ -6,21 +6,32 @@ from deepdiff import DeepDiff Result = Dict[str, Any] -NON_ALPHANUMERIC_PATTERN = re.compile(r"[^a-zA-Z0-9]") +NON_ALPHANUMERIC_REGEX = re.compile(r"[^a-zA-Z0-9]") + + +def sanitize_string(input_str: str) -> str: + return NON_ALPHANUMERIC_REGEX.sub("", input_str).lower() + + +def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool: + length_difference = abs(len(actual) - len(expected)) + sanitized_actual = sanitize_string(actual) + sanitized_expected = sanitize_string(expected) + + return sanitized_actual == sanitized_expected and length_difference <= tolerance def validate_field_match(expected: Result, actual: Result, field: str) -> None: expected_value = expected.get(field, None) actual_value = actual.get(field, None) - sanitized_expected = ( - sanitize_string(str(expected_value)) if expected_value is not None else None - ) - sanitized_actual = ( - sanitize_string(str(actual_value)) if actual_value is not None else None + matcher = ( + is_string_similar + if isinstance(expected_value, str) and isinstance(actual_value, str) + else lambda x, y: x == y ) - if sanitized_expected != sanitized_actual: + if not matcher(actual_value, expected_value): diff_msg = f"Actual: {actual_value}\nExpected: {expected_value}" pytest.fail(f"FieldEval mismatch!\n{diff_msg}") @@ -60,13 +71,6 @@ def validate_end_url_match(expected: str, actual: str) -> None: pytest.fail(f"URLEval mismatch!\n{diff_msg}") -def sanitize_string(input_str: str) -> str: - """Remove non-alphanumeric characters and convert to lowercase.""" - - sanitized = NON_ALPHANUMERIC_PATTERN.sub("", input_str) - return sanitized.lower() - - def format_new_lines(d: Result) -> Result: """Recursively replace newlines in strings with spaces.""" new_dict: Result = {} diff --git a/pyproject.toml b/pyproject.toml index 6886b565..dc8649ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bananalyzer" -version = "0.5.1" +version = "0.5.2" description = "Open source AI Agent evaluation framework for web tasks 🐒🍌" authors = ["asim-shrestha "] readme = "README.md" diff --git a/tests/test_evals.py b/tests/test_evals.py index d256dd28..027dadd3 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -1,6 +1,10 @@ import pytest -from bananalyzer.runner.evals import sanitize_string, validate_field_match +from bananalyzer.runner.evals import ( + is_string_similar, + sanitize_string, + validate_field_match, +) @pytest.mark.parametrize( @@ -23,24 +27,53 @@ def test_sanitize_string(input_str, expected): assert sanitize_string(input_str) == expected +@pytest.mark.parametrize( + "actual, expected, tolerance, expected_result", + [ + ("Hello-World", "hello world", 2, True), + ("test123", "test!123", 2, True), + ("string-with-chars", "stringwithchars", 2, True), + ("text_with_underscores", "textwithunderscores", 2, True), + ("hello", "he-llo", 1, True), + ("string", "string!!", 2, True), + ("foo", "foo--", 2, True), + ("short", "s-h-o-r-t", 2, False), + ("text", "text----", 2, False), + ("word", "w-o-r-d-e", 3, False), + ("name", "n-a-m-e--", 3, False), + ("different", "diff3r3nt", 2, False), + ("text", "txet", 2, False), + ("hello", "world", 2, False), + ("abc", "def", 2, False), + ("example", "ex-ample", 0, False), + ("", "", 2, True), + ("a", "a-", 1, True), + ("b", "b--", 1, False), + ], +) +def test_is_string_similar(actual, expected, tolerance, expected_result): + assert is_string_similar(actual, expected, tolerance) == expected_result + + @pytest.mark.parametrize( "expected, actual, field", [ - ({"field": "TestValue123"}, {"field": "test value!123"}, "field"), - ({"field": "AnotherTest123"}, {"field": "another test 123"}, "field"), - ({"field": ""}, {"field": " "}, "field"), + ({"field": "Hello World"}, {"field": "Hello-World"}, "field"), + ({"field": "test"}, {"field": "test!!"}, "field"), + ({"field": 123}, {"field": 123}, "field"), + ({"field": [1, 2, 3]}, {"field": [1, 2, 3]}, "field"), ], ) -def test_validate_field_match(expected, actual, field): +def test_validate_field_match_pass(expected, actual, field): validate_field_match(expected, actual, field) @pytest.mark.parametrize( "expected, actual, field", [ - ({"field": "TestValue123"}, {"field": "DifferentValue123"}, "field"), - ({"field": None}, {"field": "testvalue"}, "field"), - ({"field": "Value"}, {"other_field": "Value"}, "field"), + ({"field": "example"}, {"field": "example 123"}, "field"), + ({"field": "short string"}, {"field": "short string!!!"}, "field"), + ({"field": [1, 2, 3]}, {"field": [1, 2]}, "field"), ], ) def test_validate_field_match_fail(expected, actual, field): From 42a158a0818891173039e0aeb5852ee7184eae1b Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:25:00 -0800 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20a=20tolerance=20of=20t?= =?UTF-8?q?wo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/runner/evals.py | 21 +++++++++++++++++++-- tests/test_evals.py | 3 +++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/bananalyzer/runner/evals.py b/bananalyzer/runner/evals.py index 983b3dbc..79d35664 100644 --- a/bananalyzer/runner/evals.py +++ b/bananalyzer/runner/evals.py @@ -14,11 +14,28 @@ def sanitize_string(input_str: str) -> str: def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool: - length_difference = abs(len(actual) - len(expected)) sanitized_actual = sanitize_string(actual) sanitized_expected = sanitize_string(expected) - return sanitized_actual == sanitized_expected and length_difference <= tolerance + # Check if alphanumeric content matches + if sanitized_actual != sanitized_expected: + return False + + non_alnum_actual = ''.join(char for char in actual if not char.isalnum()) + non_alnum_expected = ''.join(char for char in expected if not char.isalnum()) + + # Compare the sequence of non-alphanumeric characters with a tolerance for + # additional/missing characters + diff_count = 0 + for char1, char2 in zip(non_alnum_actual, non_alnum_expected): + if char1 != char2: + diff_count += 1 + + # Account for length difference if one sequence is longer than the other + length_diff = abs(len(non_alnum_actual) - len(non_alnum_expected)) + diff_count += length_diff + + return diff_count <= tolerance def validate_field_match(expected: Result, actual: Result, field: str) -> None: diff --git a/tests/test_evals.py b/tests/test_evals.py index 027dadd3..c7528011 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -49,6 +49,9 @@ def test_sanitize_string(input_str, expected): ("", "", 2, True), ("a", "a-", 1, True), ("b", "b--", 1, False), + ("c+", "c-", 0, False), + ("d---", "d+++", 1, False), + ("++e+++", "---e--", 0, False), ], ) def test_is_string_similar(actual, expected, tolerance, expected_result): From f92d5b47145f60de8c53912a388da93b701860d5 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:28:40 -0800 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20a=20tolerance=20of=20t?= =?UTF-8?q?wo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/runner/evals.py | 74 +++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/bananalyzer/runner/evals.py b/bananalyzer/runner/evals.py index 79d35664..139b0fa7 100644 --- a/bananalyzer/runner/evals.py +++ b/bananalyzer/runner/evals.py @@ -1,6 +1,6 @@ import json import re -from typing import Any, Dict +from typing import Any, Callable, Dict import pytest from deepdiff import DeepDiff @@ -9,45 +9,11 @@ NON_ALPHANUMERIC_REGEX = re.compile(r"[^a-zA-Z0-9]") -def sanitize_string(input_str: str) -> str: - return NON_ALPHANUMERIC_REGEX.sub("", input_str).lower() - - -def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool: - sanitized_actual = sanitize_string(actual) - sanitized_expected = sanitize_string(expected) - - # Check if alphanumeric content matches - if sanitized_actual != sanitized_expected: - return False - - non_alnum_actual = ''.join(char for char in actual if not char.isalnum()) - non_alnum_expected = ''.join(char for char in expected if not char.isalnum()) - - # Compare the sequence of non-alphanumeric characters with a tolerance for - # additional/missing characters - diff_count = 0 - for char1, char2 in zip(non_alnum_actual, non_alnum_expected): - if char1 != char2: - diff_count += 1 - - # Account for length difference if one sequence is longer than the other - length_diff = abs(len(non_alnum_actual) - len(non_alnum_expected)) - diff_count += length_diff - - return diff_count <= tolerance - - def validate_field_match(expected: Result, actual: Result, field: str) -> None: expected_value = expected.get(field, None) actual_value = actual.get(field, None) - matcher = ( - is_string_similar - if isinstance(expected_value, str) and isinstance(actual_value, str) - else lambda x, y: x == y - ) - + matcher = get_matcher(expected_value, actual_value) if not matcher(actual_value, expected_value): diff_msg = f"Actual: {actual_value}\nExpected: {expected_value}" pytest.fail(f"FieldEval mismatch!\n{diff_msg}") @@ -99,3 +65,39 @@ def format_new_lines(d: Result) -> Result: else: new_dict[k] = v return new_dict + + +def sanitize_string(input_str: str) -> str: + return NON_ALPHANUMERIC_REGEX.sub("", input_str).lower() + + +def is_string_similar(actual: str, expected: str, tolerance: int = 2) -> bool: + sanitized_actual = sanitize_string(actual) + sanitized_expected = sanitize_string(expected) + + # Check if alphanumeric content matches + if sanitized_actual != sanitized_expected: + return False + + non_alnum_actual = "".join(char for char in actual if not char.isalnum()) + non_alnum_expected = "".join(char for char in expected if not char.isalnum()) + + # Compare the sequence of non-alphanumeric characters with a tolerance for + # additional/missing characters + diff_count = 0 + for char1, char2 in zip(non_alnum_actual, non_alnum_expected): + if char1 != char2: + diff_count += 1 + + # Account for length difference if one sequence is longer than the other + length_diff = abs(len(non_alnum_actual) - len(non_alnum_expected)) + diff_count += length_diff + + return diff_count <= tolerance + + +def get_matcher(expected_value: Any, actual_value: Any) -> Callable[[Any, Any], bool]: + if isinstance(expected_value, str) and isinstance(actual_value, str): + return is_string_similar + else: + return lambda x, y: x == y From 59b869a7d6886ca08ce4be8b4dbb278c00001497 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:33:32 -0800 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=AB=A1=20Add=20a=20tolerance=20of=20t?= =?UTF-8?q?wo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bananalyzer/runner/evals.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/bananalyzer/runner/evals.py b/bananalyzer/runner/evals.py index 139b0fa7..968b8683 100644 --- a/bananalyzer/runner/evals.py +++ b/bananalyzer/runner/evals.py @@ -20,29 +20,26 @@ def validate_field_match(expected: Result, actual: Result, field: str) -> None: def validate_json_match(expected: Result, actual: Result) -> None: - assert isinstance(expected, dict) + if isinstance(expected, dict): + expected = format_new_lines(expected) + actual = format_new_lines(actual) - # TODO: We should probably code gen to remove newlines or update test data to contain new lines - formatted_expected = format_new_lines(expected) - formatted_actual = format_new_lines(actual) - - # TODO: Pass in schema in the backend and handle this OUTSIDE of tests - # Adding missing keys in actual with None if they are expected to be None - for key, value in formatted_expected.items(): - if value is None and key not in formatted_actual: - formatted_actual[key] = None + # TODO: Pass in schema in the backend and handle this OUTSIDE of tests + # Adding missing keys in actual with None if they are expected to be None + for key, value in expected.items(): + if value is None and key not in actual: + actual[key] = None diff = DeepDiff( - formatted_expected, - formatted_actual, + expected, + actual, ignore_order=True, report_repetition=True, ) - if diff: # Pretty print both expected and actual results - pretty_expected = json.dumps(formatted_expected, indent=4) - pretty_actual = json.dumps(formatted_actual, indent=4) + pretty_expected = json.dumps(expected, indent=4) + pretty_actual = json.dumps(actual, indent=4) diff_msg = f"Actual: {pretty_actual}\nExpected: {pretty_expected}" pytest.fail(f"JSONEval mismatch!\n{diff_msg}")