diff --git a/.github/actions/install-internal-pip/action.yml b/.github/actions/install-internal-pip/action.yml index 912fe1f9fe..aeb4619af7 100644 --- a/.github/actions/install-internal-pip/action.yml +++ b/.github/actions/install-internal-pip/action.yml @@ -30,4 +30,5 @@ runs: else URL="git+ssh://git@${{ inputs.host }}/${{ inputs.repo }}.git" fi - pip install "$URL" ${{ inputs.pip-extra-args }} \ No newline at end of file + echo "Installing from URL: $URL" + pip install --no-cache-dir --force-reinstall "$URL" ${{ inputs.pip-extra-args }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b1b709f15..6f81ade2c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: hooks: - id: enforce-relative-imports name: Enforce Relative Imports - entry: python utils/enforce_relative_imports.py + entry: python3 utils/enforce_relative_imports.py language: system # Adjust the files pattern to match your needs files: ^src/.*\.py$ @@ -40,7 +40,7 @@ repos: hooks: - id: enforce-library-imports name: Enforce Library Imports - entry: python utils/enforce_library_imports.py + entry: python3 utils/enforce_library_imports.py language: system # Adjust the files pattern to match your needs exclude: (^src/.*\.py$)|utils/enforce_library_imports.py|utils/enforce_relative_imports.py diff --git a/examples/evaluate_tool_calling_with_reflection.py b/examples/evaluate_tool_calling_with_reflection.py index 47da026752..2d69369dde 100644 --- a/examples/evaluate_tool_calling_with_reflection.py +++ b/examples/evaluate_tool_calling_with_reflection.py @@ -64,7 +64,10 @@ test_set=data, split="test", format="formats.chat_api", - metrics=["metrics.tool_calling.reflection.syntactic"], + metrics=[ + "metrics.tool_calling.reflection.syntactic", + "metrics.tool_calling.reflection", + ], max_test_instances=10, ) diff --git a/prepare/metrics/tool_calling.py b/prepare/metrics/tool_calling.py index c4297d0a92..b01bb1dc86 100644 --- a/prepare/metrics/tool_calling.py +++ b/prepare/metrics/tool_calling.py @@ -1,6 +1,7 @@ from unitxt.catalog import add_to_catalog from unitxt.metrics import ( MultiTurnToolCallingMetric, + ReflectionToolCallingMetric, ReflectionToolCallingMetricSyntactic, ToolCallingMetric, ToolCallKeyValueExtraction, @@ -48,15 +49,23 @@ add_to_catalog( MultiTurnToolCallingMetric( - __description__="""Metric that evaluates tool call predictions for the validity with regards to the tools schema.""" + __description__="""A metric that assesses tool call predictions for their conformity to the tool schema.""" ), "metrics.tool_calling.multi_turn.validity", overwrite=True, ) +add_to_catalog( + ReflectionToolCallingMetric( + __description__="""A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations. For each instance, it returns a score reflecting its overall validity, as well as a breakdown of the specific checks/metrics that passed or failed, including hallucination check, value format alignment, function selection and agentic constraints satisfaction. Each metric also contains an evidence from the input, an explanation describing the reflection decision, a confidence, and a validity score with a range of 1-5 (higher score -> more valid).""" + ), + "metrics.tool_calling.reflection", + overwrite=True, +) + add_to_catalog( ReflectionToolCallingMetricSyntactic( - __description__="""This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy.""" + __description__="""This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy. Each metric also contains an explanation describing the errors that it detected (if no errors were found - the explanation will be None).""" ), "metrics.tool_calling.reflection.syntactic", overwrite=True, diff --git a/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json b/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json index d15cd09f49..5df68bdbb5 100644 --- a/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json +++ b/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json @@ -1,4 +1,4 @@ { "__type__": "multi_turn_tool_calling_metric", - "__description__": "Metric that evaluates tool call predictions for the validity with regards to the tools schema." + "__description__": "A metric that assesses tool call predictions for their conformity to the tool schema." } diff --git a/src/unitxt/catalog/metrics/tool_calling/reflection.json b/src/unitxt/catalog/metrics/tool_calling/reflection.json new file mode 100644 index 0000000000..7b52d5ec12 --- /dev/null +++ b/src/unitxt/catalog/metrics/tool_calling/reflection.json @@ -0,0 +1,4 @@ +{ + "__type__": "reflection_tool_calling_metric", + "__description__": "A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations. For each instance, it returns a score reflecting its overall validity, as well as a breakdown of the specific checks/metrics that passed or failed, including hallucination check, value format alignment, function selection and agentic constraints satisfaction. Each metric also contains an evidence from the input, an explanation describing the reflection decision, a confidence, and a validity score with a range of 1-5 (higher score -> more valid)." +} diff --git a/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json b/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json index 85b5a37227..d4a1e4bf8b 100644 --- a/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json +++ b/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json @@ -1,4 +1,4 @@ { "__type__": "reflection_tool_calling_metric_syntactic", - "__description__": "This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy." + "__description__": "This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy. Each metric also contains an explanation describing the errors that it detected (if no errors were found - the explanation will be None)." } diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 7e93751d4b..0aff478446 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1,4 +1,5 @@ import ast +import asyncio import json import math import os @@ -17,6 +18,7 @@ Dict, Generator, Generic, + Iterable, List, Literal, Optional, @@ -891,13 +893,307 @@ def map( } +class ReflectionToolCallingMixin: + @staticmethod + def convert_tools_inventory(tools): + from llmevalkit.function_calling.pipeline.types import ( + ToolSpec as LLMEvalKitToolSpec, + ) + + return [ + LLMEvalKitToolSpec( + type="function", + function={**tool}, + ) + for tool in tools + ] + + @staticmethod + def convert_tool_call(prediction: ToolCall): + from llmevalkit.function_calling.pipeline.types import ( + ToolCall as LLMEvalKitToolCall, + ) + + return LLMEvalKitToolCall( + type="function", + function={ + "name": prediction["name"], + "arguments": json.dumps(prediction["arguments"]), + "parsed_arguments": prediction["arguments"], + }, + ) + + +class ReflectionToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]): + """Measures syntactic and semantic validity of tool calls. + + The final output contains two main fields: "semantic" and "static" (i.e., semantic). + Under the semantics we define two types of metrics: general and function selection. + + General metrics evaluate the overall quality and correctness of the tool call. + These metrics contains: + 1. General hallucination check: Evaluate whether each parameter value in the function call is correct and directly supported by the provided conversation history and adhere the tool specifications. + 2. Value format alignment: Check if the format of the parameter values aligns with the expected formats defined in the tool specifications. + + Function selection metrics evaluate the appropriateness of the selected function for the given context. + These metrics include: + 1. Function selection appropriateness: Assess whether the chosen function is suitable for the task at hand. + 2. Agentic constraints satisfaction: Assess whether the proposed tool call satisfies all agentic constraints required for execution. + + Static metrics evaluate the syntactic validity of the tool call. + It contains the following metrics: + - non_existent_function: tool name not found. + - non_existent_parameter: argument name not in tool spec. + - incorrect_parameter_type: argument type mismatch. + - missing_required_parameter: required argument missing. + - allowed_values_violation: argument value outside allowed set. + - json_schema_violation: call violates JSON schema. + - empty_api_spec: no tool spec provided. + - invalid_api_spec: tool spec is invalid. + - invalid_tool_call: call is not a valid tool invocation. + - overall_valid: validity of the call (main score). + - score: alias of overall_valid. + + Here is an example for a aggregated reflection output after calling reduce. + The range of each score is [0, 1] (where higher indicates less errors). + { + "static_non_existent_function": 1.0, + "static_non_existent_parameter": 1.0, + "static_incorrect_parameter_type": 1.0, + "static_missing_required_parameter": 1.0, + "static_allowed_values_violation": 1.0, + "static_json_schema_violation": 1.0, + "static_empty_api_spec": 1.0, + "static_invalid_api_spec": 1.0, + "static_invalid_tool_call": 1.0, + "semantic_general_hallucination_check": 0.0, + "semantic_general_value_format_alignment": 0.0, + "semantic_avg_score_general": 1.0, + "semantic_function_selection_appropriateness": 0.0, + "semantic_agentic_constraints_satisfaction": 0.0, + "semantic_avg_score_function_selection": 1.0, + "overall_valid": 1.0 + } + + Where overall_valid is the final decision made by the reflection pipeline, indicating whether the tool call is valid or not. + + Before the aggregation each metric contains also evidence, explanation, a more fine-grained score, etc. + + Reference: https://github.ibm.com/MLT/LLMEvalKit + """ + + main_score = "overall_valid" + reduction = MeanReduction() + prediction_type = ToolCall + _requirements_list = { + "llmevalkit": "Install with \"pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'\".\nTo gain access please reach the team." + } + runtime_pipeline: bool = True # Whether to use the runtime pipeline or the longer evaluation pipeline with actionable recommendations + + def prepare(self): + provider_to_default_reflector_model = { + "watsonx": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + "open-ai": "gpt-4o", + "rits": "openai/gpt-oss-120b", + "azure": "gpt-4o", + "mock": "mock", + } + provider = ( + settings.default_provider if not settings.mock_inference_mode else "mock" + ) + if provider not in provider_to_default_reflector_model: + raise ValueError( + f"Unsupported provider for ReflectionToolCallingMetric: {provider}. Supported providers are: {list(provider_to_default_reflector_model.keys())}" + ) + self.requirements = self._get_missing_requirements_by_provider(provider) + super().prepare() + self.setup_pipeline( + reflector_model_name=provider_to_default_reflector_model.get(provider), + provider_name=provider, + ) + + def setup_pipeline( + self, reflector_model_name: str, provider_name: Optional[str] = None + ): + if provider_name: + llmeval_provider_name = self._get_llmeval_provider_name(provider_name) + requirements = self._get_missing_requirements_by_provider(provider_name) + self.check_missing_requirements(requirements) + + metrics_client = self._get_metrics_client( + llmeval_provider_name, reflector_model_name + ) + self.reflection_pipeline = self._build_reflection_pipeline(metrics_client) + return self.reflection_pipeline + + @staticmethod + def _get_llmeval_provider_name(provider_name: str) -> str: + mapping = { + "watsonx": "watsonx.output_val", + "open-ai": "openai.async.output_val", + "rits": "litellm.rits.output_val", + "azure": "azure_openai.async.output_val", + "mock": "mock.output_val", + } + llmeval_provider_name = mapping.get(provider_name) + if llmeval_provider_name is None: + raise ValueError(f"Unsupported provider by llmevalkit: {provider_name}") + return llmeval_provider_name + + @staticmethod + def _get_missing_requirements_by_provider(provider_name: str): + provider_libs = { + "watsonx": "ibm_watsonx_ai", + "open-ai": "openai", + "rits": "litellm", + "azure": "openai", + } + required_lib = provider_libs.get(provider_name) + return [required_lib] if required_lib else [] + + @staticmethod + def _get_metrics_client(llmeval_provider_name: str, reflector_model_name: str): + from llmevalkit.llm import get_llm + + metrics_client_cls = get_llm(llmeval_provider_name) + return metrics_client_cls(model_name=reflector_model_name) + + def _build_reflection_pipeline(self, metrics_client): + from llmevalkit.function_calling.consts import ( + METRIC_AGENTIC_CONSTRAINTS_SATISFACTION, + METRIC_FUNCTION_SELECTION_APPROPRIATENESS, + METRIC_GENERAL_HALLUCINATION_CHECK, + METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT, + ) + from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline + + return ReflectionPipeline( + metrics_client=metrics_client, + general_metrics=[ + METRIC_GENERAL_HALLUCINATION_CHECK, + METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT, + ], + function_metrics=[ + METRIC_FUNCTION_SELECTION_APPROPRIATENESS, + METRIC_AGENTIC_CONSTRAINTS_SATISFACTION, + ], + parameter_metrics=[], + runtime_pipeline=self.runtime_pipeline, + ) + + async def map( + self, + prediction: ToolCall, + references: None, + task_data: Dict[str, Any], + ): + from llmevalkit.function_calling.pipeline.types import PipelineResult + + # Convert unitxt dialog to LLMEvalKit format + if "dialog" in task_data: + conversation_history = [dict(turn) for turn in task_data["dialog"]] + elif "query" in task_data: + conversation_history = [{"role": "user", "content": task_data["query"]}] + else: + raise ValueError("task_data must contain either 'dialog' or 'query' field.") + + # Convert unitxt tool inventory to LLMEvalKit format + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory( + task_data.get("tools", []) + ) + + # Convert unitxt tool call to LLMEvalKit format + tool_call_converted = ReflectionToolCallingMixin.convert_tool_call(prediction) + + # Run reflection (syntactic + semantic) + result: PipelineResult = await self.reflection_pipeline.run_async( + conversation=conversation_history, + inventory=tools_inventory, + call=tool_call_converted, + retries=3, + continue_on_static=True, + ) + return result.model_dump() + + def map_stream( + self, + items: Iterable[Tuple[ToolCall, None, Dict[str, Any]]], + *, + max_concurrency: int = 8, + ) -> List[Dict[str, Any]]: + """Run self.map in parallel over an iterable and return results in order.""" + + async def process_all(): + items_iter = iter(enumerate(items)) + results = [] + pending = set() + while True: + while len(pending) < max_concurrency: + try: + idx, (pred, refs, data) = next(items_iter) + if isinstance(pred, list): + for p in pred: + task = asyncio.create_task(self.map(p, refs, data)) + task.idx = idx + pending.add(task) + else: + task = asyncio.create_task(self.map(pred, refs, data)) + task.idx = idx + pending.add(task) + except StopIteration: + break + if not pending: + break + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + for task in done: + results.append((task.idx, await task)) + results.sort() + return [r for _, r in results] + + return asyncio.run(process_all()) + + def reduce_one(self, intermidate: Dict[str, Any]) -> Dict[str, float]: + return intermidate + + def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]: + flat_instances = [] + for instance in intermidates: + flat_instance_dict = {} + for metric, metric_type_dict in ( + instance.get("static", {}).get("metrics", {}).items() + ): + flat_instance_dict[f"static_{metric}"] = float( + metric_type_dict["valid"] + ) + + for metric_type, metric_type_dict in instance.get("semantic", {}).items(): + if metric_type_dict is not None: + for metric, metric_dict in metric_type_dict.get( + "metrics", {} + ).items(): + flat_instance_dict[f"semantic_{metric}"] = float( + metric_dict["is_issue"] + ) + flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float( + metric_type_dict.get("avg_score") + ) + + flat_instance_dict["overall_valid"] = float(instance["overall_valid"]) + flat_instances.append(flat_instance_dict) + + return self.reduction.reduce(flat_instances) + + class ReflectionToolCallingMetricSyntactic( ReductionInstanceMetric[str, Dict[str, float]] ): """Measures syntactic and schema validity of tool calls. - Range: [0, 1] (higher is better) - Returns 1.0 if the tool call is valid (all checks pass), 0.0 otherwise. + Range: [0, 1] (higher indicates less errors). + Returns 1.0 if the tool call is valid for each metric, 0.0 otherwise. + overall_valid equals 1.0 if all metrics are valid, 0.0 otherwise. Global score is the percentage of valid instances across the dataset. Scores: @@ -930,47 +1226,39 @@ def map( task_data: Dict[str, Any], ) -> Dict[str, float]: from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline - from llmevalkit.function_calling.pipeline.types import ( - ToolCall as LLMEvalKitToolCall, - ) - from llmevalkit.function_calling.pipeline.types import ( - ToolSpec as LLMEvalKitToolSpec, - ) # Convert unitxt tool inventory to LLMEvalKit format - tools_inventory = [] - for tool in task_data.get("tools", []): - tools_inventory.append( - LLMEvalKitToolSpec( - type="function", - function={**tool}, - ) - ) + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory( + task_data.get("tools", []) + ) # Convert unitxt tool call to LLMEvalKit format - tool_call = LLMEvalKitToolCall( - type="function", - function={ - "name": prediction["name"], - "arguments": json.dumps(prediction["arguments"]), - "parsed_arguments": prediction["arguments"], - }, - ) + tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction) # Run static validation static_result = ReflectionPipeline.static_only(tools_inventory, tool_call) - result_dict = { - ( - metric_name - if metric_name != "json_schema_validation" - else "json_schema_violation" - ): not metric_dict.valid - for metric_name, metric_dict in static_result.metrics.items() - } - result_dict["overall_valid"] = static_result.final_decision + result_dict = static_result.model_dump() + result_dict["overall_valid"] = result_dict.pop("final_decision") + result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop( + "json_schema_validation" + ) return result_dict + def reduce_one(self, intermidate: Dict[str, float]) -> Dict[str, float]: + return intermidate + + def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]: + flat_instances = [] + for instance in intermediates: + flat_instance_dict = {} + for metric, metric_dict in instance.get("metrics", {}).items(): + flat_instance_dict[metric] = float(metric_dict["valid"]) + flat_instance_dict["overall_valid"] = instance["overall_valid"] + flat_instances.append(flat_instance_dict) + + return self.reduction.reduce(flat_instances) + class MetricWithConfidenceInterval(Metric): # The number of resamples used to estimate the confidence intervals of this metric. diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index ac66289250..2bf2a94eac 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1,6 +1,8 @@ +import random from math import isnan from typing import Dict, List +import unitxt from unitxt.api import create_dataset, evaluate from unitxt.inference import MockInferenceEngine from unitxt.llm_as_judge import LLMAsJudge, TaskBasedLLMasJudge @@ -61,6 +63,7 @@ Perplexity, PrecisionBinary, RecallBinary, + ReflectionToolCallingMetric, ReflectionToolCallingMetricSyntactic, RelaxedCorrectness, RocAuc, @@ -80,7 +83,7 @@ check_scores, test_metric, ) -from unitxt.types import ToolCall +from unitxt.types import Dialog, Tool, ToolCall from tests.utils import UnitxtTestCase @@ -146,6 +149,8 @@ class TestMetrics(UnitxtTestCase): + use_mock_model: bool = True + def test_unsorted_list_exact_match(self): metric = UnsortedListExactMatch() @@ -1649,73 +1654,81 @@ def test_tool_calling_metric(self): outputs[0]["score"]["global"]["argument_schema_validation"], 0.0 ) - def test_complex_tool_call_real_static_only(self): - """Test a complex tool call with multiple types of validation issues.""" - # Create a complex tool call with multiple issues - metric = ReflectionToolCallingMetricSyntactic() + def test_reflection_tool_calling_metric(self): + unitxt.settings.mock_inference_mode = True + metric = ReflectionToolCallingMetric() + prediction = ToolCall( **{ "name": "advanced_weather", "arguments": { "location": "San Francisco", - "days": "7", # Wrong type (string vs integer) - "format": "complete", # Invalid enum value - "include_alerts": "yes", # Wrong type (string vs boolean) - "coordinates": "37.7749,-122.4194", # Wrong type (should be array) - "debug_mode": True, # Non-existent parameter - "extra_info": "all", # Non-existent parameter - # Missing required parameter: api_key + "days": 7, + "format": "summary", }, } ) - references = [] - task_data = { "tools": [ - { - "name": "advanced_weather", - "description": "Get advanced weather forecast", - "parameters": { - "type": "object", - "required": ["location", "api_key"], - "properties": { - "location": {"type": "string"}, - "api_key": {"type": "string"}, - "days": {"type": "integer"}, - "format": { - "type": "string", - "enum": ["brief", "detailed", "summary"], - }, - "include_alerts": {"type": "boolean"}, - "coordinates": { - "type": "array", - "items": {"type": "number"}, + Tool( + { + "name": "advanced_weather", + "description": "Get advanced weather forecast", + "parameters": { + "type": "object", + "required": ["location", "days"], + "properties": { + "location": {"type": "string"}, + "days": {"type": "integer"}, + "format": { + "type": "string", + "enum": ["brief", "detailed", "summary"], + "default": "detailed", + }, }, }, - }, - } - ] + } + ) + ], + "dialog": Dialog( + [ + { + "role": "user", + "content": "What's the weather like in San Francisco in the next 7 days? Give me a detailed forecast.", + } + ], + ), } # Call the map method - result = metric.map(prediction, references, task_data) + result = metric.map_stream(items=[(prediction, None, task_data)])[0] # Verify all the scores - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_function"], False) - - # 1 out of 2 required parameters (missing api_key) - self.assertEqual(result["missing_required_parameter"], True) + self.assertFalse(result["overall_valid"]) + self.assertTrue(result["static"]["final_decision"]) + self.assertTrue( + result["semantic"]["general"]["metrics"]["general_hallucination_check"][ + "is_issue" + ] + ) - # 5 valid parameters out of 7 total (2 non-existent) - self.assertEqual(result["non_existent_parameter"], True) + unitxt_provider_name = "mock" if self.use_mock_model else "watsonx" + model_name = "meta-llama/llama-4-maverick-17b-128e-instruct-fp8" + metric.setup_pipeline( + reflector_model_name=model_name, provider_name=unitxt_provider_name + ) - self.assertEqual(result["allowed_values_violation"], True) - self.assertEqual(result["incorrect_parameter_type"], True) + result = metric.map_stream(items=[(prediction, None, task_data)])[0] - # Schema validation might not reflect our expectations due to updated logic - # We're primarily interested in value precision calculation + # Verify all the scores + self.assertFalse(result["overall_valid"]) + self.assertTrue(result["static"]["final_decision"]) + self.assertTrue( + result["semantic"]["general"]["metrics"]["general_hallucination_check"][ + "is_issue" + ] + ) def test_partial_value_precision_enum_violations_real_static_only(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -1769,9 +1782,254 @@ def test_partial_value_precision_enum_violations_real_static_only(self): # Assert partial value precision - all parameters exist in the schema # 4 out of 6 parameters have valid values (unit and format have enum violations) - self.assertAlmostEqual(result["allowed_values_violation"], True) + self.assertAlmostEqual( + result["metrics"]["allowed_values_violation"]["valid"], False + ) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["missing_required_parameter"], False) + self.assertAlmostEqual( + result["metrics"]["missing_required_parameter"]["valid"], True + ) + + def test_reflection_tool_calling_metric_reduce(self): + # Instance 1: valid call + instance1 = { + "overall_valid": True, + "static": { + "metrics": { + "non_existent_function": {"valid": True}, + "missing_required_parameter": {"valid": True}, + }, + "final_decision": True, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": False}, + "general_value_format_alignment": {"is_issue": False}, + }, + "avg_score": 1.0, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": False}, + "agentic_constraints_satisfaction": {"is_issue": False}, + }, + "avg_score": 1.0, + }, + }, + } + # Instance 2: invalid call + instance2 = { + "overall_valid": False, + "static": { + "metrics": { + "non_existent_function": {"valid": False}, + "missing_required_parameter": {"valid": False}, + }, + "final_decision": False, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": True}, + "general_value_format_alignment": {"is_issue": True}, + }, + "avg_score": 0.0, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": True}, + "agentic_constraints_satisfaction": {"is_issue": True}, + }, + "avg_score": 0.0, + }, + }, + } + # Instance 3: partially valid + instance3 = { + "overall_valid": True, + "static": { + "metrics": { + "non_existent_function": {"valid": True}, + "missing_required_parameter": {"valid": False}, + }, + "final_decision": True, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": False}, + "general_value_format_alignment": {"is_issue": True}, + }, + "avg_score": 0.5, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": False}, + "agentic_constraints_satisfaction": {"is_issue": True}, + }, + "avg_score": 0.5, + }, + }, + } + + unitxt.settings.mock_inference_mode = True + metric = ReflectionToolCallingMetric() + reduced = metric.reduce([instance1, instance2, instance3]) + + # All outputs should be floats in [0, 1] + import math + + self.assertIsInstance(reduced, dict) + for k, v in reduced.items(): + self.assertIsInstance(v, float, msg=f"value for {k} is not float") + self.assertFalse(math.isnan(v), msg=f"NaN at key {k}") + self.assertGreaterEqual(v, 0.0, msg=f"value for {k} < 0") + self.assertLessEqual(v, 1.0, msg=f"value for {k} > 1") + + # Overall validity is the mean of booleans + expected_overall_valid = (1 + 0 + 1) / 3 + self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid) + + # Static aggregated metrics, now flattened + # non_existent_function: True, False, True -> 2/3 + nef_expected = (1 + 0 + 1) / 3 + # missing_required_parameter: True, False, False -> 1/3 + mrp_expected = (1 + 0 + 0) / 3 + self.assertIn("static_non_existent_function", reduced) + self.assertIn("static_missing_required_parameter", reduced) + self.assertAlmostEqual(reduced["static_non_existent_function"], nef_expected) + self.assertAlmostEqual( + reduced["static_missing_required_parameter"], mrp_expected + ) + + # Semantic per family average scores + semantic_avg_expected = (1.0 + 0.0 + 0.5) / 3 + self.assertIn("semantic_avg_score_general", reduced) + self.assertIn("semantic_avg_score_function_selection", reduced) + self.assertAlmostEqual( + reduced["semantic_avg_score_general"], semantic_avg_expected + ) + self.assertAlmostEqual( + reduced["semantic_avg_score_function_selection"], semantic_avg_expected + ) + + # Semantic issue rates for individual checks, flattened, mean of is_issue + # general_hallucination_check: False, True, False -> 1/3 + ghc_expected = (0 + 1 + 0) / 3 + # general_value_format_alignment: False, True, True -> 2/3 + gvfa_expected = (0 + 1 + 1) / 3 + # function_selection_appropriateness: False, True, False -> 1/3 + fsa_expected = (0 + 1 + 0) / 3 + # agentic_constraints_satisfaction: False, True, True -> 2/3 + acs_expected = (0 + 1 + 1) / 3 + + self.assertIn("semantic_general_hallucination_check", reduced) + self.assertIn("semantic_general_value_format_alignment", reduced) + self.assertIn("semantic_function_selection_appropriateness", reduced) + self.assertIn("semantic_agentic_constraints_satisfaction", reduced) + + self.assertAlmostEqual( + reduced["semantic_general_hallucination_check"], ghc_expected + ) + self.assertAlmostEqual( + reduced["semantic_general_value_format_alignment"], gvfa_expected + ) + self.assertAlmostEqual( + reduced["semantic_function_selection_appropriateness"], fsa_expected + ) + self.assertAlmostEqual( + reduced["semantic_agentic_constraints_satisfaction"], acs_expected + ) + + def test_reflection_tool_calling_metric_syntactic_reduce(self): + from unitxt.metrics import ReflectionToolCallingMetricSyntactic + + # Instance 1: invalid function, all other checks pass + instance1 = { + "metrics": { + "non_existent_function": {"valid": False}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": True}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": False, + } + # Instance 2: all checks pass + instance2 = { + "metrics": { + "non_existent_function": {"valid": True}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": True}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": True, + } + # Instance 3: missing required parameter, all others pass + instance3 = { + "metrics": { + "non_existent_function": {"valid": True}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": False}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": False, + } + + metric = ReflectionToolCallingMetricSyntactic() + inputs = [instance1, instance2, instance3] + reduced = metric.reduce(inputs) + + # 1) Key set is exactly metrics + overall_valid + metric_names = [*instance1["metrics"]] + expected_keys = set(metric_names) | {"overall_valid"} + + self.assertEqual(set(reduced.keys()), expected_keys) + + # 2) All values are floats in [0, 1] and not NaN + for k, v in reduced.items(): + self.assertIsInstance(v, float, f"value for {k} is not float") + self.assertFalse(isnan(v), f"NaN at {k}") + self.assertGreaterEqual(v, 0.0, f"value for {k} < 0") + self.assertLessEqual(v, 1.0, f"value for {k} > 1") + + # 3) Per metric aggregation equals mean of valid flags across instances + def mean_valid(name: str) -> float: + vals = [1.0 if inst["metrics"][name]["valid"] else 0.0 for inst in inputs] + return sum(vals) / len(vals) + + for m in metric_names: + self.assertAlmostEqual(reduced[m], mean_valid(m), msg=f"mismatch for {m}") + + # Spot checks, same as your original examples + self.assertAlmostEqual(reduced["non_existent_function"], (0 + 1 + 1) / 3) + self.assertAlmostEqual(reduced["missing_required_parameter"], (1 + 1 + 0) / 3) + self.assertAlmostEqual(reduced["json_schema_violation"], (1 + 1 + 1) / 3) + + # 4) overall_valid equals mean of per instance overall_valid + expected_overall_valid = (0 + 1 + 0) / 3 + self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid) + + # 5) Order invariance, shuffle inputs and expect identical result + shuffled = inputs[:] + random.shuffle(shuffled) + reduced_shuffled = metric.reduce(shuffled) + self.assertEqual(reduced, reduced_shuffled) def test_tool_calling_metric_syntactic_reflector(self): metric = ReflectionToolCallingMetricSyntactic() @@ -1804,8 +2062,8 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Exact match should be 1.0 when prediction and reference are identical - self.assertEqual(outputs["overall_valid"], True) - self.assertEqual(outputs["non_existent_function"], False) + self.assertTrue(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) # Test case 2: Different tool name prediction = {"name": "different_tool", "arguments": {"param1": "value1"}} @@ -1816,8 +2074,8 @@ def test_tool_calling_metric_syntactic_reflector(self): task_data=tools_data, ) - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], True) + self.assertFalse(outputs["overall_valid"]) + self.assertFalse(outputs["metrics"]["non_existent_function"]["valid"]) # Test case 3: Different parameter names prediction = { @@ -1832,18 +2090,18 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Exact match should be 0.0, tool choice 1.0 - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], False) + self.assertFalse(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) # param1 is present but param2 is missing out of 2 required parameters in the schema - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # 1 valid parameter (param1) out of 2 total parameters (param1, wrongParam) - self.assertEqual(outputs["non_existent_parameter"], True) + self.assertFalse(outputs["metrics"]["non_existent_parameter"]["valid"]) # Since wrongParam doesn't exist in the schema, value precision only considers param1 # param1 has the correct type, so value precision is 1.0 - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) # Test case 4: Different parameter values prediction = { @@ -1859,7 +2117,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # Parameter choice should be 1.0 (all names match) # Note: overall_valid is 1.0 because the static validation only checks schema conformance, not value equality - self.assertEqual(outputs["overall_valid"], True) + self.assertTrue(outputs["overall_valid"]) # Test case 5a: Empty arguments prediction = {"name": "test_tool", "arguments": {}} @@ -1872,11 +2130,11 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Recall should be 0 for empty arguments (missing required parameters) - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # Precision is 1.0 because there are no invalid parameter names (no non-existent parameters) - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) prediction = {"name": "test_tool", "arguments": {}} reference = {"name": "test_tool", "arguments": {}} @@ -1890,11 +2148,11 @@ def test_tool_calling_metric_syntactic_reflector(self): # Recall should still be 0 since there are required parameters in the schema # (regardless of the references, which are not used for validation) - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # Precision is 1.0 because there are no invalid parameter names - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) # Test case 6: Multiple references with one match prediction = {"name": "test_tool", "arguments": {"param1": "value1"}} @@ -1907,10 +2165,10 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # overall_valid should be 0.0 because param2 is missing (it's required) - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], False) - self.assertEqual(outputs["missing_required_parameter"], True) - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertFalse(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Test case 7: Parameter types prediction = { @@ -1925,7 +2183,7 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Parameters should have correct types - self.assertEqual(outputs["json_schema_violation"], False) + self.assertEqual(outputs["metrics"]["json_schema_violation"]["valid"], True) # Test case 8: Wrong parameter types prediction = { @@ -1941,7 +2199,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # json_schema_violation is separate from the type validation # schema validation can still pass even if there are type errors - self.assertEqual(outputs["json_schema_violation"], False) + self.assertTrue(outputs["metrics"]["json_schema_violation"]["valid"]) def test_overall_valid_success_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -1986,9 +2244,9 @@ def test_overall_valid_success_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], True) - self.assertEqual(result["non_existent_function"], False) - self.assertEqual(result["missing_required_parameter"], False) + self.assertTrue(result["overall_valid"]) + self.assertTrue(result["metrics"]["non_existent_function"]["valid"]) + self.assertTrue(result["metrics"]["missing_required_parameter"]["valid"]) def test_non_existent_function_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2020,9 +2278,9 @@ def test_non_existent_function_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_function"], True) - self.assertEqual(result["missing_required_parameter"], False) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["non_existent_function"]["valid"]) + self.assertTrue(result["metrics"]["missing_required_parameter"]["valid"]) def test_missing_required_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2060,9 +2318,9 @@ def test_missing_required_parameter_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["missing_required_parameter"], True) - self.assertEqual(result["allowed_values_violation"], False) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["missing_required_parameter"]["valid"]) + self.assertTrue(result["metrics"]["allowed_values_violation"]["valid"]) def test_non_existent_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2094,111 +2352,8 @@ def test_non_existent_parameter_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_parameter"], True) - - def test_incorrect_parameter_type_real_map(self): - metric = ReflectionToolCallingMetricSyntactic() - # Create sample inputs with wrong parameter type - prediction = ToolCall(**{"name": "get_weather", "arguments": {"location": 42}}) - - references = [ - {"name": "get_weather", "arguments": {"location": "San Francisco"}} - ] - - task_data = { - "tools": [ - { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "required": ["location"], - "properties": {"location": {"type": "string"}}, - }, - }, - ] - } - - # Call the map method - result = metric.map(prediction, references, task_data) - - # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["incorrect_parameter_type"], True) - self.assertEqual(result["missing_required_parameter"], False) - self.assertEqual(result["allowed_values_violation"], False) - - def test_multiple_parameter_issues_real_map(self): - metric = ReflectionToolCallingMetricSyntactic() - # Create sample inputs with multiple issues - prediction = ToolCall( - **{ - "name": "get_weather", - "arguments": { - "unit": 123, - "format": True, - "unknown_param": "value1", - "extra_param": "value2", - }, - } - ) - - references = [ - { - "name": "get_weather", - "arguments": { - "location": "San Francisco", - "date": "2025-08-19", - "unit": "celsius", - "format": "brief", - }, - } - ] - - task_data = { - "tools": [ - { - "name": "get_weather", - "description": "Get the weather in a given location", - "parameters": { - "type": "object", - "required": ["location", "date"], - "properties": { - "location": {"type": "string"}, - "date": {"type": "string"}, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - "format": { - "type": "string", - "enum": ["brief", "detailed"], - }, - }, - }, - } - ] - } - - # Call the map method - result = metric.map(prediction, references, task_data) - - # Assert expected results - self.assertEqual(result["overall_valid"], False) - - # Missing 2 out of 2 required parameters - self.assertEqual(result["missing_required_parameter"], True) - - # 2 invalid parameters out of 4 total - self.assertEqual(result["non_existent_parameter"], True) - - # Both unit and format have type issues and no other validation issues are reported - # So 0 out of 2 valid parameters (excluding the non-existent ones) - self.assertEqual(result["incorrect_parameter_type"], True) - - # Schema validation might not reflect our expectations due to updated logic - # We're primarily interested in value precision calculation + self.assertFalse(result["overall_valid"], False) + self.assertFalse(result["metrics"]["non_existent_parameter"]["valid"]) def test_allowed_values_violation(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2241,10 +2396,10 @@ def test_allowed_values_violation(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) + self.assertFalse(result["overall_valid"]) # With 1 invalid enum value out of 2 parameters, value precision should be 0.5 - self.assertEqual(result["allowed_values_violation"], True) - self.assertEqual(result["incorrect_parameter_type"], False) + self.assertFalse(result["metrics"]["allowed_values_violation"]["valid"]) + self.assertTrue(result["metrics"]["incorrect_parameter_type"]["valid"]) def test_json_schema_violation_specific_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2279,12 +2434,12 @@ def test_json_schema_violation_specific_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) # Overall validation fails - self.assertEqual( - result["missing_required_parameter"], True + self.assertFalse(result["overall_valid"]) # Overall validation fails + self.assertFalse( + result["metrics"]["missing_required_parameter"]["valid"] ) # Missing required param # json_schema_violation specifically should be 1.0 because it's marked valid - self.assertEqual(result["json_schema_violation"], False) + self.assertTrue(result["metrics"]["json_schema_violation"]["valid"]) def test_partial_recall_missing_parameters_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2321,8 +2476,8 @@ def test_partial_recall_missing_parameters_real_map(self): result = metric.map(prediction, references, task_data) # Assert partial recall - 1 out of 2 required parameters provided - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["missing_required_parameter"], True) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["missing_required_parameter"]["valid"]) def test_partial_precision_non_existent_parameters_real_map(self): """Test partial precision score when some parameters don't exist in the schema.""" @@ -2366,8 +2521,8 @@ def test_partial_precision_non_existent_parameters_real_map(self): result = metric.map(prediction, references, task_data) # Assert partial precision - 3 out of 6 parameters are valid - self.assertEqual(result["non_existent_parameter"], True) - self.assertEqual(result["overall_valid"], False) + self.assertFalse(result["metrics"]["non_existent_parameter"]["valid"]) + self.assertFalse(result["overall_valid"]) def test_partial_value_precision_type_errors_real_map(self): """Test partial value precision when some parameters have incorrect types.""" @@ -2413,7 +2568,9 @@ def test_partial_value_precision_type_errors_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["incorrect_parameter_type"], True) + self.assertAlmostEqual( + result["metrics"]["incorrect_parameter_type"]["valid"], False + ) def test_partial_value_precision_enum_violations_real_map(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -2465,7 +2622,9 @@ def test_partial_value_precision_enum_violations_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["allowed_values_violation"], True) + self.assertAlmostEqual( + result["metrics"]["allowed_values_violation"]["valid"], False + ) def test_tool_calling_key_value_metric(self): metric = ToolCallKeyValueExtraction(metric="metrics.accuracy")