From 752dba971b519c6ce3b5636ffc7618ce05d33ce0 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 9 Mar 2025 12:43:46 +0200 Subject: [PATCH] BugFix: Use dumping of task data and source only when dumping Signed-off-by: elronbandel --- src/unitxt/api.py | 12 +++++++----- src/unitxt/dataset.py | 7 +++++-- src/unitxt/schema.py | 28 +++++++++++++++++----------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/unitxt/api.py b/src/unitxt/api.py index 885fd6bff1..48bc42b614 100644 --- a/src/unitxt/api.py +++ b/src/unitxt/api.py @@ -21,7 +21,7 @@ from .logging_utils import get_logger from .metric_utils import EvaluationResults, _compute, _inference_post_process from .operator import SourceOperator -from .schema import loads_instance +from .schema import SerializeInstancesBeforeDump, loads_instance from .settings_utils import get_constants, get_settings from .standard import DatasetRecipe from .task import Task @@ -151,7 +151,7 @@ def _source_to_dataset( ) if split is not None: stream = {split: stream[split]} - ds_builder._generators = stream + ds_builder._generators = SerializeInstancesBeforeDump()(stream) ds_builder.download_and_prepare( verification_mode="no_checks", @@ -280,10 +280,12 @@ def produce( is_list = isinstance(instance_or_instances, list) if not is_list: instance_or_instances = [instance_or_instances] - result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances) + instances = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances) + serialize = SerializeInstancesBeforeDump() + instances = [serialize.process_instance(instance) for instance in instances] if not is_list: - return result[0] - return Dataset.from_list(result).with_transform(loads_instance) + return instances[0] + return Dataset.from_list(instances).with_transform(loads_instance) def infer( diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index 95b2d21667..b8df4c535b 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -46,7 +46,7 @@ from .random_utils import __file__ as _ from .recipe import __file__ as _ from .register import __file__ as _ -from .schema import loads_instance +from .schema import SerializeInstancesBeforeDump, loads_instance from .serializers import __file__ as _ from .settings_utils import get_constants from .span_lableing_operators import __file__ as _ @@ -54,6 +54,7 @@ from .splitters import __file__ as _ from .sql_utils import __file__ as _ from .standard import __file__ as _ +from .stream import MultiStream from .stream import __file__ as _ from .stream_operators import __file__ as _ from .string_operators import __file__ as _ @@ -91,7 +92,9 @@ def generators(self): logger.info("Loading with huggingface unitxt copy...") dataset = get_dataset_artifact(self.config.name) - self._generators = dataset() + multi_stream: MultiStream = dataset() + + self._generators = SerializeInstancesBeforeDump()(multi_stream) return self._generators diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index cbb85d8ad4..70f7480950 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -7,7 +7,7 @@ from .artifact import Artifact from .dict_utils import dict_get from .image_operators import ImageDataString -from .operator import InstanceOperatorValidator +from .operator import InstanceOperator, InstanceOperatorValidator from .settings_utils import get_constants, get_settings from .type_utils import isoftype from .types import Image @@ -87,6 +87,18 @@ def loads_instance(batch): return batch +class SerializeInstancesBeforeDump(InstanceOperator): + + def process( + self, instance: Dict[str, Any], stream_name: Optional[str] = None + ) -> Dict[str, Any]: + if settings.task_data_as_text: + instance["task_data"] = json.dumps(instance["task_data"]) + + if not isinstance(instance["source"], str): + instance["source"] = json.dumps(instance["source"]) + return instance + class FinalizeDataset(InstanceOperatorValidator): group_by: List[List[str]] remove_unnecessary_fields: bool = True @@ -126,13 +138,6 @@ def _get_instance_task_data( task_data = {**task_data, **instance["reference_fields"]} return task_data - def serialize_instance_fields(self, instance, task_data): - if settings.task_data_as_text: - instance["task_data"] = json.dumps(task_data) - - if not isinstance(instance["source"], str): - instance["source"] = json.dumps(instance["source"]) - return instance def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None @@ -157,7 +162,7 @@ def process( for instance in instance.pop(constants.demos_field) ] - instance = self.serialize_instance_fields(instance, task_data) + instance["task_data"] = task_data if self.remove_unnecessary_fields: keys_to_delete = [] @@ -202,7 +207,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None): instance, dict ), f"Instance should be a dict, got {type(instance)}" schema = get_schema(stream_name) + assert all( key in instance for key in schema - ), f"Instance should have the following keys: {schema}. Instance is: {instance}" - schema.encode_example(instance) + ), f"Instance should have the following keys: {schema.keys()}. Instance is: {instance.keys()}" + # schema.encode_example(instance)