From 8d141e6052341772bc56dc63306be6a1639f05d9 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Tue, 26 Aug 2025 15:58:10 +0300 Subject: [PATCH 01/14] Add ReflectionToolCallingMetric and update related metrics - Introduced ReflectionToolCallingMetric for assessing syntactic and semantic validity of tool calls. - Updated MultiTurnToolCallingMetric description for clarity. - Added reflection.json to catalog with appropriate descriptions. - Enhanced test coverage for ReflectionToolCallingMetric and its reduction logic. --- prepare/metrics/tool_calling.py | 11 +- .../tool_calling/multi_turn/validity.json | 2 +- .../metrics/tool_calling/reflection.json | 4 + src/unitxt/metrics.py | 308 +++++++++++++-- tests/library/test_metrics.py | 372 ++++++++++++++---- 5 files changed, 597 insertions(+), 100 deletions(-) create mode 100644 src/unitxt/catalog/metrics/tool_calling/reflection.json diff --git a/prepare/metrics/tool_calling.py b/prepare/metrics/tool_calling.py index c4297d0a92..cab880fa5d 100644 --- a/prepare/metrics/tool_calling.py +++ b/prepare/metrics/tool_calling.py @@ -1,6 +1,7 @@ from unitxt.catalog import add_to_catalog from unitxt.metrics import ( MultiTurnToolCallingMetric, + ReflectionToolCallingMetric, ReflectionToolCallingMetricSyntactic, ToolCallingMetric, ToolCallKeyValueExtraction, @@ -48,12 +49,20 @@ add_to_catalog( MultiTurnToolCallingMetric( - __description__="""Metric that evaluates tool call predictions for the validity with regards to the tools schema.""" + __description__="""A metric that assesses tool call predictions for their conformity to the tool schema.""" ), "metrics.tool_calling.multi_turn.validity", overwrite=True, ) +add_to_catalog( + ReflectionToolCallingMetric( + __description__="""A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations.""" + ), + "metrics.tool_calling.reflection", + overwrite=True, +) + add_to_catalog( ReflectionToolCallingMetricSyntactic( __description__="""This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy.""" diff --git a/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json b/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json index d15cd09f49..5df68bdbb5 100644 --- a/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json +++ b/src/unitxt/catalog/metrics/tool_calling/multi_turn/validity.json @@ -1,4 +1,4 @@ { "__type__": "multi_turn_tool_calling_metric", - "__description__": "Metric that evaluates tool call predictions for the validity with regards to the tools schema." + "__description__": "A metric that assesses tool call predictions for their conformity to the tool schema." } diff --git a/src/unitxt/catalog/metrics/tool_calling/reflection.json b/src/unitxt/catalog/metrics/tool_calling/reflection.json new file mode 100644 index 0000000000..6f1d560357 --- /dev/null +++ b/src/unitxt/catalog/metrics/tool_calling/reflection.json @@ -0,0 +1,4 @@ +{ + "__type__": "reflection_tool_calling_metric", + "__description__": "A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations." +} diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 7e93751d4b..cbb6c9e8df 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1,9 +1,13 @@ import ast +import asyncio +import contextlib import json import math import os +from queue import Empty, Queue import re import string +import threading import uuid import warnings from abc import ABC, abstractmethod @@ -14,9 +18,11 @@ from functools import lru_cache from typing import ( Any, + AsyncIterator, Dict, Generator, Generic, + Iterable, List, Literal, Optional, @@ -24,6 +30,7 @@ Type, TypeVar, Union, + override, ) import evaluate @@ -34,6 +41,8 @@ from scipy.stats import bootstrap from scipy.stats._warnings_errors import DegenerateDataWarning +import unitxt + from .artifact import Artifact from .base_metric import Metric from .collections import ListCollection @@ -889,6 +898,257 @@ def map( return { "argument_schema_validation": argument_schema_validation, } + + +class ReflectionToolCallingMixin: + @staticmethod + def convert_tools_inventory(tools): + from llmevalkit.function_calling.pipeline.types import ToolSpec as LLMEvalKitToolSpec + return [ + LLMEvalKitToolSpec( + type="function", + function={**tool}, + ) + for tool in tools + ] + + @staticmethod + def convert_tool_call(prediction): + from llmevalkit.function_calling.pipeline.types import ToolCall as LLMEvalKitToolCall + return LLMEvalKitToolCall( + type="function", + function={ + "name": prediction["name"], + "arguments": json.dumps(prediction["arguments"]), + "parsed_arguments": prediction["arguments"], + }, + ) + + +class ReflectionToolCallingMetric( + ReductionInstanceMetric[str, Dict[str, float]] +): + """Measures syntactic and semantic validity of tool calls. + + The final output contains two main fields: "semantic" and "static" (i.e., semantic). + Under the semantics we define two types of metrics: general and function selection. + + General metrics evaluate the overall quality and correctness of the tool call. + These metrics contains: + 1. General hallucination check: Evaluate whether each parameter value in the function call is correct and directly supported by the provided conversation history and adhere the tool specifications. + 2. Value format alignment: Check if the format of the parameter values aligns with the expected formats defined in the tool specifications. + + Function selection metrics evaluate the appropriateness of the selected function for the given context. + These metrics include: + 1. Function selection appropriateness: Assess whether the chosen function is suitable for the task at hand. + 2. Agentic constraints satisfaction: Assess whether the proposed tool call satisfies all agentic constraints required for execution. + + Static metrics evaluate the syntactic validity of the tool call. + It contains the following metrics: + - non_existent_function: tool name not found. + - non_existent_parameter: argument name not in tool spec. + - incorrect_parameter_type: argument type mismatch. + - missing_required_parameter: required argument missing. + - allowed_values_violation: argument value outside allowed set. + - json_schema_violation: call violates JSON schema. + - empty_api_spec: no tool spec provided. + - invalid_api_spec: tool spec is invalid. + - invalid_tool_call: call is not a valid tool invocation. + - overall_valid: validity of the call (main score). + - score: alias of overall_valid. + + Here is an example for a aggregated reflection output after calling reduce. + The range of each score is [0, 1] (where higher indicates less errors). + { + "static_non_existent_function": 1.0, + "static_non_existent_parameter": 1.0, + "static_incorrect_parameter_type": 1.0, + "static_missing_required_parameter": 1.0, + "static_allowed_values_violation": 1.0, + "static_json_schema_violation": 1.0, + "static_empty_api_spec": 1.0, + "static_invalid_api_spec": 1.0, + "static_invalid_tool_call": 1.0, + "semantic_general_hallucination_check": 0.0, + "semantic_general_value_format_alignment": 0.0, + "semantic_avg_score_general": 1.0, + "semantic_function_selection_appropriateness": 0.0, + "semantic_agentic_constraints_satisfaction": 0.0, + "semantic_avg_score_function_selection": 1.0, + "overall_valid": 1.0 + } + + Where overall_valid is the final decision made by the reflection pipeline, indicating whether the tool call is valid or not. + + Before the aggregation each metric contains also evidence, explanation, a more fine-grained score, etc. + + Reference: https://github.ibm.com/MLT/LLMEvalKit + """ + from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline + from llmevalkit.function_calling.pipeline.types import PipelineResult + + main_score = "overall_valid" + reduction = MeanReduction() + prediction_type = ToolCall + _requirements_list = { + "llmevalkit": "pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'" + } + runtime_pipeline: bool = True # Whether to use the runtime pipeline or the longer evaluation pipeline with actionable recommendations + + def prepare(self): + provider_to_default_reflector_model = { + "watsonx": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + "open-ai": "gpt-4o", + "rits": "openai/gpt-oss-120b", + "azure": "gpt-4o", + "mock": "mock" + } + self.requirements = self._check_missing_requirements_by_provider(unitxt.settings.default_provider) + super().prepare() + self.setup_pipeline( + reflector_model_name=provider_to_default_reflector_model.get(unitxt.settings.default_provider), + provider_name=unitxt.settings.default_provider + ) + + def setup_pipeline(self, reflector_model_name: str, provider_name: str) -> ReflectionPipeline: + llmeval_provider_name = self._get_llmeval_provider_name(provider_name) + self._check_missing_requirements_by_provider(unitxt.settings.default_provider) # TODO: fix it + metrics_client = self._get_metrics_client(llmeval_provider_name, reflector_model_name) + self.reflection_pipeline = self._build_reflection_pipeline(metrics_client) + return self.reflection_pipeline + + @staticmethod + def _get_llmeval_provider_name(provider_name: str) -> str: + mapping = { + "watsonx": "watsonx.output_val", + "open-ai": "openai.async.output_val", + "rits": "litellm.rits.output_val", + "azure": "azure_openai.async.output_val", + "mock": "mock", + } + llmeval_provider_name = mapping.get(provider_name) + if llmeval_provider_name is None: + raise ValueError(f"Unsupported provider by llmevalkit: {provider_name}") + return llmeval_provider_name + + @staticmethod + def _check_missing_requirements_by_provider(provider_name: str): + provider_libs = { + "watsonx": "ibm_watsonx_ai", + "open-ai": "openai", + "rits": "litellm", + "azure": "openai" + } + required_lib = provider_libs.get(provider_name) + return [required_lib] if required_lib else [] + + @staticmethod + def _get_metrics_client(llmeval_provider_name: str, reflector_model_name: str): + from llmevalkit.llm import get_llm + MetricsClientCls = get_llm(llmeval_provider_name) + return MetricsClientCls(model_name=reflector_model_name) + + def _build_reflection_pipeline(self, metrics_client): + from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline + from llmevalkit.function_calling.consts import ( + METRIC_GENERAL_HALLUCINATION_CHECK, + METRIC_AGENTIC_CONSTRAINTS_SATISFACTION, + METRIC_FUNCTION_SELECTION_APPROPRIATENESS, + METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT + ) + return ReflectionPipeline( + metrics_client=metrics_client, + general_metrics=[ + METRIC_GENERAL_HALLUCINATION_CHECK, + METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT, + ], + function_metrics=[ + METRIC_FUNCTION_SELECTION_APPROPRIATENESS, + METRIC_AGENTIC_CONSTRAINTS_SATISFACTION, + ], + parameter_metrics=[], + runtime_pipeline=self.runtime_pipeline, + ) + + async def map( + self, + prediction: ToolCall, + references: None, + task_data: Dict[str, Any], + ) -> PipelineResult: + from llmevalkit.function_calling.pipeline.types import PipelineResult + + # Convert unitxt dialog to LLMEvalKit format + conversation_history = [dict(turn) for turn in task_data["dialog"]] + + # Convert unitxt tool inventory to LLMEvalKit format + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) + + # Convert unitxt tool call to LLMEvalKit format + tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction) + + # Run reflection (syntactic + semantic) + result: PipelineResult = await self.reflection_pipeline.run_async( + conversation=conversation_history, + inventory=tools_inventory, + call=tool_call, + retries=3, + continue_on_static=True, + ) + return result.model_dump() + + + def map_stream( + self, + items: Iterable[Tuple[ToolCall, None, Dict[str, Any]]], + *, + max_concurrency: int = 8, + ) -> List[Dict[str, Any]]: + """ + Run self.map in parallel over an iterable and return results in order. + """ + async def process_all(): + items_iter = iter(enumerate(items)) + results = [] + pending = set() + while True: + while len(pending) < max_concurrency: + try: + idx, (pred, refs, data) = next(items_iter) + task = asyncio.create_task(self.map(pred, refs, data)) + task.idx = idx + pending.add(task) + except StopIteration: + break + if not pending: + break + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + results.append((task.idx, await task)) + results.sort() + return [r for _, r in results] + return asyncio.run(process_all()) + + def reduce_one(self, intermidate: Dict[str, Any]) -> Dict[str, float]: + return intermidate + + def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]: + flat_instances = [] + for instance in intermidates: + flat_instance_dict = {} + for metric, metric_type_dict in instance.get("static", {}).get("metrics", {}).items(): + flat_instance_dict[f"static_{metric}"] = float(metric_type_dict["valid"]) + + for metric_type, metric_type_dict in instance.get("semantic", {}).items(): + for metric, metric_dict in metric_type_dict.get("metrics", {}).items(): + flat_instance_dict[f"semantic_{metric}"] = float(metric_dict["is_issue"]) + flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(metric_type_dict.get("avg_score")) + + flat_instance_dict["overall_valid"] = float(instance["overall_valid"]) + flat_instances.append(flat_instance_dict) + + + return self.reduction.reduce(flat_instances) class ReflectionToolCallingMetricSyntactic( @@ -896,8 +1156,9 @@ class ReflectionToolCallingMetricSyntactic( ): """Measures syntactic and schema validity of tool calls. - Range: [0, 1] (higher is better) - Returns 1.0 if the tool call is valid (all checks pass), 0.0 otherwise. + Range: [0, 1] (higher indicates less errors). + Returns 1.0 if the tool call is valid for each metric, 0.0 otherwise. + overall_valid equals 1.0 if all metrics are valid, 0.0 otherwise. Global score is the percentage of valid instances across the dataset. Scores: @@ -938,39 +1199,32 @@ def map( ) # Convert unitxt tool inventory to LLMEvalKit format - tools_inventory = [] - for tool in task_data.get("tools", []): - tools_inventory.append( - LLMEvalKitToolSpec( - type="function", - function={**tool}, - ) - ) + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) # Convert unitxt tool call to LLMEvalKit format - tool_call = LLMEvalKitToolCall( - type="function", - function={ - "name": prediction["name"], - "arguments": json.dumps(prediction["arguments"]), - "parsed_arguments": prediction["arguments"], - }, - ) + tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction) # Run static validation static_result = ReflectionPipeline.static_only(tools_inventory, tool_call) - result_dict = { - ( - metric_name - if metric_name != "json_schema_validation" - else "json_schema_violation" - ): not metric_dict.valid - for metric_name, metric_dict in static_result.metrics.items() - } - result_dict["overall_valid"] = static_result.final_decision + result_dict = static_result.model_dump() + result_dict["overall_valid"] = result_dict.pop("final_decision") + result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop("json_schema_validation") return result_dict + def reduce_one(self, intermidate: Dict[str, float]) -> Dict[str, float]: + return intermidate + + def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]: + flat_instances = [] + for instance in intermediates: + flat_instance_dict = {} + for metric, metric_dict in instance.get("metrics", {}).items(): + flat_instance_dict[metric] = float(metric_dict["valid"]) + flat_instance_dict["overall_valid"] = instance["overall_valid"] + flat_instances.append(flat_instance_dict) + + return self.reduction.reduce(flat_instances) class MetricWithConfidenceInterval(Metric): # The number of resamples used to estimate the confidence intervals of this metric. diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index ac66289250..36e633270c 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1,6 +1,10 @@ +import asyncio from math import isnan from typing import Dict, List +import random +import unitxt +from unitxt.types import Dialog, Tool from unitxt.api import create_dataset, evaluate from unitxt.inference import MockInferenceEngine from unitxt.llm_as_judge import LLMAsJudge, TaskBasedLLMasJudge @@ -61,6 +65,7 @@ Perplexity, PrecisionBinary, RecallBinary, + ReflectionToolCallingMetric, ReflectionToolCallingMetricSyntactic, RelaxedCorrectness, RocAuc, @@ -146,6 +151,7 @@ class TestMetrics(UnitxtTestCase): + use_mock_model: bool = True def test_unsorted_list_exact_match(self): metric = UnsortedListExactMatch() @@ -1649,22 +1655,17 @@ def test_tool_calling_metric(self): outputs[0]["score"]["global"]["argument_schema_validation"], 0.0 ) - def test_complex_tool_call_real_static_only(self): - """Test a complex tool call with multiple types of validation issues.""" - # Create a complex tool call with multiple issues - metric = ReflectionToolCallingMetricSyntactic() + def test_reflection_tool_calling_metric(self): + unitxt.settings.default_provider = "mock" if self.use_mock_model else "rits" + metric = ReflectionToolCallingMetric() + prediction = ToolCall( **{ "name": "advanced_weather", "arguments": { "location": "San Francisco", - "days": "7", # Wrong type (string vs integer) - "format": "complete", # Invalid enum value - "include_alerts": "yes", # Wrong type (string vs boolean) - "coordinates": "37.7749,-122.4194", # Wrong type (should be array) - "debug_mode": True, # Non-existent parameter - "extra_info": "all", # Non-existent parameter - # Missing required parameter: api_key + "days": 7, + "format": "summary", }, } ) @@ -1672,50 +1673,47 @@ def test_complex_tool_call_real_static_only(self): references = [] task_data = { - "tools": [ + "tools": [Tool( { "name": "advanced_weather", "description": "Get advanced weather forecast", "parameters": { "type": "object", - "required": ["location", "api_key"], + "required": ["location", "days"], "properties": { "location": {"type": "string"}, - "api_key": {"type": "string"}, "days": {"type": "integer"}, "format": { "type": "string", "enum": ["brief", "detailed", "summary"], - }, - "include_alerts": {"type": "boolean"}, - "coordinates": { - "type": "array", - "items": {"type": "number"}, + "default": "detailed" }, }, }, - } - ] + }) + ], + "dialog": Dialog([{"role": "user", "content": "What's the weather like in San Francisco in the next 7 days? Give me a detailed forecast."}],) } # Call the map method - result = metric.map(prediction, references, task_data) + result = metric.map_stream(items=[(prediction, None, task_data)])[0] # Verify all the scores self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_function"], False) + self.assertEqual(result["static"]["final_decision"], True) + self.assertEqual(result["semantic"]["general"]["metrics"]["general_hallucination_check"]["is_issue"], True) - # 1 out of 2 required parameters (missing api_key) - self.assertEqual(result["missing_required_parameter"], True) + unitxt_provider_name = "mock" if self.use_mock_model else "watsonx" + model_name = "meta-llama/llama-4-maverick-17b-128e-instruct-fp8" + metric.setup_pipeline(reflector_model_name=model_name, provider_name=unitxt_provider_name) - # 5 valid parameters out of 7 total (2 non-existent) - self.assertEqual(result["non_existent_parameter"], True) + result = metric.map_stream(items=[(prediction, None, task_data)])[0] - self.assertEqual(result["allowed_values_violation"], True) - self.assertEqual(result["incorrect_parameter_type"], True) + # Verify all the scores + self.assertEqual(result["overall_valid"], False) + self.assertEqual(result["static"]["final_decision"], True) + self.assertEqual(result["semantic"]["general"]["metrics"]["general_hallucination_check"]["is_issue"], True) - # Schema validation might not reflect our expectations due to updated logic - # We're primarily interested in value precision calculation def test_partial_value_precision_enum_violations_real_static_only(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -1769,10 +1767,242 @@ def test_partial_value_precision_enum_violations_real_static_only(self): # Assert partial value precision - all parameters exist in the schema # 4 out of 6 parameters have valid values (unit and format have enum violations) - self.assertAlmostEqual(result["allowed_values_violation"], True) + self.assertAlmostEqual(result["metrics"]["allowed_values_violation"]["valid"], False) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["missing_required_parameter"], False) + self.assertAlmostEqual(result["metrics"]["missing_required_parameter"]["valid"], True) + + def test_reflection_tool_calling_metric_reduce(self): + # Instance 1: valid call + instance1 = { + "overall_valid": True, + "static": { + "metrics": { + "non_existent_function": {"valid": True}, + "missing_required_parameter": {"valid": True}, + }, + "final_decision": True, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": False}, + "general_value_format_alignment": {"is_issue": False}, + }, + "avg_score": 1.0, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": False}, + "agentic_constraints_satisfaction": {"is_issue": False}, + }, + "avg_score": 1.0, + }, + }, + } + # Instance 2: invalid call + instance2 = { + "overall_valid": False, + "static": { + "metrics": { + "non_existent_function": {"valid": False}, + "missing_required_parameter": {"valid": False}, + }, + "final_decision": False, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": True}, + "general_value_format_alignment": {"is_issue": True}, + }, + "avg_score": 0.0, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": True}, + "agentic_constraints_satisfaction": {"is_issue": True}, + }, + "avg_score": 0.0, + }, + }, + } + # Instance 3: partially valid + instance3 = { + "overall_valid": True, + "static": { + "metrics": { + "non_existent_function": {"valid": True}, + "missing_required_parameter": {"valid": False}, + }, + "final_decision": True, + }, + "semantic": { + "general": { + "metrics": { + "general_hallucination_check": {"is_issue": False}, + "general_value_format_alignment": {"is_issue": True}, + }, + "avg_score": 0.5, + }, + "function_selection": { + "metrics": { + "function_selection_appropriateness": {"is_issue": False}, + "agentic_constraints_satisfaction": {"is_issue": True}, + }, + "avg_score": 0.5, + }, + }, + } + + unitxt.settings.default_provider = "mock" if self.use_mock_model else "rits" + metric = ReflectionToolCallingMetric() + reduced = metric.reduce([instance1, instance2, instance3]) + + # All outputs should be floats in [0, 1] + import math + self.assertIsInstance(reduced, dict) + for k, v in reduced.items(): + self.assertIsInstance(v, float, msg=f"value for {k} is not float") + self.assertFalse(math.isnan(v), msg=f"NaN at key {k}") + self.assertGreaterEqual(v, 0.0, msg=f"value for {k} < 0") + self.assertLessEqual(v, 1.0, msg=f"value for {k} > 1") + + # Overall validity is the mean of booleans + expected_overall_valid = (1 + 0 + 1) / 3 + self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid) + + # Static aggregated metrics, now flattened + # non_existent_function: True, False, True -> 2/3 + nef_expected = (1 + 0 + 1) / 3 + # missing_required_parameter: True, False, False -> 1/3 + mrp_expected = (1 + 0 + 0) / 3 + self.assertIn("static_non_existent_function", reduced) + self.assertIn("static_missing_required_parameter", reduced) + self.assertAlmostEqual(reduced["static_non_existent_function"], nef_expected) + self.assertAlmostEqual(reduced["static_missing_required_parameter"], mrp_expected) + + # Semantic per family average scores + semantic_avg_expected = (1.0 + 0.0 + 0.5) / 3 + self.assertIn("semantic_avg_score_general", reduced) + self.assertIn("semantic_avg_score_function_selection", reduced) + self.assertAlmostEqual(reduced["semantic_avg_score_general"], semantic_avg_expected) + self.assertAlmostEqual(reduced["semantic_avg_score_function_selection"], semantic_avg_expected) + + # Semantic issue rates for individual checks, flattened, mean of is_issue + # general_hallucination_check: False, True, False -> 1/3 + ghc_expected = (0 + 1 + 0) / 3 + # general_value_format_alignment: False, True, True -> 2/3 + gvfa_expected = (0 + 1 + 1) / 3 + # function_selection_appropriateness: False, True, False -> 1/3 + fsa_expected = (0 + 1 + 0) / 3 + # agentic_constraints_satisfaction: False, True, True -> 2/3 + acs_expected = (0 + 1 + 1) / 3 + + self.assertIn("semantic_general_hallucination_check", reduced) + self.assertIn("semantic_general_value_format_alignment", reduced) + self.assertIn("semantic_function_selection_appropriateness", reduced) + self.assertIn("semantic_agentic_constraints_satisfaction", reduced) + + self.assertAlmostEqual(reduced["semantic_general_hallucination_check"], ghc_expected) + self.assertAlmostEqual(reduced["semantic_general_value_format_alignment"], gvfa_expected) + self.assertAlmostEqual(reduced["semantic_function_selection_appropriateness"], fsa_expected) + self.assertAlmostEqual(reduced["semantic_agentic_constraints_satisfaction"], acs_expected) + + + + def test_reflection_tool_calling_metric_syntactic_reduce(self): + from unitxt.metrics import ReflectionToolCallingMetricSyntactic + + # Instance 1: invalid function, all other checks pass + instance1 = { + "metrics": { + "non_existent_function": {"valid": False}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": True}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": False, + } + # Instance 2: all checks pass + instance2 = { + "metrics": { + "non_existent_function": {"valid": True}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": True}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": True, + } + # Instance 3: missing required parameter, all others pass + instance3 = { + "metrics": { + "non_existent_function": {"valid": True}, + "non_existent_parameter": {"valid": True}, + "incorrect_parameter_type": {"valid": True}, + "missing_required_parameter": {"valid": False}, + "allowed_values_violation": {"valid": True}, + "empty_api_spec": {"valid": True}, + "invalid_api_spec": {"valid": True}, + "invalid_tool_call": {"valid": True}, + "json_schema_violation": {"valid": True}, + }, + "overall_valid": False, + } + + metric = ReflectionToolCallingMetricSyntactic() + inputs = [instance1, instance2, instance3] + reduced = metric.reduce(inputs) + + # 1) Key set is exactly metrics + overall_valid + metric_names = list(instance1["metrics"].keys()) + expected_keys = set(metric_names + ["overall_valid"]) + self.assertEqual(set(reduced.keys()), expected_keys) + + # 2) All values are floats in [0, 1] and not NaN + for k, v in reduced.items(): + self.assertIsInstance(v, float, f"value for {k} is not float") + self.assertFalse(isnan(v), f"NaN at {k}") + self.assertGreaterEqual(v, 0.0, f"value for {k} < 0") + self.assertLessEqual(v, 1.0, f"value for {k} > 1") + + # 3) Per metric aggregation equals mean of valid flags across instances + def mean_valid(name: str) -> float: + vals = [ + 1.0 if inst["metrics"][name]["valid"] else 0.0 + for inst in inputs + ] + return sum(vals) / len(vals) + + for m in metric_names: + self.assertAlmostEqual(reduced[m], mean_valid(m), msg=f"mismatch for {m}") + + # Spot checks, same as your original examples + self.assertAlmostEqual(reduced["non_existent_function"], (0 + 1 + 1) / 3) + self.assertAlmostEqual(reduced["missing_required_parameter"], (1 + 1 + 0) / 3) + self.assertAlmostEqual(reduced["json_schema_violation"], (1 + 1 + 1) / 3) + + # 4) overall_valid equals mean of per instance overall_valid + expected_overall_valid = (0 + 1 + 0) / 3 + self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid) + + # 5) Order invariance, shuffle inputs and expect identical result + shuffled = inputs[:] + random.shuffle(shuffled) + reduced_shuffled = metric.reduce(shuffled) + self.assertEqual(reduced, reduced_shuffled) + + def test_tool_calling_metric_syntactic_reflector(self): metric = ReflectionToolCallingMetricSyntactic() tools_data = { @@ -1805,7 +2035,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # Exact match should be 1.0 when prediction and reference are identical self.assertEqual(outputs["overall_valid"], True) - self.assertEqual(outputs["non_existent_function"], False) + self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) # Test case 2: Different tool name prediction = {"name": "different_tool", "arguments": {"param1": "value1"}} @@ -1817,7 +2047,7 @@ def test_tool_calling_metric_syntactic_reflector(self): ) self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], True) + self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], False) # Test case 3: Different parameter names prediction = { @@ -1833,17 +2063,17 @@ def test_tool_calling_metric_syntactic_reflector(self): # Exact match should be 0.0, tool choice 1.0 self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], False) + self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) # param1 is present but param2 is missing out of 2 required parameters in the schema - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) # 1 valid parameter (param1) out of 2 total parameters (param1, wrongParam) - self.assertEqual(outputs["non_existent_parameter"], True) + self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], False) # Since wrongParam doesn't exist in the schema, value precision only considers param1 # param1 has the correct type, so value precision is 1.0 - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) # Test case 4: Different parameter values prediction = { @@ -1872,11 +2102,11 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Recall should be 0 for empty arguments (missing required parameters) - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) # Precision is 1.0 because there are no invalid parameter names (no non-existent parameters) - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) prediction = {"name": "test_tool", "arguments": {}} reference = {"name": "test_tool", "arguments": {}} @@ -1890,11 +2120,11 @@ def test_tool_calling_metric_syntactic_reflector(self): # Recall should still be 0 since there are required parameters in the schema # (regardless of the references, which are not used for validation) - self.assertEqual(outputs["missing_required_parameter"], True) + self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) # Precision is 1.0 because there are no invalid parameter names - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["incorrect_parameter_type"], False) + self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) # Test case 6: Multiple references with one match prediction = {"name": "test_tool", "arguments": {"param1": "value1"}} @@ -1908,9 +2138,9 @@ def test_tool_calling_metric_syntactic_reflector(self): # overall_valid should be 0.0 because param2 is missing (it's required) self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["non_existent_function"], False) - self.assertEqual(outputs["missing_required_parameter"], True) - self.assertEqual(outputs["non_existent_parameter"], False) + self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) + self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Test case 7: Parameter types prediction = { @@ -1925,7 +2155,7 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Parameters should have correct types - self.assertEqual(outputs["json_schema_violation"], False) + self.assertEqual(outputs["metrics"]["json_schema_violation"]["valid"], True) # Test case 8: Wrong parameter types prediction = { @@ -1941,7 +2171,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # json_schema_violation is separate from the type validation # schema validation can still pass even if there are type errors - self.assertEqual(outputs["json_schema_violation"], False) + self.assertEqual(outputs["metrics"]["json_schema_violation"]["valid"], True) def test_overall_valid_success_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -1987,8 +2217,8 @@ def test_overall_valid_success_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], True) - self.assertEqual(result["non_existent_function"], False) - self.assertEqual(result["missing_required_parameter"], False) + self.assertEqual(result["metrics"]["non_existent_function"]["valid"], True) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) def test_non_existent_function_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2021,8 +2251,8 @@ def test_non_existent_function_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_function"], True) - self.assertEqual(result["missing_required_parameter"], False) + self.assertEqual(result["metrics"]["non_existent_function"]["valid"], False) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) def test_missing_required_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2061,8 +2291,8 @@ def test_missing_required_parameter_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["missing_required_parameter"], True) - self.assertEqual(result["allowed_values_violation"], False) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], True) def test_non_existent_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2095,7 +2325,7 @@ def test_non_existent_parameter_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["non_existent_parameter"], True) + self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) def test_incorrect_parameter_type_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2125,9 +2355,9 @@ def test_incorrect_parameter_type_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["incorrect_parameter_type"], True) - self.assertEqual(result["missing_required_parameter"], False) - self.assertEqual(result["allowed_values_violation"], False) + self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) + self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], True) def test_multiple_parameter_issues_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2188,14 +2418,14 @@ def test_multiple_parameter_issues_real_map(self): self.assertEqual(result["overall_valid"], False) # Missing 2 out of 2 required parameters - self.assertEqual(result["missing_required_parameter"], True) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) # 2 invalid parameters out of 4 total - self.assertEqual(result["non_existent_parameter"], True) + self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) # Both unit and format have type issues and no other validation issues are reported # So 0 out of 2 valid parameters (excluding the non-existent ones) - self.assertEqual(result["incorrect_parameter_type"], True) + self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) # Schema validation might not reflect our expectations due to updated logic # We're primarily interested in value precision calculation @@ -2243,8 +2473,8 @@ def test_allowed_values_violation(self): # Assert expected results self.assertEqual(result["overall_valid"], False) # With 1 invalid enum value out of 2 parameters, value precision should be 0.5 - self.assertEqual(result["allowed_values_violation"], True) - self.assertEqual(result["incorrect_parameter_type"], False) + self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], False) + self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], True) def test_json_schema_violation_specific_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2281,10 +2511,10 @@ def test_json_schema_violation_specific_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) # Overall validation fails self.assertEqual( - result["missing_required_parameter"], True + result["metrics"]["missing_required_parameter"]["valid"], False ) # Missing required param # json_schema_violation specifically should be 1.0 because it's marked valid - self.assertEqual(result["json_schema_violation"], False) + self.assertEqual(result["metrics"]["json_schema_violation"]["valid"], True) def test_partial_recall_missing_parameters_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2322,7 +2552,7 @@ def test_partial_recall_missing_parameters_real_map(self): # Assert partial recall - 1 out of 2 required parameters provided self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["missing_required_parameter"], True) + self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) def test_partial_precision_non_existent_parameters_real_map(self): """Test partial precision score when some parameters don't exist in the schema.""" @@ -2366,7 +2596,7 @@ def test_partial_precision_non_existent_parameters_real_map(self): result = metric.map(prediction, references, task_data) # Assert partial precision - 3 out of 6 parameters are valid - self.assertEqual(result["non_existent_parameter"], True) + self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) self.assertEqual(result["overall_valid"], False) def test_partial_value_precision_type_errors_real_map(self): @@ -2413,7 +2643,7 @@ def test_partial_value_precision_type_errors_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["incorrect_parameter_type"], True) + self.assertAlmostEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) def test_partial_value_precision_enum_violations_real_map(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -2465,7 +2695,7 @@ def test_partial_value_precision_enum_violations_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["allowed_values_violation"], True) + self.assertAlmostEqual(result["metrics"]["allowed_values_violation"]["valid"], False) def test_tool_calling_key_value_metric(self): metric = ToolCallKeyValueExtraction(metric="metrics.accuracy") From cc6e827bfba37167bc35e2600ef2b03b1c43f21e Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 1 Sep 2025 10:28:02 +0300 Subject: [PATCH 02/14] removed redundant import which makes tests fail. --- src/unitxt/metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index cbb6c9e8df..3ad0826764 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -30,7 +30,6 @@ Type, TypeVar, Union, - override, ) import evaluate From f3efba19c209f420bf82ce147d8f3673a842a761 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 1 Sep 2025 14:59:38 +0300 Subject: [PATCH 03/14] Minor fix for mock provider name in llmevalkit --- src/unitxt/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 3ad0826764..3c0032f7cf 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1023,7 +1023,7 @@ def _get_llmeval_provider_name(provider_name: str) -> str: "open-ai": "openai.async.output_val", "rits": "litellm.rits.output_val", "azure": "azure_openai.async.output_val", - "mock": "mock", + "mock": "mock.output_val", } llmeval_provider_name = mapping.get(provider_name) if llmeval_provider_name is None: From 0537dd0b91e14eef7111fa95790db6b65771f02b Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Tue, 2 Sep 2025 10:44:34 +0300 Subject: [PATCH 04/14] Update descriptions for ReflectionToolCallingMetric and ReflectionToolCallingMetricSyntactic; enhance clarity and detail on evaluation criteria and installation instructions. --- prepare/metrics/tool_calling.py | 4 ++-- .../catalog/metrics/tool_calling/reflection.json | 2 +- .../metrics/tool_calling/reflection/syntactic.json | 2 +- src/unitxt/metrics.py | 10 ++-------- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/prepare/metrics/tool_calling.py b/prepare/metrics/tool_calling.py index cab880fa5d..b01bb1dc86 100644 --- a/prepare/metrics/tool_calling.py +++ b/prepare/metrics/tool_calling.py @@ -57,7 +57,7 @@ add_to_catalog( ReflectionToolCallingMetric( - __description__="""A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations.""" + __description__="""A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations. For each instance, it returns a score reflecting its overall validity, as well as a breakdown of the specific checks/metrics that passed or failed, including hallucination check, value format alignment, function selection and agentic constraints satisfaction. Each metric also contains an evidence from the input, an explanation describing the reflection decision, a confidence, and a validity score with a range of 1-5 (higher score -> more valid).""" ), "metrics.tool_calling.reflection", overwrite=True, @@ -65,7 +65,7 @@ add_to_catalog( ReflectionToolCallingMetricSyntactic( - __description__="""This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy.""" + __description__="""This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy. Each metric also contains an explanation describing the errors that it detected (if no errors were found - the explanation will be None).""" ), "metrics.tool_calling.reflection.syntactic", overwrite=True, diff --git a/src/unitxt/catalog/metrics/tool_calling/reflection.json b/src/unitxt/catalog/metrics/tool_calling/reflection.json index 6f1d560357..7b52d5ec12 100644 --- a/src/unitxt/catalog/metrics/tool_calling/reflection.json +++ b/src/unitxt/catalog/metrics/tool_calling/reflection.json @@ -1,4 +1,4 @@ { "__type__": "reflection_tool_calling_metric", - "__description__": "A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations." + "__description__": "A metric that assesses tool call predictions for both syntactic correctness and semantic validity, using predefined checks combined with LLM-based evaluations. For each instance, it returns a score reflecting its overall validity, as well as a breakdown of the specific checks/metrics that passed or failed, including hallucination check, value format alignment, function selection and agentic constraints satisfaction. Each metric also contains an evidence from the input, an explanation describing the reflection decision, a confidence, and a validity score with a range of 1-5 (higher score -> more valid)." } diff --git a/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json b/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json index 85b5a37227..d4a1e4bf8b 100644 --- a/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json +++ b/src/unitxt/catalog/metrics/tool_calling/reflection/syntactic.json @@ -1,4 +1,4 @@ { "__type__": "reflection_tool_calling_metric_syntactic", - "__description__": "This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy." + "__description__": "This metric evaluates whether a model's tool call outputs are structurally valid by checking their compliance with the provided tool schema. For each instance, it returns a binary score (True for valid, False for invalid), and aggregates these into a global percentage across all instances. The evaluation covers a wide range of possible issues, including nonexistent functions or parameters, incorrect parameter types, missing required parameters, values outside allowed ranges, JSON schema violations, invalid or empty API specifications, and malformed tool calls. The main reported score, overall_valid (aliased as score), reflects the proportion of calls that are fully valid, making the metric a measure of syntactic and schema-level correctness rather than semantic accuracy. Each metric also contains an explanation describing the errors that it detected (if no errors were found - the explanation will be None)." } diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 3c0032f7cf..f75cb9b748 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -990,7 +990,7 @@ class ReflectionToolCallingMetric( reduction = MeanReduction() prediction_type = ToolCall _requirements_list = { - "llmevalkit": "pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'" + "llmevalkit": "Install with \"pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'\".\nTo gain access please reach the team." } runtime_pipeline: bool = True # Whether to use the runtime pipeline or the longer evaluation pipeline with actionable recommendations @@ -1033,7 +1033,7 @@ def _get_llmeval_provider_name(provider_name: str) -> str: @staticmethod def _check_missing_requirements_by_provider(provider_name: str): provider_libs = { - "watsonx": "ibm_watsonx_ai", + "watsonx": "ibm-watsonx-ai", "open-ai": "openai", "rits": "litellm", "azure": "openai" @@ -1190,12 +1190,6 @@ def map( task_data: Dict[str, Any], ) -> Dict[str, float]: from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline - from llmevalkit.function_calling.pipeline.types import ( - ToolCall as LLMEvalKitToolCall, - ) - from llmevalkit.function_calling.pipeline.types import ( - ToolSpec as LLMEvalKitToolSpec, - ) # Convert unitxt tool inventory to LLMEvalKit format tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) From 3f98a589c81094f31d89f5949a3e5ea7b502038b Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Tue, 2 Sep 2025 11:04:31 +0300 Subject: [PATCH 05/14] Refactor ReflectionToolCallingMetric to use settings directly instead of unitxt.settings; update provider name format for watsonx. --- src/unitxt/metrics.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index f75cb9b748..6b083f4829 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -40,8 +40,6 @@ from scipy.stats import bootstrap from scipy.stats._warnings_errors import DegenerateDataWarning -import unitxt - from .artifact import Artifact from .base_metric import Metric from .collections import ListCollection @@ -1002,16 +1000,16 @@ def prepare(self): "azure": "gpt-4o", "mock": "mock" } - self.requirements = self._check_missing_requirements_by_provider(unitxt.settings.default_provider) + self.requirements = self._check_missing_requirements_by_provider(settings.default_provider) super().prepare() self.setup_pipeline( - reflector_model_name=provider_to_default_reflector_model.get(unitxt.settings.default_provider), - provider_name=unitxt.settings.default_provider + reflector_model_name=provider_to_default_reflector_model.get(settings.default_provider), + provider_name=settings.default_provider ) def setup_pipeline(self, reflector_model_name: str, provider_name: str) -> ReflectionPipeline: llmeval_provider_name = self._get_llmeval_provider_name(provider_name) - self._check_missing_requirements_by_provider(unitxt.settings.default_provider) # TODO: fix it + self._check_missing_requirements_by_provider(settings.default_provider) # TODO: fix it metrics_client = self._get_metrics_client(llmeval_provider_name, reflector_model_name) self.reflection_pipeline = self._build_reflection_pipeline(metrics_client) return self.reflection_pipeline @@ -1033,7 +1031,7 @@ def _get_llmeval_provider_name(provider_name: str) -> str: @staticmethod def _check_missing_requirements_by_provider(provider_name: str): provider_libs = { - "watsonx": "ibm-watsonx-ai", + "watsonx": "ibm_watsonx_ai", "open-ai": "openai", "rits": "litellm", "azure": "openai" From 1848f9f473f5c1efea5ace33b7eaba6371a45ab1 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 12:09:27 +0300 Subject: [PATCH 06/14] Fixed minor bugs to support different tasks. --- .../evaluate_tool_calling_with_reflection.py | 2 +- src/unitxt/metrics.py | 35 ++++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/evaluate_tool_calling_with_reflection.py b/examples/evaluate_tool_calling_with_reflection.py index 47da026752..b52d74f5ba 100644 --- a/examples/evaluate_tool_calling_with_reflection.py +++ b/examples/evaluate_tool_calling_with_reflection.py @@ -64,7 +64,7 @@ test_set=data, split="test", format="formats.chat_api", - metrics=["metrics.tool_calling.reflection.syntactic"], + metrics=["metrics.tool_calling.reflection.syntactic", "metrics.tool_calling.reflection"], max_test_instances=10, ) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 6b083f4829..b0434e3f9d 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -910,7 +910,7 @@ def convert_tools_inventory(tools): ] @staticmethod - def convert_tool_call(prediction): + def convert_tool_call(prediction: ToolCall): from llmevalkit.function_calling.pipeline.types import ToolCall as LLMEvalKitToolCall return LLMEvalKitToolCall( type="function", @@ -1076,19 +1076,24 @@ async def map( from llmevalkit.function_calling.pipeline.types import PipelineResult # Convert unitxt dialog to LLMEvalKit format - conversation_history = [dict(turn) for turn in task_data["dialog"]] - + if "dialog" in task_data: + conversation_history = [dict(turn) for turn in task_data["dialog"]] + elif "query" in task_data: + conversation_history = [{"role": "user", "content": task_data["query"]}] + else: + raise ValueError("task_data must contain either 'dialog' or 'query' field.") + # Convert unitxt tool inventory to LLMEvalKit format tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) # Convert unitxt tool call to LLMEvalKit format - tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction) + tool_call_converted = ReflectionToolCallingMixin.convert_tool_call(prediction) # Run reflection (syntactic + semantic) result: PipelineResult = await self.reflection_pipeline.run_async( conversation=conversation_history, inventory=tools_inventory, - call=tool_call, + call=tool_call_converted, retries=3, continue_on_static=True, ) @@ -1112,9 +1117,15 @@ async def process_all(): while len(pending) < max_concurrency: try: idx, (pred, refs, data) = next(items_iter) - task = asyncio.create_task(self.map(pred, refs, data)) - task.idx = idx - pending.add(task) + if isinstance(pred, list): + for p in pred: + task = asyncio.create_task(self.map(p, refs, data)) + task.idx = idx + pending.add(task) + else: + task = asyncio.create_task(self.map(pred, refs, data)) + task.idx = idx + pending.add(task) except StopIteration: break if not pending: @@ -1137,14 +1148,14 @@ def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]: flat_instance_dict[f"static_{metric}"] = float(metric_type_dict["valid"]) for metric_type, metric_type_dict in instance.get("semantic", {}).items(): - for metric, metric_dict in metric_type_dict.get("metrics", {}).items(): - flat_instance_dict[f"semantic_{metric}"] = float(metric_dict["is_issue"]) - flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(metric_type_dict.get("avg_score")) + if metric_type_dict is not None: + for metric, metric_dict in metric_type_dict.get("metrics", {}).items(): + flat_instance_dict[f"semantic_{metric}"] = float(metric_dict["is_issue"]) + flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(metric_type_dict.get("avg_score")) flat_instance_dict["overall_valid"] = float(instance["overall_valid"]) flat_instances.append(flat_instance_dict) - return self.reduction.reduce(flat_instances) From 8367826a32f64702f7aba4c3105d205d138be1e4 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 12:47:35 +0300 Subject: [PATCH 07/14] Fixed requirements issue, general import bug, and added some guards for the provider. --- src/unitxt/metrics.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index b0434e3f9d..d1ef5416d4 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -981,9 +981,6 @@ class ReflectionToolCallingMetric( Reference: https://github.ibm.com/MLT/LLMEvalKit """ - from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline - from llmevalkit.function_calling.pipeline.types import PipelineResult - main_score = "overall_valid" reduction = MeanReduction() prediction_type = ToolCall @@ -1000,16 +997,22 @@ def prepare(self): "azure": "gpt-4o", "mock": "mock" } - self.requirements = self._check_missing_requirements_by_provider(settings.default_provider) + provider = settings.default_provider if not settings.mock_inference_mode else "mock" + if provider not in provider_to_default_reflector_model: + raise ValueError(f"Unsupported provider for ReflectionToolCallingMetric: {provider}. Supported providers are: {list(provider_to_default_reflector_model.keys())}") + self.requirements = self._get_missing_requirements_by_provider(provider) super().prepare() self.setup_pipeline( - reflector_model_name=provider_to_default_reflector_model.get(settings.default_provider), - provider_name=settings.default_provider + reflector_model_name=provider_to_default_reflector_model.get(provider), + provider_name=provider ) - - def setup_pipeline(self, reflector_model_name: str, provider_name: str) -> ReflectionPipeline: - llmeval_provider_name = self._get_llmeval_provider_name(provider_name) - self._check_missing_requirements_by_provider(settings.default_provider) # TODO: fix it + + def setup_pipeline(self, reflector_model_name: str, provider_name: Optional[str] = None): + if provider_name is not None: + llmeval_provider_name = self._get_llmeval_provider_name(provider_name) + requirements = self._get_missing_requirements_by_provider(provider_name) + self.check_missing_requirements(requirements) + metrics_client = self._get_metrics_client(llmeval_provider_name, reflector_model_name) self.reflection_pipeline = self._build_reflection_pipeline(metrics_client) return self.reflection_pipeline @@ -1029,7 +1032,7 @@ def _get_llmeval_provider_name(provider_name: str) -> str: return llmeval_provider_name @staticmethod - def _check_missing_requirements_by_provider(provider_name: str): + def _get_missing_requirements_by_provider(provider_name: str): provider_libs = { "watsonx": "ibm_watsonx_ai", "open-ai": "openai", @@ -1072,7 +1075,7 @@ async def map( prediction: ToolCall, references: None, task_data: Dict[str, Any], - ) -> PipelineResult: + ): from llmevalkit.function_calling.pipeline.types import PipelineResult # Convert unitxt dialog to LLMEvalKit format @@ -1098,7 +1101,7 @@ async def map( continue_on_static=True, ) return result.model_dump() - + def map_stream( self, From 2cac4148cc6c21ac40a19968642789da3b913bbf Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 13:21:09 +0300 Subject: [PATCH 08/14] Fixed pre-commit issues. --- .pre-commit-config.yaml | 4 +- .../evaluate_tool_calling_with_reflection.py | 5 +- src/unitxt/metrics.py | 119 ++++++++----- tests/library/test_metrics.py | 160 ++++++++++++------ 4 files changed, 184 insertions(+), 104 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b1b709f15..6f81ade2c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: hooks: - id: enforce-relative-imports name: Enforce Relative Imports - entry: python utils/enforce_relative_imports.py + entry: python3 utils/enforce_relative_imports.py language: system # Adjust the files pattern to match your needs files: ^src/.*\.py$ @@ -40,7 +40,7 @@ repos: hooks: - id: enforce-library-imports name: Enforce Library Imports - entry: python utils/enforce_library_imports.py + entry: python3 utils/enforce_library_imports.py language: system # Adjust the files pattern to match your needs exclude: (^src/.*\.py$)|utils/enforce_library_imports.py|utils/enforce_relative_imports.py diff --git a/examples/evaluate_tool_calling_with_reflection.py b/examples/evaluate_tool_calling_with_reflection.py index b52d74f5ba..2d69369dde 100644 --- a/examples/evaluate_tool_calling_with_reflection.py +++ b/examples/evaluate_tool_calling_with_reflection.py @@ -64,7 +64,10 @@ test_set=data, split="test", format="formats.chat_api", - metrics=["metrics.tool_calling.reflection.syntactic", "metrics.tool_calling.reflection"], + metrics=[ + "metrics.tool_calling.reflection.syntactic", + "metrics.tool_calling.reflection", + ], max_test_instances=10, ) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index d1ef5416d4..e615ae67fb 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1,13 +1,10 @@ import ast import asyncio -import contextlib import json import math import os -from queue import Empty, Queue import re import string -import threading import uuid import warnings from abc import ABC, abstractmethod @@ -18,7 +15,6 @@ from functools import lru_cache from typing import ( Any, - AsyncIterator, Dict, Generator, Generic, @@ -895,12 +891,15 @@ def map( return { "argument_schema_validation": argument_schema_validation, } - + class ReflectionToolCallingMixin: @staticmethod def convert_tools_inventory(tools): - from llmevalkit.function_calling.pipeline.types import ToolSpec as LLMEvalKitToolSpec + from llmevalkit.function_calling.pipeline.types import ( + ToolSpec as LLMEvalKitToolSpec, + ) + return [ LLMEvalKitToolSpec( type="function", @@ -911,7 +910,10 @@ def convert_tools_inventory(tools): @staticmethod def convert_tool_call(prediction: ToolCall): - from llmevalkit.function_calling.pipeline.types import ToolCall as LLMEvalKitToolCall + from llmevalkit.function_calling.pipeline.types import ( + ToolCall as LLMEvalKitToolCall, + ) + return LLMEvalKitToolCall( type="function", function={ @@ -920,11 +922,9 @@ def convert_tool_call(prediction: ToolCall): "parsed_arguments": prediction["arguments"], }, ) - -class ReflectionToolCallingMetric( - ReductionInstanceMetric[str, Dict[str, float]] -): + +class ReflectionToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]): """Measures syntactic and semantic validity of tool calls. The final output contains two main fields: "semantic" and "static" (i.e., semantic). @@ -981,6 +981,7 @@ class ReflectionToolCallingMetric( Reference: https://github.ibm.com/MLT/LLMEvalKit """ + main_score = "overall_valid" reduction = MeanReduction() prediction_type = ToolCall @@ -991,29 +992,37 @@ class ReflectionToolCallingMetric( def prepare(self): provider_to_default_reflector_model = { - "watsonx": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", - "open-ai": "gpt-4o", - "rits": "openai/gpt-oss-120b", - "azure": "gpt-4o", - "mock": "mock" + "watsonx": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + "open-ai": "gpt-4o", + "rits": "openai/gpt-oss-120b", + "azure": "gpt-4o", + "mock": "mock", } - provider = settings.default_provider if not settings.mock_inference_mode else "mock" + provider = ( + settings.default_provider if not settings.mock_inference_mode else "mock" + ) if provider not in provider_to_default_reflector_model: - raise ValueError(f"Unsupported provider for ReflectionToolCallingMetric: {provider}. Supported providers are: {list(provider_to_default_reflector_model.keys())}") + raise ValueError( + f"Unsupported provider for ReflectionToolCallingMetric: {provider}. Supported providers are: {list(provider_to_default_reflector_model.keys())}" + ) self.requirements = self._get_missing_requirements_by_provider(provider) super().prepare() self.setup_pipeline( reflector_model_name=provider_to_default_reflector_model.get(provider), - provider_name=provider + provider_name=provider, ) - def setup_pipeline(self, reflector_model_name: str, provider_name: Optional[str] = None): + def setup_pipeline( + self, reflector_model_name: str, provider_name: Optional[str] = None + ): if provider_name is not None: llmeval_provider_name = self._get_llmeval_provider_name(provider_name) requirements = self._get_missing_requirements_by_provider(provider_name) self.check_missing_requirements(requirements) - - metrics_client = self._get_metrics_client(llmeval_provider_name, reflector_model_name) + + metrics_client = self._get_metrics_client( + llmeval_provider_name, reflector_model_name + ) self.reflection_pipeline = self._build_reflection_pipeline(metrics_client) return self.reflection_pipeline @@ -1037,7 +1046,7 @@ def _get_missing_requirements_by_provider(provider_name: str): "watsonx": "ibm_watsonx_ai", "open-ai": "openai", "rits": "litellm", - "azure": "openai" + "azure": "openai", } required_lib = provider_libs.get(provider_name) return [required_lib] if required_lib else [] @@ -1045,17 +1054,19 @@ def _get_missing_requirements_by_provider(provider_name: str): @staticmethod def _get_metrics_client(llmeval_provider_name: str, reflector_model_name: str): from llmevalkit.llm import get_llm - MetricsClientCls = get_llm(llmeval_provider_name) - return MetricsClientCls(model_name=reflector_model_name) + + metrics_client_cls = get_llm(llmeval_provider_name) + return metrics_client_cls(model_name=reflector_model_name) def _build_reflection_pipeline(self, metrics_client): - from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline from llmevalkit.function_calling.consts import ( - METRIC_GENERAL_HALLUCINATION_CHECK, METRIC_AGENTIC_CONSTRAINTS_SATISFACTION, METRIC_FUNCTION_SELECTION_APPROPRIATENESS, - METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT + METRIC_GENERAL_HALLUCINATION_CHECK, + METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT, ) + from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline + return ReflectionPipeline( metrics_client=metrics_client, general_metrics=[ @@ -1085,9 +1096,11 @@ async def map( conversation_history = [{"role": "user", "content": task_data["query"]}] else: raise ValueError("task_data must contain either 'dialog' or 'query' field.") - + # Convert unitxt tool inventory to LLMEvalKit format - tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory( + task_data.get("tools", []) + ) # Convert unitxt tool call to LLMEvalKit format tool_call_converted = ReflectionToolCallingMixin.convert_tool_call(prediction) @@ -1102,16 +1115,14 @@ async def map( ) return result.model_dump() - def map_stream( self, items: Iterable[Tuple[ToolCall, None, Dict[str, Any]]], *, max_concurrency: int = 8, ) -> List[Dict[str, Any]]: - """ - Run self.map in parallel over an iterable and return results in order. - """ + """Run self.map in parallel over an iterable and return results in order.""" + async def process_all(): items_iter = iter(enumerate(items)) results = [] @@ -1133,28 +1144,41 @@ async def process_all(): break if not pending: break - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) for task in done: results.append((task.idx, await task)) results.sort() return [r for _, r in results] + return asyncio.run(process_all()) - + def reduce_one(self, intermidate: Dict[str, Any]) -> Dict[str, float]: return intermidate - + def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]: flat_instances = [] for instance in intermidates: flat_instance_dict = {} - for metric, metric_type_dict in instance.get("static", {}).get("metrics", {}).items(): - flat_instance_dict[f"static_{metric}"] = float(metric_type_dict["valid"]) + for metric, metric_type_dict in ( + instance.get("static", {}).get("metrics", {}).items() + ): + flat_instance_dict[f"static_{metric}"] = float( + metric_type_dict["valid"] + ) for metric_type, metric_type_dict in instance.get("semantic", {}).items(): if metric_type_dict is not None: - for metric, metric_dict in metric_type_dict.get("metrics", {}).items(): - flat_instance_dict[f"semantic_{metric}"] = float(metric_dict["is_issue"]) - flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(metric_type_dict.get("avg_score")) + for metric, metric_dict in metric_type_dict.get( + "metrics", {} + ).items(): + flat_instance_dict[f"semantic_{metric}"] = float( + metric_dict["is_issue"] + ) + flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float( + metric_type_dict.get("avg_score") + ) flat_instance_dict["overall_valid"] = float(instance["overall_valid"]) flat_instances.append(flat_instance_dict) @@ -1204,7 +1228,9 @@ def map( from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline # Convert unitxt tool inventory to LLMEvalKit format - tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(task_data.get("tools", [])) + tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory( + task_data.get("tools", []) + ) # Convert unitxt tool call to LLMEvalKit format tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction) @@ -1214,12 +1240,14 @@ def map( result_dict = static_result.model_dump() result_dict["overall_valid"] = result_dict.pop("final_decision") - result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop("json_schema_validation") + result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop( + "json_schema_validation" + ) return result_dict def reduce_one(self, intermidate: Dict[str, float]) -> Dict[str, float]: return intermidate - + def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]: flat_instances = [] for instance in intermediates: @@ -1231,6 +1259,7 @@ def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]: return self.reduction.reduce(flat_instances) + class MetricWithConfidenceInterval(Metric): # The number of resamples used to estimate the confidence intervals of this metric. # Use None to disable confidence interval computation. diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 36e633270c..e64f3d7842 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1,10 +1,8 @@ -import asyncio +import random from math import isnan from typing import Dict, List -import random import unitxt -from unitxt.types import Dialog, Tool from unitxt.api import create_dataset, evaluate from unitxt.inference import MockInferenceEngine from unitxt.llm_as_judge import LLMAsJudge, TaskBasedLLMasJudge @@ -85,7 +83,7 @@ check_scores, test_metric, ) -from unitxt.types import ToolCall +from unitxt.types import Dialog, Tool, ToolCall from tests.utils import UnitxtTestCase @@ -152,6 +150,7 @@ class TestMetrics(UnitxtTestCase): use_mock_model: bool = True + def test_unsorted_list_exact_match(self): metric = UnsortedListExactMatch() @@ -1665,34 +1664,41 @@ def test_reflection_tool_calling_metric(self): "arguments": { "location": "San Francisco", "days": 7, - "format": "summary", + "format": "summary", }, } ) - references = [] - task_data = { - "tools": [Tool( - { - "name": "advanced_weather", - "description": "Get advanced weather forecast", - "parameters": { - "type": "object", - "required": ["location", "days"], - "properties": { - "location": {"type": "string"}, - "days": {"type": "integer"}, - "format": { - "type": "string", - "enum": ["brief", "detailed", "summary"], - "default": "detailed" + "tools": [ + Tool( + { + "name": "advanced_weather", + "description": "Get advanced weather forecast", + "parameters": { + "type": "object", + "required": ["location", "days"], + "properties": { + "location": {"type": "string"}, + "days": {"type": "integer"}, + "format": { + "type": "string", + "enum": ["brief", "detailed", "summary"], + "default": "detailed", + }, }, }, - }, - }) + } + ) ], - "dialog": Dialog([{"role": "user", "content": "What's the weather like in San Francisco in the next 7 days? Give me a detailed forecast."}],) + "dialog": Dialog( + [ + { + "role": "user", + "content": "What's the weather like in San Francisco in the next 7 days? Give me a detailed forecast.", + } + ], + ), } # Call the map method @@ -1701,19 +1707,30 @@ def test_reflection_tool_calling_metric(self): # Verify all the scores self.assertEqual(result["overall_valid"], False) self.assertEqual(result["static"]["final_decision"], True) - self.assertEqual(result["semantic"]["general"]["metrics"]["general_hallucination_check"]["is_issue"], True) + self.assertEqual( + result["semantic"]["general"]["metrics"]["general_hallucination_check"][ + "is_issue" + ], + True, + ) unitxt_provider_name = "mock" if self.use_mock_model else "watsonx" model_name = "meta-llama/llama-4-maverick-17b-128e-instruct-fp8" - metric.setup_pipeline(reflector_model_name=model_name, provider_name=unitxt_provider_name) + metric.setup_pipeline( + reflector_model_name=model_name, provider_name=unitxt_provider_name + ) result = metric.map_stream(items=[(prediction, None, task_data)])[0] # Verify all the scores self.assertEqual(result["overall_valid"], False) self.assertEqual(result["static"]["final_decision"], True) - self.assertEqual(result["semantic"]["general"]["metrics"]["general_hallucination_check"]["is_issue"], True) - + self.assertEqual( + result["semantic"]["general"]["metrics"]["general_hallucination_check"][ + "is_issue" + ], + True, + ) def test_partial_value_precision_enum_violations_real_static_only(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -1767,9 +1784,13 @@ def test_partial_value_precision_enum_violations_real_static_only(self): # Assert partial value precision - all parameters exist in the schema # 4 out of 6 parameters have valid values (unit and format have enum violations) - self.assertAlmostEqual(result["metrics"]["allowed_values_violation"]["valid"], False) + self.assertAlmostEqual( + result["metrics"]["allowed_values_violation"]["valid"], False + ) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["metrics"]["missing_required_parameter"]["valid"], True) + self.assertAlmostEqual( + result["metrics"]["missing_required_parameter"]["valid"], True + ) def test_reflection_tool_calling_metric_reduce(self): # Instance 1: valid call @@ -1860,6 +1881,7 @@ def test_reflection_tool_calling_metric_reduce(self): # All outputs should be floats in [0, 1] import math + self.assertIsInstance(reduced, dict) for k, v in reduced.items(): self.assertIsInstance(v, float, msg=f"value for {k} is not float") @@ -1879,14 +1901,20 @@ def test_reflection_tool_calling_metric_reduce(self): self.assertIn("static_non_existent_function", reduced) self.assertIn("static_missing_required_parameter", reduced) self.assertAlmostEqual(reduced["static_non_existent_function"], nef_expected) - self.assertAlmostEqual(reduced["static_missing_required_parameter"], mrp_expected) + self.assertAlmostEqual( + reduced["static_missing_required_parameter"], mrp_expected + ) # Semantic per family average scores semantic_avg_expected = (1.0 + 0.0 + 0.5) / 3 self.assertIn("semantic_avg_score_general", reduced) self.assertIn("semantic_avg_score_function_selection", reduced) - self.assertAlmostEqual(reduced["semantic_avg_score_general"], semantic_avg_expected) - self.assertAlmostEqual(reduced["semantic_avg_score_function_selection"], semantic_avg_expected) + self.assertAlmostEqual( + reduced["semantic_avg_score_general"], semantic_avg_expected + ) + self.assertAlmostEqual( + reduced["semantic_avg_score_function_selection"], semantic_avg_expected + ) # Semantic issue rates for individual checks, flattened, mean of is_issue # general_hallucination_check: False, True, False -> 1/3 @@ -1903,12 +1931,18 @@ def test_reflection_tool_calling_metric_reduce(self): self.assertIn("semantic_function_selection_appropriateness", reduced) self.assertIn("semantic_agentic_constraints_satisfaction", reduced) - self.assertAlmostEqual(reduced["semantic_general_hallucination_check"], ghc_expected) - self.assertAlmostEqual(reduced["semantic_general_value_format_alignment"], gvfa_expected) - self.assertAlmostEqual(reduced["semantic_function_selection_appropriateness"], fsa_expected) - self.assertAlmostEqual(reduced["semantic_agentic_constraints_satisfaction"], acs_expected) - - + self.assertAlmostEqual( + reduced["semantic_general_hallucination_check"], ghc_expected + ) + self.assertAlmostEqual( + reduced["semantic_general_value_format_alignment"], gvfa_expected + ) + self.assertAlmostEqual( + reduced["semantic_function_selection_appropriateness"], fsa_expected + ) + self.assertAlmostEqual( + reduced["semantic_agentic_constraints_satisfaction"], acs_expected + ) def test_reflection_tool_calling_metric_syntactic_reduce(self): from unitxt.metrics import ReflectionToolCallingMetricSyntactic @@ -1964,8 +1998,9 @@ def test_reflection_tool_calling_metric_syntactic_reduce(self): reduced = metric.reduce(inputs) # 1) Key set is exactly metrics + overall_valid - metric_names = list(instance1["metrics"].keys()) - expected_keys = set(metric_names + ["overall_valid"]) + metric_names = [*instance1["metrics"]] + expected_keys = set(metric_names) | {"overall_valid"} + self.assertEqual(set(reduced.keys()), expected_keys) # 2) All values are floats in [0, 1] and not NaN @@ -1977,10 +2012,7 @@ def test_reflection_tool_calling_metric_syntactic_reduce(self): # 3) Per metric aggregation equals mean of valid flags across instances def mean_valid(name: str) -> float: - vals = [ - 1.0 if inst["metrics"][name]["valid"] else 0.0 - for inst in inputs - ] + vals = [1.0 if inst["metrics"][name]["valid"] else 0.0 for inst in inputs] return sum(vals) / len(vals) for m in metric_names: @@ -2001,8 +2033,6 @@ def mean_valid(name: str) -> float: reduced_shuffled = metric.reduce(shuffled) self.assertEqual(reduced, reduced_shuffled) - - def test_tool_calling_metric_syntactic_reflector(self): metric = ReflectionToolCallingMetricSyntactic() tools_data = { @@ -2066,7 +2096,9 @@ def test_tool_calling_metric_syntactic_reflector(self): self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) # param1 is present but param2 is missing out of 2 required parameters in the schema - self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + outputs["metrics"]["missing_required_parameter"]["valid"], False + ) # 1 valid parameter (param1) out of 2 total parameters (param1, wrongParam) self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], False) @@ -2102,7 +2134,9 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Recall should be 0 for empty arguments (missing required parameters) - self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + outputs["metrics"]["missing_required_parameter"]["valid"], False + ) # Precision is 1.0 because there are no invalid parameter names (no non-existent parameters) self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Value precision is 1.0 because there are no parameters with type or enum violations @@ -2120,7 +2154,9 @@ def test_tool_calling_metric_syntactic_reflector(self): # Recall should still be 0 since there are required parameters in the schema # (regardless of the references, which are not used for validation) - self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + outputs["metrics"]["missing_required_parameter"]["valid"], False + ) # Precision is 1.0 because there are no invalid parameter names self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Value precision is 1.0 because there are no parameters with type or enum violations @@ -2139,7 +2175,9 @@ def test_tool_calling_metric_syntactic_reflector(self): # overall_valid should be 0.0 because param2 is missing (it's required) self.assertEqual(outputs["overall_valid"], False) self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) - self.assertEqual(outputs["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + outputs["metrics"]["missing_required_parameter"]["valid"], False + ) self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) # Test case 7: Parameter types @@ -2291,7 +2329,9 @@ def test_missing_required_parameter_real_map(self): # Assert expected results self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + result["metrics"]["missing_required_parameter"]["valid"], False + ) self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], True) def test_non_existent_parameter_real_map(self): @@ -2418,7 +2458,9 @@ def test_multiple_parameter_issues_real_map(self): self.assertEqual(result["overall_valid"], False) # Missing 2 out of 2 required parameters - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + result["metrics"]["missing_required_parameter"]["valid"], False + ) # 2 invalid parameters out of 4 total self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) @@ -2552,7 +2594,9 @@ def test_partial_recall_missing_parameters_real_map(self): # Assert partial recall - 1 out of 2 required parameters provided self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], False) + self.assertEqual( + result["metrics"]["missing_required_parameter"]["valid"], False + ) def test_partial_precision_non_existent_parameters_real_map(self): """Test partial precision score when some parameters don't exist in the schema.""" @@ -2643,7 +2687,9 @@ def test_partial_value_precision_type_errors_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) + self.assertAlmostEqual( + result["metrics"]["incorrect_parameter_type"]["valid"], False + ) def test_partial_value_precision_enum_violations_real_map(self): """Test partial value precision when some parameters have invalid enum values.""" @@ -2695,7 +2741,9 @@ def test_partial_value_precision_enum_violations_real_map(self): result = metric.map(prediction, references, task_data) self.assertAlmostEqual(result["overall_valid"], False) - self.assertAlmostEqual(result["metrics"]["allowed_values_violation"]["valid"], False) + self.assertAlmostEqual( + result["metrics"]["allowed_values_violation"]["valid"], False + ) def test_tool_calling_key_value_metric(self): metric = ToolCallKeyValueExtraction(metric="metrics.accuracy") From 43a584b7545abc134b12fb7a8ed96de9d099d814 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 13:27:20 +0300 Subject: [PATCH 09/14] made sure that we reinstall libraries from git. --- .github/actions/install-internal-pip/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/install-internal-pip/action.yml b/.github/actions/install-internal-pip/action.yml index 912fe1f9fe..41de9d515f 100644 --- a/.github/actions/install-internal-pip/action.yml +++ b/.github/actions/install-internal-pip/action.yml @@ -30,4 +30,4 @@ runs: else URL="git+ssh://git@${{ inputs.host }}/${{ inputs.repo }}.git" fi - pip install "$URL" ${{ inputs.pip-extra-args }} \ No newline at end of file + pip install --no-cache-dir --force-reinstall "$URL" ${{ inputs.pip-extra-args }} From e229bab86a0d0345c970ea1465293dd73f6527ed Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 13:46:09 +0300 Subject: [PATCH 10/14] Add logging for installation URL and version info in internal pip action --- .github/actions/install-internal-pip/action.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/actions/install-internal-pip/action.yml b/.github/actions/install-internal-pip/action.yml index 41de9d515f..8c8f71e383 100644 --- a/.github/actions/install-internal-pip/action.yml +++ b/.github/actions/install-internal-pip/action.yml @@ -30,4 +30,7 @@ runs: else URL="git+ssh://git@${{ inputs.host }}/${{ inputs.repo }}.git" fi + echo "Installing from URL: $URL" pip install --no-cache-dir --force-reinstall "$URL" ${{ inputs.pip-extra-args }} + # Get and print the installed version/commit hash + python -c "import importlib.metadata; import sys; print(f'Installed {importlib.metadata.version(\"llmevalkit\")} from {sys.modules.get(\"llmevalkit\").__path__[0] if \"llmevalkit\" in sys.modules else \"not loaded\"}')" || echo "Failed to get LLMEvalKit version info" From e11b0a34e5f6db233b7472e4dc0d13e2754bfbc8 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Mon, 8 Sep 2025 14:13:22 +0300 Subject: [PATCH 11/14] minor change --- src/unitxt/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index e615ae67fb..0aff478446 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1015,7 +1015,7 @@ def prepare(self): def setup_pipeline( self, reflector_model_name: str, provider_name: Optional[str] = None ): - if provider_name is not None: + if provider_name: llmeval_provider_name = self._get_llmeval_provider_name(provider_name) requirements = self._get_missing_requirements_by_provider(provider_name) self.check_missing_requirements(requirements) From 1c05b4aa4c40de50e04c652621bcb47a0674e584 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Tue, 9 Sep 2025 11:19:54 +0300 Subject: [PATCH 12/14] fixed assignment of mock provider --- tests/library/test_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index e64f3d7842..297dd15e8d 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1655,7 +1655,7 @@ def test_tool_calling_metric(self): ) def test_reflection_tool_calling_metric(self): - unitxt.settings.default_provider = "mock" if self.use_mock_model else "rits" + unitxt.settings.mock_inference_mode = True metric = ReflectionToolCallingMetric() prediction = ToolCall( @@ -1875,7 +1875,7 @@ def test_reflection_tool_calling_metric_reduce(self): }, } - unitxt.settings.default_provider = "mock" if self.use_mock_model else "rits" + unitxt.settings.mock_inference_mode = True metric = ReflectionToolCallingMetric() reduced = metric.reduce([instance1, instance2, instance3]) From 31d2c6531761e9112768cdea2e449c5723b3788c Mon Sep 17 00:00:00 2001 From: Koren Lazar <44236526+korenLazar@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:22:28 +0300 Subject: [PATCH 13/14] Update .github/actions/install-internal-pip/action.yml Co-authored-by: Elron Bandel --- .github/actions/install-internal-pip/action.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/actions/install-internal-pip/action.yml b/.github/actions/install-internal-pip/action.yml index 8c8f71e383..aeb4619af7 100644 --- a/.github/actions/install-internal-pip/action.yml +++ b/.github/actions/install-internal-pip/action.yml @@ -32,5 +32,3 @@ runs: fi echo "Installing from URL: $URL" pip install --no-cache-dir --force-reinstall "$URL" ${{ inputs.pip-extra-args }} - # Get and print the installed version/commit hash - python -c "import importlib.metadata; import sys; print(f'Installed {importlib.metadata.version(\"llmevalkit\")} from {sys.modules.get(\"llmevalkit\").__path__[0] if \"llmevalkit\" in sys.modules else \"not loaded\"}')" || echo "Failed to get LLMEvalKit version info" From 8a438d6c4e1bfd342b6cf24f780d15e6c95dec43 Mon Sep 17 00:00:00 2001 From: Koren Lazar Date: Tue, 9 Sep 2025 11:50:26 +0300 Subject: [PATCH 14/14] removed two unittests that were causing problems and fixed assertEqual to assertFalse/assertTrue. --- tests/library/test_metrics.py | 221 ++++++++-------------------------- 1 file changed, 51 insertions(+), 170 deletions(-) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 297dd15e8d..2bf2a94eac 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1705,13 +1705,12 @@ def test_reflection_tool_calling_metric(self): result = metric.map_stream(items=[(prediction, None, task_data)])[0] # Verify all the scores - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["static"]["final_decision"], True) - self.assertEqual( + self.assertFalse(result["overall_valid"]) + self.assertTrue(result["static"]["final_decision"]) + self.assertTrue( result["semantic"]["general"]["metrics"]["general_hallucination_check"][ "is_issue" - ], - True, + ] ) unitxt_provider_name = "mock" if self.use_mock_model else "watsonx" @@ -1723,13 +1722,12 @@ def test_reflection_tool_calling_metric(self): result = metric.map_stream(items=[(prediction, None, task_data)])[0] # Verify all the scores - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["static"]["final_decision"], True) - self.assertEqual( + self.assertFalse(result["overall_valid"]) + self.assertTrue(result["static"]["final_decision"]) + self.assertTrue( result["semantic"]["general"]["metrics"]["general_hallucination_check"][ "is_issue" - ], - True, + ] ) def test_partial_value_precision_enum_violations_real_static_only(self): @@ -2064,8 +2062,8 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Exact match should be 1.0 when prediction and reference are identical - self.assertEqual(outputs["overall_valid"], True) - self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) + self.assertTrue(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) # Test case 2: Different tool name prediction = {"name": "different_tool", "arguments": {"param1": "value1"}} @@ -2076,8 +2074,8 @@ def test_tool_calling_metric_syntactic_reflector(self): task_data=tools_data, ) - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], False) + self.assertFalse(outputs["overall_valid"]) + self.assertFalse(outputs["metrics"]["non_existent_function"]["valid"]) # Test case 3: Different parameter names prediction = { @@ -2092,20 +2090,18 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Exact match should be 0.0, tool choice 1.0 - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) + self.assertFalse(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) # param1 is present but param2 is missing out of 2 required parameters in the schema - self.assertEqual( - outputs["metrics"]["missing_required_parameter"]["valid"], False - ) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # 1 valid parameter (param1) out of 2 total parameters (param1, wrongParam) - self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], False) + self.assertFalse(outputs["metrics"]["non_existent_parameter"]["valid"]) # Since wrongParam doesn't exist in the schema, value precision only considers param1 # param1 has the correct type, so value precision is 1.0 - self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) # Test case 4: Different parameter values prediction = { @@ -2121,7 +2117,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # Parameter choice should be 1.0 (all names match) # Note: overall_valid is 1.0 because the static validation only checks schema conformance, not value equality - self.assertEqual(outputs["overall_valid"], True) + self.assertTrue(outputs["overall_valid"]) # Test case 5a: Empty arguments prediction = {"name": "test_tool", "arguments": {}} @@ -2134,13 +2130,11 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # Recall should be 0 for empty arguments (missing required parameters) - self.assertEqual( - outputs["metrics"]["missing_required_parameter"]["valid"], False - ) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # Precision is 1.0 because there are no invalid parameter names (no non-existent parameters) - self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) prediction = {"name": "test_tool", "arguments": {}} reference = {"name": "test_tool", "arguments": {}} @@ -2154,13 +2148,11 @@ def test_tool_calling_metric_syntactic_reflector(self): # Recall should still be 0 since there are required parameters in the schema # (regardless of the references, which are not used for validation) - self.assertEqual( - outputs["metrics"]["missing_required_parameter"]["valid"], False - ) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) # Precision is 1.0 because there are no invalid parameter names - self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Value precision is 1.0 because there are no parameters with type or enum violations - self.assertEqual(outputs["metrics"]["incorrect_parameter_type"]["valid"], True) + self.assertTrue(outputs["metrics"]["incorrect_parameter_type"]["valid"]) # Test case 6: Multiple references with one match prediction = {"name": "test_tool", "arguments": {"param1": "value1"}} @@ -2173,12 +2165,10 @@ def test_tool_calling_metric_syntactic_reflector(self): ) # overall_valid should be 0.0 because param2 is missing (it's required) - self.assertEqual(outputs["overall_valid"], False) - self.assertEqual(outputs["metrics"]["non_existent_function"]["valid"], True) - self.assertEqual( - outputs["metrics"]["missing_required_parameter"]["valid"], False - ) - self.assertEqual(outputs["metrics"]["non_existent_parameter"]["valid"], True) + self.assertFalse(outputs["overall_valid"]) + self.assertTrue(outputs["metrics"]["non_existent_function"]["valid"]) + self.assertFalse(outputs["metrics"]["missing_required_parameter"]["valid"]) + self.assertTrue(outputs["metrics"]["non_existent_parameter"]["valid"]) # Test case 7: Parameter types prediction = { @@ -2209,7 +2199,7 @@ def test_tool_calling_metric_syntactic_reflector(self): # json_schema_violation is separate from the type validation # schema validation can still pass even if there are type errors - self.assertEqual(outputs["metrics"]["json_schema_violation"]["valid"], True) + self.assertTrue(outputs["metrics"]["json_schema_violation"]["valid"]) def test_overall_valid_success_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2254,9 +2244,9 @@ def test_overall_valid_success_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], True) - self.assertEqual(result["metrics"]["non_existent_function"]["valid"], True) - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) + self.assertTrue(result["overall_valid"]) + self.assertTrue(result["metrics"]["non_existent_function"]["valid"]) + self.assertTrue(result["metrics"]["missing_required_parameter"]["valid"]) def test_non_existent_function_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2288,9 +2278,9 @@ def test_non_existent_function_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["metrics"]["non_existent_function"]["valid"], False) - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["non_existent_function"]["valid"]) + self.assertTrue(result["metrics"]["missing_required_parameter"]["valid"]) def test_missing_required_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2328,11 +2318,9 @@ def test_missing_required_parameter_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual( - result["metrics"]["missing_required_parameter"]["valid"], False - ) - self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], True) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["missing_required_parameter"]["valid"]) + self.assertTrue(result["metrics"]["allowed_values_violation"]["valid"]) def test_non_existent_parameter_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2364,113 +2352,8 @@ def test_non_existent_parameter_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) - - def test_incorrect_parameter_type_real_map(self): - metric = ReflectionToolCallingMetricSyntactic() - # Create sample inputs with wrong parameter type - prediction = ToolCall(**{"name": "get_weather", "arguments": {"location": 42}}) - - references = [ - {"name": "get_weather", "arguments": {"location": "San Francisco"}} - ] - - task_data = { - "tools": [ - { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "required": ["location"], - "properties": {"location": {"type": "string"}}, - }, - }, - ] - } - - # Call the map method - result = metric.map(prediction, references, task_data) - - # Assert expected results - self.assertEqual(result["overall_valid"], False) - self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) - self.assertEqual(result["metrics"]["missing_required_parameter"]["valid"], True) - self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], True) - - def test_multiple_parameter_issues_real_map(self): - metric = ReflectionToolCallingMetricSyntactic() - # Create sample inputs with multiple issues - prediction = ToolCall( - **{ - "name": "get_weather", - "arguments": { - "unit": 123, - "format": True, - "unknown_param": "value1", - "extra_param": "value2", - }, - } - ) - - references = [ - { - "name": "get_weather", - "arguments": { - "location": "San Francisco", - "date": "2025-08-19", - "unit": "celsius", - "format": "brief", - }, - } - ] - - task_data = { - "tools": [ - { - "name": "get_weather", - "description": "Get the weather in a given location", - "parameters": { - "type": "object", - "required": ["location", "date"], - "properties": { - "location": {"type": "string"}, - "date": {"type": "string"}, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - "format": { - "type": "string", - "enum": ["brief", "detailed"], - }, - }, - }, - } - ] - } - - # Call the map method - result = metric.map(prediction, references, task_data) - - # Assert expected results - self.assertEqual(result["overall_valid"], False) - - # Missing 2 out of 2 required parameters - self.assertEqual( - result["metrics"]["missing_required_parameter"]["valid"], False - ) - - # 2 invalid parameters out of 4 total - self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) - - # Both unit and format have type issues and no other validation issues are reported - # So 0 out of 2 valid parameters (excluding the non-existent ones) - self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], False) - - # Schema validation might not reflect our expectations due to updated logic - # We're primarily interested in value precision calculation + self.assertFalse(result["overall_valid"], False) + self.assertFalse(result["metrics"]["non_existent_parameter"]["valid"]) def test_allowed_values_violation(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2513,10 +2396,10 @@ def test_allowed_values_violation(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) + self.assertFalse(result["overall_valid"]) # With 1 invalid enum value out of 2 parameters, value precision should be 0.5 - self.assertEqual(result["metrics"]["allowed_values_violation"]["valid"], False) - self.assertEqual(result["metrics"]["incorrect_parameter_type"]["valid"], True) + self.assertFalse(result["metrics"]["allowed_values_violation"]["valid"]) + self.assertTrue(result["metrics"]["incorrect_parameter_type"]["valid"]) def test_json_schema_violation_specific_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2551,12 +2434,12 @@ def test_json_schema_violation_specific_real_map(self): result = metric.map(prediction, references, task_data) # Assert expected results - self.assertEqual(result["overall_valid"], False) # Overall validation fails - self.assertEqual( - result["metrics"]["missing_required_parameter"]["valid"], False + self.assertFalse(result["overall_valid"]) # Overall validation fails + self.assertFalse( + result["metrics"]["missing_required_parameter"]["valid"] ) # Missing required param # json_schema_violation specifically should be 1.0 because it's marked valid - self.assertEqual(result["metrics"]["json_schema_violation"]["valid"], True) + self.assertTrue(result["metrics"]["json_schema_violation"]["valid"]) def test_partial_recall_missing_parameters_real_map(self): metric = ReflectionToolCallingMetricSyntactic() @@ -2593,10 +2476,8 @@ def test_partial_recall_missing_parameters_real_map(self): result = metric.map(prediction, references, task_data) # Assert partial recall - 1 out of 2 required parameters provided - self.assertEqual(result["overall_valid"], False) - self.assertEqual( - result["metrics"]["missing_required_parameter"]["valid"], False - ) + self.assertFalse(result["overall_valid"]) + self.assertFalse(result["metrics"]["missing_required_parameter"]["valid"]) def test_partial_precision_non_existent_parameters_real_map(self): """Test partial precision score when some parameters don't exist in the schema.""" @@ -2640,8 +2521,8 @@ def test_partial_precision_non_existent_parameters_real_map(self): result = metric.map(prediction, references, task_data) # Assert partial precision - 3 out of 6 parameters are valid - self.assertEqual(result["metrics"]["non_existent_parameter"]["valid"], False) - self.assertEqual(result["overall_valid"], False) + self.assertFalse(result["metrics"]["non_existent_parameter"]["valid"]) + self.assertFalse(result["overall_valid"]) def test_partial_value_precision_type_errors_real_map(self): """Test partial value precision when some parameters have incorrect types."""