diff --git a/src/ragas/llms/prompt.py b/src/ragas/llms/prompt.py index a97669cd5..5a3cb06e9 100644 --- a/src/ragas/llms/prompt.py +++ b/src/ragas/llms/prompt.py @@ -269,7 +269,7 @@ def get_all_keys(nested_json): return self - def save(self, cache_dir: t.Optional[str] = None) -> None: + def save(self, cache_dir: t.Optional[str] = None): cache_dir = cache_dir if cache_dir else get_cache_dir() cache_dir = os.path.join(cache_dir, self.language) if not os.path.exists(cache_dir): diff --git a/tests/unit/llms/test_prompt.py b/tests/unit/llms/test_prompt.py index 3f54a2618..12337a666 100644 --- a/tests/unit/llms/test_prompt.py +++ b/tests/unit/llms/test_prompt.py @@ -121,3 +121,12 @@ def test_prompt_object_names(): obj.name not in prompt_object_names ), f"Duplicate prompt name: {obj.name}" prompt_object_names.append(obj.name) + + +def test_save_and_load(tmp_path): + for testcase in TESTCASES: + prompt = Prompt(**testcase) + prompt.save(tmp_path) + loaded_prompt = prompt._load(prompt.language, prompt.name, tmp_path) + + assert prompt == loaded_prompt