Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def from_dict(cls, d, overwrite_args=None):

@classmethod
def load(cls, path, artifact_identifier=None, overwrite_args=None):
with open(path) as f:
d = json_loads_with_artifacts(f.read())
d = artifacts_json_cache(path)
if "__type__" in d and d["__type__"] == "artifact_link":
cls.from_dict(d) # for verifications and warnings
Expand Down Expand Up @@ -382,7 +384,9 @@ def save(self, path):
raise UnitxtError(
f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}"
)
save_to_file(path, self.to_json())
save_to_file(
path, json_dumps_with_artifacts(source=self, dump_source_as_dict=True)
)

def verify_instance(
self, instance: Dict[str, Any], name: Optional[str] = None
Expand Down Expand Up @@ -468,6 +472,29 @@ def __repr__(self):
return super().__repr__()


def json_dumps_with_artifacts(source, dump_source_as_dict=False):
def maybe_artifact_object_to_dict(obj):
if isinstance(obj, Artifact):
if (
dump_source_as_dict and obj.__id__ == source.__id__
) or obj.__id__ is None:
return obj.to_dict()
return obj.__id__
return obj

return json.dumps(source, default=maybe_artifact_object_to_dict, indent=4)


def maybe_artifact_dict_to_object(d):
if Artifact.is_artifact_dict(d):
return Artifact.from_dict(d)
return d


def json_loads_with_artifacts(s):
return json.loads(s, object_hook=maybe_artifact_dict_to_object)


class ArtifactLink(Artifact):
to: Artifact

Expand Down Expand Up @@ -636,7 +663,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
)

try:
data_classification = json.loads(data_classification)
data_classification = json_loads_with_artifacts(data_classification)
except json.decoder.JSONDecodeError as e:
raise RuntimeError(error_msg) from e

Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tqdm import tqdm, trange
from tqdm.asyncio import tqdm_asyncio

from .artifact import Artifact
from .artifact import Artifact, json_loads_with_artifacts
from .dataclass import InternalField, NonPositionalField
from .deprecation_utils import deprecation
from .error_utils import UnitxtError, UnitxtWarning
Expand Down Expand Up @@ -2484,7 +2484,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]:
task_data = instance["task_data"]
if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_loads_with_artifacts(task_data)
question = task_data.get("question")

images = [None]
Expand Down Expand Up @@ -2854,7 +2854,7 @@ def _infer(
task_data = instance["task_data"]

if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_loads_with_artifacts(task_data)

for option in task_data["options"]:
requests.append(
Expand Down
12 changes: 11 additions & 1 deletion src/unitxt/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Union

from .api import infer
from .artifact import fetch_artifact
from .artifact import fetch_artifact, json_loads_with_artifacts
from .dict_utils import dict_get
from .error_utils import UnitxtError
from .inference import (
Expand Down Expand Up @@ -47,6 +47,16 @@
from .task import Task
from .templates import Template


def get_task_data_dict(task_data):
# seems like the task data sometimes comes as a string, not a dict
# this fixes it
return (
json_loads_with_artifacts(task_data)
if isinstance(task_data, str)
else task_data
)

logger = get_logger(__name__)

class LLMJudge(BulkInstanceMetric):
Expand Down
27 changes: 14 additions & 13 deletions src/unitxt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datasets import Audio, Features, Sequence, Value
from datasets import Image as DatasetImage

from .artifact import Artifact
from .artifact import Artifact, json_dumps_with_artifacts, json_loads_with_artifacts
from .dict_utils import dict_get
from .image_operators import ImageDataString
from .operator import InstanceOperatorValidator
Expand All @@ -28,7 +28,7 @@
"audios": Sequence(Audio()),
},
"postprocessors": Sequence(Value("string")),
"task_data": Value(dtype="string"),
"task_data": Value("string"),
"data_classification_policy": Sequence(Value("string")),
}
)
Expand All @@ -40,7 +40,7 @@
"groups": Sequence(Value("string")),
"subset": Sequence(Value("string")),
"postprocessors": Sequence(Value("string")),
"task_data": Value(dtype="string"),
"task_data": Value("string"),
"data_classification_policy": Sequence(Value("string")),
"media": {
"images": Sequence(Image()),
Expand Down Expand Up @@ -76,13 +76,13 @@ def loads_batch(batch):
or batch["source"][0].startswith('[{"content":')
)
):
batch["source"] = [load_chat_source(d) for d in batch["source"]]
batch["source"] = [json_loads_with_artifacts(d) for d in batch["source"]]
if (
not settings.task_data_as_text
and "task_data" in batch
and isinstance(batch["task_data"][0], str)
):
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
batch["task_data"] = [json_loads_with_artifacts(d) for d in batch["task_data"]]
return batch

def loads_instance(instance):
Expand Down Expand Up @@ -145,10 +145,10 @@ def _get_instance_task_data(

def serialize_instance_fields(self, instance, task_data):
if settings.task_data_as_text:
instance["task_data"] = json.dumps(task_data)
instance["task_data"] = json_dumps_with_artifacts(task_data)

if not isinstance(instance["source"], str):
instance["source"] = json.dumps(instance["source"])
instance["source"] = json_dumps_with_artifacts(instance["source"])
return instance

def process(
Expand All @@ -163,9 +163,8 @@ def process(
task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][
"demos_pool_size"
]
task_data["metadata"]["template"] = self.artifact_to_jsonable(
instance["recipe_metadata"]["template"]
)
task_data["metadata"]["template"] = instance["recipe_metadata"]["template"]

if "criteria" in task_data and isinstance(task_data["criteria"], Artifact):
task_data["criteria"] = self.artifact_to_jsonable(task_data["criteria"])
if constants.demos_field in instance:
Expand Down Expand Up @@ -194,19 +193,21 @@ def process(
group_attributes = [group_attributes]
for attribute in group_attributes:
group[attribute] = dict_get(data, attribute)
groups.append(json.dumps(group))
groups.append(json_dumps_with_artifacts(group))

instance["groups"] = groups
instance["subset"] = []

instance = self._prepare_media(instance)

instance["metrics"] = [
metric.to_json() if isinstance(metric, Artifact) else metric
json_dumps_with_artifacts(metric) if not isinstance(metric, str) else metric
for metric in instance["metrics"]
]
instance["postprocessors"] = [
processor.to_json() if isinstance(processor, Artifact) else processor
json_dumps_with_artifacts(processor)
if not isinstance(processor, str)
else processor
for processor in instance["postprocessors"]
]

Expand Down
46 changes: 45 additions & 1 deletion tests/library/test_artifact_recovery.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json

from unitxt.artifact import (
Artifact,
MissingArtifactTypeError,
UnrecognizedArtifactTypeError,
json_dumps_with_artifacts,
json_loads_with_artifacts,
)
from unitxt.card import TaskCard
from unitxt.logging_utils import get_logger
from unitxt.templates import InputOutputTemplate

from tests.utils import UnitxtTestCase

Expand All @@ -15,12 +21,50 @@ def test_correct_artifact_recovery(self):
args = {
"__type__": "dataset_recipe",
"card": "cards.sst2",
"template_card_index": 0,
"template": {
"__type__": "input_output_template",
"input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}",
"output_format": "{output}",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
},
"demos_pool_size": 100,
"num_demos": 0,
}
a = Artifact.from_dict(args)
self.assertEqual(a.num_demos, 0)
self.assertIsInstance(a.template, InputOutputTemplate)

def test_correct_artifact_loading_with_json_loads(self):
args = {
"__type__": "standard_recipe",
"card": "cards.sst2",
"template": {
"__type__": "input_output_template",
"input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}",
"output_format": "{output}",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
},
"demos_pool_size": 100,
"num_demos": 0,
}

a = json_loads_with_artifacts(json.dumps(args))
self.assertEqual(a.num_demos, 0)

a = json_loads_with_artifacts(json.dumps({"x": args}))
self.assertEqual(a["x"].num_demos, 0)

self.assertIsInstance(a["x"].card, TaskCard)
self.assertIsInstance(a["x"].template, InputOutputTemplate)

d = json.loads(json_dumps_with_artifacts(a))
self.assertDictEqual(d, {"x": args})

def test_correct_artifact_recovery_with_overwrite(self):
args = {
Expand Down
Loading