From 426a171074dac1d61244a5851a6cd4e785b8a278 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 21 Nov 2024 09:17:14 -0500 Subject: [PATCH 01/12] Remove ImportBlock as a pipeline block While this was technically part of our public Python API, it appears to be entirely unused. Let's pull it out now to make syncing with the latest research advancements easier. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 2 - src/instructlab/sdg/importblock.py | 54 --------- src/instructlab/sdg/pipeline.py | 3 +- src/instructlab/sdg/pipelines/schema/v1.json | 11 -- tests/test_importblock.py | 109 ------------------- 5 files changed, 1 insertion(+), 178 deletions(-) delete mode 100644 src/instructlab/sdg/importblock.py delete mode 100644 tests/test_importblock.py diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index b2500ae3..784c03ee 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -11,7 +11,6 @@ "FilterByValueBlockError", "FlattenColumnsBlock", "GenerateException", - "ImportBlock", "LLMBlock", "Pipeline", "PipelineBlockError", @@ -30,7 +29,6 @@ from .block import Block from .filterblock import FilterByValueBlock, FilterByValueBlockError from .generate_data import generate_data -from .importblock import ImportBlock from .llmblock import ConditionalLLMBlock, LLMBlock from .pipeline import ( FULL_PIPELINES_PACKAGE, diff --git a/src/instructlab/sdg/importblock.py b/src/instructlab/sdg/importblock.py deleted file mode 100644 index e65a7f01..00000000 --- a/src/instructlab/sdg/importblock.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -import logging - -# Third Party -from datasets import Dataset - -# Local -from .block import Block - -logger = logging.getLogger(__name__) - - -# This is part of the public API. -class ImportBlock(Block): - def __init__( - self, - ctx, - pipe, - block_name, - path, - ) -> None: - """ - ImportBlock imports a chain of blocks from another pipeline config file. - - Parameters: - - ctx (PipelineContext): A PipelineContext object containing runtime parameters. - - pipe (Pipeline): The Pipeline containing this block in its chain. - - block_name (str): An identifier for this block. - - path (str): A path (absolute, or relative to the instructlab.sdg package) to a pipeline config file. - """ - super().__init__(ctx, pipe, block_name) - self.path = path - - # FIXME: find a better fix for this circular import error: - # - # src/instructlab/sdg/__init__.py:29: in - # from .importblock import ImportBlock - # src/instructlab/sdg/importblock.py:6: in - # from . import pipeline - # src/instructlab/sdg/pipeline.py:102: in - # "ImportBlock": importblock.ImportBlock, - # E AttributeError: partially initialized module 'src.instructlab.sdg.importblock' has no attribute 'ImportBlock' (most likely due to a circular import) - # - # pylint: disable=C0415 - # Local - from . import pipeline - - self.pipeline = pipeline.Pipeline.from_file(self.ctx, self.path) - - def generate(self, samples) -> Dataset: - logger.info("ImportBlock chaining to blocks from {self.path}") - return self.pipeline.generate(samples) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 52621f81..508bb9da 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -19,7 +19,7 @@ from instructlab.sdg.utils import pandas # Local -from . import filterblock, importblock, llmblock, utilblocks +from . import filterblock, llmblock, utilblocks from .block import Block logger = logging.getLogger(__name__) @@ -261,7 +261,6 @@ def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int] "DuplicateColumnsBlock": utilblocks.DuplicateColumnsBlock, "FilterByValueBlock": filterblock.FilterByValueBlock, "FlattenColumnsBlock": utilblocks.FlattenColumnsBlock, - "ImportBlock": importblock.ImportBlock, "LLMBlock": llmblock.LLMBlock, "RenameColumnsBlock": utilblocks.RenameColumnsBlock, "SamplePopulatorBlock": utilblocks.SamplePopulatorBlock, diff --git a/src/instructlab/sdg/pipelines/schema/v1.json b/src/instructlab/sdg/pipelines/schema/v1.json index 64b9c477..690d3d73 100644 --- a/src/instructlab/sdg/pipelines/schema/v1.json +++ b/src/instructlab/sdg/pipelines/schema/v1.json @@ -34,17 +34,6 @@ }, "config": { "anyOf": [ - { - "type": "object", - "description": "ImportBlock", - "required": ["path"], - "additionalProperties": false, - "properties": { - "path": { - "type": "string" - } - } - }, { "type": "object", "description": "FilterByValueBlock", diff --git a/tests/test_importblock.py b/tests/test_importblock.py deleted file mode 100644 index d13d59cc..00000000 --- a/tests/test_importblock.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -from unittest.mock import MagicMock, patch -import os -import tempfile -import unittest - -# Third Party -from datasets import Dataset, Features, Value - -# First Party -from instructlab.sdg.importblock import ImportBlock -from instructlab.sdg.pipeline import Pipeline - -# Local -from .conftest import get_single_threaded_ctx - - -class TestImportBlockWithMockPipeline(unittest.TestCase): - @patch("instructlab.sdg.pipeline.Pipeline") - def setUp(self, mock_pipeline): - self.ctx = get_single_threaded_ctx() - self.pipe = MagicMock() - self.block_name = "test_block" - self.path = "/path/to/config" - self.mock_pipeline = mock_pipeline - self.import_block = ImportBlock(self.ctx, self.pipe, self.block_name, self.path) - self.dataset = Dataset.from_dict({}) - - def test_initialization(self): - self.assertEqual(self.import_block.block_name, self.block_name) - self.assertEqual(self.import_block.path, self.path) - self.mock_pipeline.from_file.assert_called_once_with(self.ctx, self.path) - - def test_generate(self): - self.mock_pipeline.from_file.return_value.generate.return_value = self.dataset - samples = self.import_block.generate(self.dataset) - self.mock_pipeline.from_file.return_value.generate.assert_called_once_with( - samples - ) - self.assertEqual(samples, self.dataset) - - -_CHILD_YAML = """\ -version: "1.0" -blocks: -- name: greater_than_thirty - type: FilterByValueBlock - config: - filter_column: age - filter_value: 30 - operation: gt - convert_dtype: int -""" - - -_PARENT_YAML_FMT = """\ -version: "1.0" -blocks: -- name: forty_or_under - type: FilterByValueBlock - config: - filter_column: age - filter_value: 40 - operation: le - convert_dtype: int - default_value: 1000 -- name: import_child - type: ImportBlock - config: - path: %s -- name: big_bdays - type: FilterByValueBlock - config: - filter_column: age - filter_value: - - 30 - - 40 - operation: eq - convert_dtype: int -""" - - -class TestImportBlockWithFilterByValue(unittest.TestCase): - def setUp(self): - self.ctx = get_single_threaded_ctx() - self.child_yaml = self._write_tmp_yaml(_CHILD_YAML) - self.parent_yaml = self._write_tmp_yaml(_PARENT_YAML_FMT % self.child_yaml) - self.dataset = Dataset.from_dict( - {"age": ["25", "30", "35", "40", "45"]}, - features=Features({"age": Value("string")}), - ) - - def tearDown(self): - os.remove(self.parent_yaml) - os.remove(self.child_yaml) - - def _write_tmp_yaml(self, content): - tmp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".yaml") - tmp_file.write(content) - tmp_file.close() - return tmp_file.name - - def test_generate(self): - pipeline = Pipeline.from_file(self.ctx, self.parent_yaml) - filtered_dataset = pipeline.generate(self.dataset) - self.assertEqual(len(filtered_dataset), 1) - self.assertEqual(filtered_dataset["age"], [40]) From dee4424e059db313730b5aba9f792e21ccabe046 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 21 Nov 2024 10:06:12 -0500 Subject: [PATCH 02/12] Switch to Jinja2 templates for prompt templates This stubs in support for Jinja templates in the LLMBlock prompt templates, opening us up to more expressive prompts and handling things like loops that take a variable number of input elements when rendering templates. NOTE: This is a backwards-incompatible change in prompt templates. Any users that had custom pipelines specified will need to update their template variables to look like `{{variable}}` instead of `{variable}` as a result of this change. Co-authored-by: shivchander Co-authored-by: abhi1092 Signed-off-by: Ben Browning --- requirements.txt | 1 + src/instructlab/sdg/block.py | 24 +++++++ src/instructlab/sdg/llmblock.py | 49 ++++++------- tests/test_llmblock.py | 119 +++++++++++++++++++++++--------- 4 files changed, 135 insertions(+), 58 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3d577c7a..4f5bf7b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ GitPython>=3.1.42,<4.0.0 gguf>=0.6.0 httpx>=0.25.0,<1.0.0 instructlab-schema>=0.4.0 +jinja2>=3.0.0 langchain-text-splitters # Note: this dependency goes along with langchain-text-splitters and may be # removed once that one is removed. diff --git a/src/instructlab/sdg/block.py b/src/instructlab/sdg/block.py index 20205801..41cffd50 100644 --- a/src/instructlab/sdg/block.py +++ b/src/instructlab/sdg/block.py @@ -2,11 +2,13 @@ # Standard from abc import ABC +from collections import ChainMap from typing import Any, Dict, Union import logging import os.path # Third Party +from jinja2 import Template, UndefinedError import yaml logger = logging.getLogger(__name__) @@ -19,6 +21,28 @@ def __init__(self, ctx, pipe, block_name: str) -> None: self.pipe = pipe self.block_name = block_name + def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bool: + """ + Validate the input data for this block. This method validates whether all required + variables in the Jinja template are provided in the input_dict. + + :param prompt_template: The Jinja2 template object. + :param input_dict: A dictionary of input values to check against the template. + :return: True if the input data is valid (i.e., no missing variables), False otherwise. + """ + + class Default(dict): + def __missing__(self, key: str) -> None: + raise KeyError(key) + + try: + # Try rendering the template with the input_dict + prompt_template.render(ChainMap(input_dict, Default())) + return True + except UndefinedError as e: + logger.error(f"Missing key: {e}") + return False + def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: """ Load the configuration file for this block. diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 0e9a5f22..f63404f0 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from collections import ChainMap from typing import Any, Dict import logging import re # Third Party from datasets import Dataset +from jinja2 import StrictUndefined, Template from tqdm import tqdm import httpx import openai @@ -70,6 +70,12 @@ def server_supports_batched(client, model_id: str) -> bool: logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}") return supported +def template_from_struct_and_config(struct, config): + # replace None with empty strings + filtered_config = { + k: (v if v is not None else "") for k, v in config.items() + } + return Template(struct.format(**filtered_config), undefined=StrictUndefined) # This is part of the public API. # pylint: disable=dangerous-default-value @@ -92,7 +98,7 @@ def __init__( self.prompt_struct = ( """{system}\n{introduction}\n{principles}\n{examples}\n{generation}""" ) - self.prompt_template = self.prompt_struct.format(**self.block_config) + self.prompt_template = template_from_struct_and_config(self.prompt_struct, self.block_config) self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs @@ -162,7 +168,7 @@ def _parse(self, generated_string) -> dict: # 2. Non-empty string - the pipeline has specified a custom model prompt # 3. Empty string - the pipeline has specified that no model prompt is needed def _format_prompt(self, sample: Dict) -> str: - prompt = self.prompt_template.format(**sample).strip() + prompt = self.prompt_template.render(sample).strip() model_prompt = None if self.model_prompt is None: @@ -265,25 +271,6 @@ def generate(self, samples: Dataset) -> Dataset: return Dataset.from_list(new_data) - def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: - """ - Validate the input data for this block. This method should be implemented by subclasses - to define how the block validates its input data. - - :return: True if the input data is valid, False otherwise. - """ - - class Default(dict): - def __missing__(self, key: str) -> None: - raise KeyError(key) - - try: - prompt_template.format_map(ChainMap(input_dict, Default())) - return True - except KeyError as e: - logger.error("Missing key: {}".format(e)) - return False - # This is part of the public API. class ConditionalLLMBlock(LLMBlock): @@ -300,6 +287,9 @@ def __init__( parser_kwargs={}, batch_kwargs={}, ) -> None: + assert config_paths, "ConditionalLLMBlock config_paths requires at least one entry" + for config_path in config_paths: + assert len(config_path) == 2, "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names" super().__init__( ctx, pipe, @@ -314,12 +304,10 @@ def __init__( self.selector_column_name = selector_column_name self.prompt_template = {} if len(config_paths) == 1 and config_paths[0][1] == "All": - self.prompt_template = self.prompt_struct.format(**self.block_config) + self.prompt_template = template_from_struct_and_config(self.prompt_struct, self.block_config) else: for config, config_key in config_paths: - self.prompt_template[config_key] = self.prompt_struct.format( - **self._load_config(config) - ) + self.prompt_template[config_key] = template_from_struct_and_config(self.prompt_struct, self._load_config(config)) def _format_prompt(self, sample: Dict) -> str: if isinstance(self.prompt_template, dict): @@ -333,5 +321,12 @@ def _format_prompt(self, sample: Dict) -> str: def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: if isinstance(prompt_template, dict): - prompt_template = prompt_template[input_dict[self.selector_column_name]] + if not self.selector_column_name in input_dict: + logger.error(f"ConditionalLLMBlock {self.block_name} missing key: {self.selector_column_name}") + return False + config_key = input_dict[self.selector_column_name] + if not config_key in prompt_template: + logger.error(f"ConditionalLLMBlock {self.block_name} selector key {config_key} not found in block config") + return False + prompt_template = prompt_template[config_key] return super()._validate(prompt_template, input_dict) diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 613ea846..7f2a1037 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -7,12 +7,13 @@ # Third Party from datasets import Dataset, Features, Value from httpx import URL -from openai import InternalServerError, NotFoundError, OpenAI +from openai import InternalServerError, NotFoundError # First Party -from src.instructlab.sdg.llmblock import LLMBlock, server_supports_batched +from src.instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock, server_supports_batched +@patch("src.instructlab.sdg.block.Block._load_config") class TestLLMBlockModelPrompt(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() @@ -20,7 +21,7 @@ def setUp(self): self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() self.config_return_value = { - "system": "{fruit}", + "system": "{{fruit}}", "introduction": "introduction", "principles": "principles", "examples": "examples", @@ -31,7 +32,6 @@ def setUp(self): features=Features({"fruit": Value("string")}), ) - @patch("src.instructlab.sdg.block.Block._load_config") def test_model_prompt_empty_string(self, mock_load_config): mock_load_config.return_value = self.config_return_value # Ensure that if an empty model_prompt is not specified, no model prompt is used. @@ -50,7 +50,6 @@ def test_model_prompt_empty_string(self, mock_load_config): "no model prompt should be used when explicitly set to an empty string", ) - @patch("src.instructlab.sdg.block.Block._load_config") def test_model_prompt_none(self, mock_load_config): mock_load_config.return_value = self.config_return_value # Ensure that if a custom model_prompt is not specified, it defaults to setting it to @@ -70,8 +69,7 @@ def test_model_prompt_none(self, mock_load_config): "model_prompt based on model_family should be used set to None", ) - @patch("src.instructlab.sdg.block.Block._load_config") - def test_model_prompt_none(self, mock_load_config): + def test_model_prompt_custom(self, mock_load_config): mock_load_config.return_value = self.config_return_value # Ensure that if a custom model_prompt is specified, it is used correctly block = LLMBlock( @@ -89,11 +87,25 @@ def test_model_prompt_none(self, mock_load_config): "model_prompt should be a non-empty string when set to None", ) - @patch("src.instructlab.sdg.block.Block._load_config") +@patch("src.instructlab.sdg.block.Block._load_config") +class TestLLMBlockOtherFunctions(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + def test_max_num_tokens_override(self, mock_load_config): mock_load_config.return_value = self.config_return_value self.mock_ctx.max_num_tokens = 512 - # Ensure that if a custom model_prompt is specified, it is used correctly + # Ensure that if max_tokens is specified, it is used correctly block = LLMBlock( ctx=self.mock_ctx, pipe=self.mock_pipe, @@ -106,26 +118,48 @@ def test_max_num_tokens_override(self, mock_load_config): num_tokens = block.gen_kwargs["max_tokens"] assert num_tokens == 512 + def test_validate(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{var1}} {{var2}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + + assert not block._validate(block.prompt_template, {}) + assert block._validate(block.prompt_template, {"var1": "foo", "var2": "bar"}) + +class TestLLMBlockBatching(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.mock_client = MagicMock() + self.mock_client.server_supports_batched = None + self.mock_client.base_url = URL("http://localhost:8000/v1") + self.mock_client.get = MagicMock() + self.mock_ctx.client = self.mock_client + def test_server_supports_batched_llama_cpp(self): resp_text = """{"message":"Hello from InstructLab! Visit us at https://instructlab.ai"}""" - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.return_value = MagicMock() - mock_client.get().text = resp_text - self.mock_ctx.client = mock_client + self.mock_client.get.return_value = MagicMock() + self.mock_client.get().text = resp_text supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert not supports_batched def test_server_supports_batched_other_llama_cpp(self): resp_text = "another server" - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.return_value = MagicMock() - mock_client.get().text = resp_text + self.mock_client.get.return_value = MagicMock() + self.mock_client.get().text = resp_text mock_completion = MagicMock() mock_completion.create = MagicMock() mock_completion.create.side_effect = InternalServerError( @@ -133,17 +167,12 @@ def test_server_supports_batched_other_llama_cpp(self): response=MagicMock(), body=MagicMock(), ) - mock_client.completions = mock_completion - self.mock_ctx.client = mock_client + self.mock_client.completions = mock_completion supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert not supports_batched def test_server_supports_batched_vllm(self): - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.side_effect = NotFoundError( + self.mock_client.get.side_effect = NotFoundError( "mock error", response=MagicMock(), body=MagicMock(), @@ -153,7 +182,35 @@ def test_server_supports_batched_vllm(self): mock_completion = MagicMock() mock_completion.create = MagicMock() mock_completion.create.return_value = mock_completion_resp - mock_client.completions = mock_completion - self.mock_ctx.client = mock_client + self.mock_client.completions = mock_completion supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert supports_batched + +@patch("src.instructlab.sdg.block.Block._load_config") +class TestConditionalLLMBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + + def test_validate(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{var1}} {{var2}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + block = ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[["/foo/bar", "_A_"]], + output_cols=[], + selector_column_name="selector", + ) + + assert not block._validate(block.prompt_template, {}) + assert not block._validate(block.prompt_template, {"selector": "_B_", "var1": "foo", "var2": "bar"}) + assert block._validate(block.prompt_template, {"selector": "_A_", "var1": "foo", "var2": "bar"}) From 5e0cc2324dca260107d2982211ab42487ebcb683 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Fri, 22 Nov 2024 09:14:17 -0500 Subject: [PATCH 03/12] Add BlockRegistry and PromptRegistry The Block and Prompt registries are how we keep track of what our supported Block types are and which Prompts map to which teacher models. Co-authored-by: shivchander Co-authored-by: abhi1092 Signed-off-by: Ben Browning --- pyproject.toml | 4 +- src/instructlab/sdg/__init__.py | 24 ++-- src/instructlab/sdg/blocks/__init__.py | 0 src/instructlab/sdg/{ => blocks}/block.py | 4 + .../sdg/{ => blocks}/filterblock.py | 0 src/instructlab/sdg/{ => blocks}/llmblock.py | 3 + .../sdg/{ => blocks}/utilblocks.py | 0 src/instructlab/sdg/generate_data.py | 2 +- src/instructlab/sdg/pipeline.py | 4 +- src/instructlab/sdg/registry.py | 120 ++++++++++++++++++ tests/test_default_pipeline_configs.py | 12 +- tests/test_filterblock.py | 3 +- tests/test_generate_data.py | 4 +- tests/test_llmblock.py | 9 +- tests/test_pipeline.py | 4 +- tests/test_registry.py | 12 ++ tests/test_sample_populator_block.py | 4 +- tests/test_utilblocks.py | 2 +- 18 files changed, 176 insertions(+), 35 deletions(-) create mode 100644 src/instructlab/sdg/blocks/__init__.py rename src/instructlab/sdg/{ => blocks}/block.py (96%) rename src/instructlab/sdg/{ => blocks}/filterblock.py (100%) rename src/instructlab/sdg/{ => blocks}/llmblock.py (99%) rename src/instructlab/sdg/{ => blocks}/utilblocks.py (100%) create mode 100644 src/instructlab/sdg/registry.py create mode 100644 tests/test_registry.py diff --git a/pyproject.toml b/pyproject.toml index 88978472..aceddb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,8 +97,8 @@ exclude = [ "^src/instructlab/sdg/generate_data\\.py$", "^src/instructlab/sdg/utils/taxonomy\\.py$", "^src/instructlab/sdg/default_flows\\.py$", - "^src/instructlab/sdg/llmblock\\.py$", - "^src/instructlab/sdg/utilblocks\\.py$", + "^src/instructlab/sdg/blocks/llmblock\\.py$", + "^src/instructlab/sdg/blocks/utilblocks\\.py$", ] # honor excludes by not following there through imports follow_imports = "silent" diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 784c03ee..6ed09cb4 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -26,10 +26,19 @@ ) # Local -from .block import Block -from .filterblock import FilterByValueBlock, FilterByValueBlockError +from .blocks.block import Block +from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError +from .blocks.llmblock import ConditionalLLMBlock, LLMBlock +from .blocks.utilblocks import ( + CombineColumnsBlock, + DuplicateColumnsBlock, + FlattenColumnsBlock, + RenameColumnsBlock, + SamplePopulatorBlock, + SelectorBlock, + SetToMajorityValueBlock, +) from .generate_data import generate_data -from .llmblock import ConditionalLLMBlock, LLMBlock from .pipeline import ( FULL_PIPELINES_PACKAGE, SIMPLE_PIPELINES_PACKAGE, @@ -39,14 +48,5 @@ PipelineConfigParserError, PipelineContext, ) -from .utilblocks import ( - CombineColumnsBlock, - DuplicateColumnsBlock, - FlattenColumnsBlock, - RenameColumnsBlock, - SamplePopulatorBlock, - SelectorBlock, - SetToMajorityValueBlock, -) from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/blocks/__init__.py b/src/instructlab/sdg/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/instructlab/sdg/block.py b/src/instructlab/sdg/blocks/block.py similarity index 96% rename from src/instructlab/sdg/block.py rename to src/instructlab/sdg/blocks/block.py index 41cffd50..1e0a39b2 100644 --- a/src/instructlab/sdg/block.py +++ b/src/instructlab/sdg/blocks/block.py @@ -11,10 +11,14 @@ from jinja2 import Template, UndefinedError import yaml +# Local +from ..registry import BlockRegistry + logger = logging.getLogger(__name__) # This is part of the public API. +@BlockRegistry.register("Block") class Block(ABC): def __init__(self, ctx, pipe, block_name: str) -> None: self.ctx = ctx diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/blocks/filterblock.py similarity index 100% rename from src/instructlab/sdg/filterblock.py rename to src/instructlab/sdg/blocks/filterblock.py diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py similarity index 99% rename from src/instructlab/sdg/llmblock.py rename to src/instructlab/sdg/blocks/llmblock.py index f63404f0..d84b08fa 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -13,6 +13,7 @@ import openai # Local +from ..registry import BlockRegistry from .block import Block logger = logging.getLogger(__name__) @@ -78,6 +79,7 @@ def template_from_struct_and_config(struct, config): return Template(struct.format(**filtered_config), undefined=StrictUndefined) # This is part of the public API. +@BlockRegistry.register("LLMBlock") # pylint: disable=dangerous-default-value class LLMBlock(Block): # pylint: disable=too-many-instance-attributes @@ -273,6 +275,7 @@ def generate(self, samples: Dataset) -> Dataset: # This is part of the public API. +@BlockRegistry.register("ConditionalLLMBlock") class ConditionalLLMBlock(LLMBlock): def __init__( self, diff --git a/src/instructlab/sdg/utilblocks.py b/src/instructlab/sdg/blocks/utilblocks.py similarity index 100% rename from src/instructlab/sdg/utilblocks.py rename to src/instructlab/sdg/blocks/utilblocks.py diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index aac007b7..70a09efd 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -21,7 +21,7 @@ # pylint: disable=ungrouped-imports from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init -from instructlab.sdg.llmblock import ( +from instructlab.sdg.blocks.llmblock import ( DEFAULT_MAX_NUM_TOKENS, MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL, diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 508bb9da..66161155 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -19,8 +19,8 @@ from instructlab.sdg.utils import pandas # Local -from . import filterblock, llmblock, utilblocks -from .block import Block +from .blocks import filterblock, llmblock, utilblocks +from .blocks.block import Block logger = logging.getLogger(__name__) diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py new file mode 100644 index 00000000..128446f8 --- /dev/null +++ b/src/instructlab/sdg/registry.py @@ -0,0 +1,120 @@ +# Standard +from typing import Dict, List, Union +import logging + +# Third Party +from jinja2 import Template + +logger = logging.getLogger(__name__) + + +class BlockRegistry: + """Registry for block classes to avoid manual additions to block type map.""" + + _registry: Dict[str, type] = {} + + @classmethod + def register(cls, block_name: str): + """ + Decorator to register a block class under a specified name. + + :param block_name: Name under which to register the block. + """ + + def decorator(block_class): + cls._registry[block_name] = block_class + logger.debug( + f"Registered block '{block_name}' with class '{block_class.__name__}'" + ) + return block_class + + return decorator + + @classmethod + def get_registry(cls): + """ + Retrieve the current registry map of block types. + + :return: Dictionary of registered block names and classes. + """ + logger.debug("Fetching the block registry map.") + return cls._registry + + +class PromptRegistry: + """Registry for managing Jinja2 prompt templates.""" + + _registry: Dict[str, Template] = {} + + @classmethod + def register(cls, name: str): + """Decorator to register a Jinja2 template function by name. + + :param name: Name of the template to register. + :return: A decorator that registers the Jinja2 template function. + """ + + def decorator(func): + template_str = func() + cls._registry[name] = Template(template_str) + logger.debug(f"Registered prompt template '{name}'") + return func + + return decorator + + @classmethod + def get_template(cls, name: str) -> Template: + """Retrieve a Jinja2 template by name. + + :param name: Name of the template to retrieve. + :return: The Jinja2 template instance. + """ + if name not in cls._registry: + raise KeyError(f"Template '{name}' not found.") + logger.debug(f"Retrieving prompt template '{name}'") + return cls._registry[name] + + @classmethod + def get_registry(cls): + """ + Retrieve the current registry map of block types. + + :return: Dictionary of registered block names and classes. + """ + logger.debug("Fetching the block registry map.") + return cls._registry + + @classmethod + def render_template( + cls, + name: str, + messages: Union[str, List[Dict[str, str]]], + add_generation_prompt: bool = True, + ) -> str: + """Render the template with the provided messages or query. + + :param name: Name of the template to render. + :param messages: Either a single query string or a list of messages (each as a dict with 'role' and 'content'). + :param add_generation_prompt: Whether to add a generation prompt at the end. + :return: The rendered prompt as a string. + """ + + # Special handling for "blank" template + if name == "blank": + if not isinstance(messages, str): + raise ValueError( + "The 'blank' template can only be used with a single query string, not a list of messages." + ) + return messages # Return the query as-is without templating + + # Get the template + template = cls.get_template(name) + + # If `messages` is a string, wrap it in a list with a default user role + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + # Render the template with the `messages` list + return template.render( + messages=messages, add_generation_prompt=add_generation_prompt + ) diff --git a/tests/test_default_pipeline_configs.py b/tests/test_default_pipeline_configs.py index a774fb2d..8009874b 100644 --- a/tests/test_default_pipeline_configs.py +++ b/tests/test_default_pipeline_configs.py @@ -9,13 +9,15 @@ from datasets import Dataset # First Party -from instructlab.sdg.filterblock import FilterByValueBlock -from instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock -from instructlab.sdg.pipeline import Pipeline, PipelineContext -from instructlab.sdg.utilblocks import ( +from instructlab.sdg import ( CombineColumnsBlock, + ConditionalLLMBlock, DuplicateColumnsBlock, + FilterByValueBlock, FlattenColumnsBlock, + LLMBlock, + Pipeline, + PipelineContext, RenameColumnsBlock, SamplePopulatorBlock, SelectorBlock, @@ -35,7 +37,7 @@ def _noop_generate(self, samples): @patch.object(RenameColumnsBlock, "generate", _noop_generate) @patch.object(SamplePopulatorBlock, "generate", _noop_generate) @patch.object(SelectorBlock, "generate", _noop_generate) -@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True) +@patch("instructlab.sdg.blocks.llmblock.server_supports_batched", lambda c, m: True) @patch.object(Pipeline, "_drop_duplicates", lambda self, dataset, cols: dataset) class TestDefaultPipelineConfigs(unittest.TestCase): def setUp(self): diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index e84258ab..ab43890a 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -9,8 +9,7 @@ from datasets import Dataset, Features, Value # First Party -from instructlab.sdg.filterblock import FilterByValueBlock -from instructlab.sdg.pipeline import PipelineContext +from instructlab.sdg import FilterByValueBlock class TestFilterByValueBlock(unittest.TestCase): diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index d3f6ced5..23b3dcb9 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -21,8 +21,8 @@ # First Party from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data -from instructlab.sdg.llmblock import LLMBlock -from instructlab.sdg.pipeline import PipelineContext +from instructlab.sdg import LLMBlock +from instructlab.sdg import PipelineContext TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant." diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 7f2a1037..73b96ea0 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -10,10 +10,11 @@ from openai import InternalServerError, NotFoundError # First Party -from src.instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock, server_supports_batched +from src.instructlab.sdg import ConditionalLLMBlock, LLMBlock +from src.instructlab.sdg.blocks.llmblock import server_supports_batched -@patch("src.instructlab.sdg.block.Block._load_config") +@patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMBlockModelPrompt(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() @@ -87,7 +88,7 @@ def test_model_prompt_custom(self, mock_load_config): "model_prompt should be a non-empty string when set to None", ) -@patch("src.instructlab.sdg.block.Block._load_config") +@patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMBlockOtherFunctions(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() @@ -186,7 +187,7 @@ def test_server_supports_batched_vllm(self): supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert supports_batched -@patch("src.instructlab.sdg.block.Block._load_config") +@patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestConditionalLLMBlock(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7367161c..a07df5ef 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,8 +14,8 @@ import pytest # First Party -from instructlab.sdg.block import Block -from instructlab.sdg.pipeline import Pipeline, PipelineBlockError +from instructlab.sdg import Block +from instructlab.sdg import Pipeline, PipelineBlockError ## Helpers ## diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 00000000..ddad7001 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 + +# First Party +from src.instructlab.sdg.registry import BlockRegistry + +def test_block_registry(): + @BlockRegistry.register("TestFooClass") + class TestFooClass: + pass + registry = BlockRegistry.get_registry() + assert registry is not None + assert registry["TestFooClass"] is TestFooClass diff --git a/tests/test_sample_populator_block.py b/tests/test_sample_populator_block.py index e19eed31..f363b9d0 100644 --- a/tests/test_sample_populator_block.py +++ b/tests/test_sample_populator_block.py @@ -8,7 +8,7 @@ from datasets import Dataset, Features, Value # First Party -from instructlab.sdg.utilblocks import SamplePopulatorBlock +from instructlab.sdg import SamplePopulatorBlock class TestSamplePopulatorBlock(unittest.TestCase): @@ -17,7 +17,7 @@ def setUp(self): self.ctx.dataset_num_procs = 1 self.pipe = MagicMock() - @patch("instructlab.sdg.block.Block._load_config") + @patch("instructlab.sdg.blocks.block.Block._load_config") def test_generate(self, mock_load_config): def load_config(file_name): if file_name == "coffee.yaml" or file_name == "tea.yaml": diff --git a/tests/test_utilblocks.py b/tests/test_utilblocks.py index 5ac6233f..a8849d6a 100644 --- a/tests/test_utilblocks.py +++ b/tests/test_utilblocks.py @@ -8,7 +8,7 @@ from datasets import Dataset, Features, Value # First Party -from src.instructlab.sdg.utilblocks import ( +from src.instructlab.sdg import ( DuplicateColumnsBlock, FlattenColumnsBlock, RenameColumnsBlock, From 3b2bc7da7bd9a9ee5df842d38b4e776dc3b8e66d Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 25 Nov 2024 15:58:13 -0500 Subject: [PATCH 04/12] Move model prompts to jinja templates and messages This brings in changes to move our model prompt templates to Jinja templates and the HuggingFace messages formats, used by their chat templates. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 5 +- src/instructlab/sdg/blocks/llmblock.py | 76 +++++++++++++++----------- src/instructlab/sdg/generate_data.py | 8 +-- src/instructlab/sdg/prompts.py | 20 +++++++ src/instructlab/sdg/registry.py | 39 +------------ tests/test_generate_data.py | 3 +- tests/test_llmblock.py | 6 +- tests/test_pipeline.py | 3 +- tests/test_registry.py | 2 + 9 files changed, 79 insertions(+), 83 deletions(-) create mode 100644 src/instructlab/sdg/prompts.py diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 6ed09cb4..e5dceeae 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -20,8 +20,10 @@ "SamplePopulatorBlock", "SelectorBlock", "SetToMajorityValueBlock", - "SIMPLE_PIPELINES_PACKAGE", + "MODEL_FAMILY_MERLINITE", + "MODEL_FAMILY_MIXTRAL", "FULL_PIPELINES_PACKAGE", + "SIMPLE_PIPELINES_PACKAGE", "generate_data", ) @@ -48,5 +50,6 @@ PipelineConfigParserError, PipelineContext, ) +from .prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index d84b08fa..a7744249 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -13,30 +13,13 @@ import openai # Local -from ..registry import BlockRegistry +from ..registry import BlockRegistry, PromptRegistry from .block import Block logger = logging.getLogger(__name__) DEFAULT_MAX_NUM_TOKENS = 4096 -MODEL_FAMILY_MIXTRAL = "mixtral" -MODEL_FAMILY_MERLINITE = "merlinite" - -_MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" -_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" - -_MODEL_PROMPTS = { - MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL, - MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE, -} - - -def _get_model_prompt(model_family): - if model_family not in _MODEL_PROMPTS: - raise ValueError(f"Unknown model family: {model_family}") - return _MODEL_PROMPTS[model_family] - def server_supports_batched(client, model_id: str) -> bool: supported = getattr(client, "server_supports_batched", None) @@ -71,13 +54,13 @@ def server_supports_batched(client, model_id: str) -> bool: logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}") return supported + def template_from_struct_and_config(struct, config): # replace None with empty strings - filtered_config = { - k: (v if v is not None else "") for k, v in config.items() - } + filtered_config = {k: (v if v is not None else "") for k, v in config.items()} return Template(struct.format(**filtered_config), undefined=StrictUndefined) + # This is part of the public API. @BlockRegistry.register("LLMBlock") # pylint: disable=dangerous-default-value @@ -100,7 +83,9 @@ def __init__( self.prompt_struct = ( """{system}\n{introduction}\n{principles}\n{examples}\n{generation}""" ) - self.prompt_template = template_from_struct_and_config(self.prompt_struct, self.block_config) + self.prompt_template = template_from_struct_and_config( + self.prompt_struct, self.block_config + ) self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs @@ -170,15 +155,30 @@ def _parse(self, generated_string) -> dict: # 2. Non-empty string - the pipeline has specified a custom model prompt # 3. Empty string - the pipeline has specified that no model prompt is needed def _format_prompt(self, sample: Dict) -> str: - prompt = self.prompt_template.render(sample).strip() + prompt_templated_str = self.prompt_template.render(sample).strip() + wrap_in_messages_format = True model_prompt = None if self.model_prompt is None: - model_prompt = _get_model_prompt(self.ctx.model_family) + model_prompt = PromptRegistry.get_template(self.ctx.model_family) elif self.model_prompt: - model_prompt = self.model_prompt + model_prompt = Template(self.model_prompt) + else: + # Our model prompt is an empty string, which we'll render + # verbatim without wrapping in the messages format + model_prompt = PromptRegistry.get_template("blank") + wrap_in_messages_format = False + + if wrap_in_messages_format: + messages = [{"role": "user", "content": prompt_templated_str}] + else: + messages = prompt_templated_str - return prompt if model_prompt is None else model_prompt.format(prompt=prompt) + return model_prompt.render( + messages=messages, + prompt=prompt_templated_str, + add_generation_prompt=True, + ).strip() def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults): gen_kwargs = {**defaults, **gen_kwargs} @@ -290,9 +290,13 @@ def __init__( parser_kwargs={}, batch_kwargs={}, ) -> None: - assert config_paths, "ConditionalLLMBlock config_paths requires at least one entry" + assert ( + config_paths + ), "ConditionalLLMBlock config_paths requires at least one entry" for config_path in config_paths: - assert len(config_path) == 2, "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names" + assert ( + len(config_path) == 2 + ), "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names" super().__init__( ctx, pipe, @@ -307,10 +311,14 @@ def __init__( self.selector_column_name = selector_column_name self.prompt_template = {} if len(config_paths) == 1 and config_paths[0][1] == "All": - self.prompt_template = template_from_struct_and_config(self.prompt_struct, self.block_config) + self.prompt_template = template_from_struct_and_config( + self.prompt_struct, self.block_config + ) else: for config, config_key in config_paths: - self.prompt_template[config_key] = template_from_struct_and_config(self.prompt_struct, self._load_config(config)) + self.prompt_template[config_key] = template_from_struct_and_config( + self.prompt_struct, self._load_config(config) + ) def _format_prompt(self, sample: Dict) -> str: if isinstance(self.prompt_template, dict): @@ -325,11 +333,15 @@ def _format_prompt(self, sample: Dict) -> str: def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: if isinstance(prompt_template, dict): if not self.selector_column_name in input_dict: - logger.error(f"ConditionalLLMBlock {self.block_name} missing key: {self.selector_column_name}") + logger.error( + f"ConditionalLLMBlock {self.block_name} missing key: {self.selector_column_name}" + ) return False config_key = input_dict[self.selector_column_name] if not config_key in prompt_template: - logger.error(f"ConditionalLLMBlock {self.block_name} selector key {config_key} not found in block config") + logger.error( + f"ConditionalLLMBlock {self.block_name} selector key {config_key} not found in block config" + ) return False prompt_template = prompt_template[config_key] return super()._validate(prompt_template, input_dict) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 70a09efd..ae12056c 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -18,20 +18,16 @@ import yaml # First Party -# pylint: disable=ungrouped-imports +from instructlab.sdg.blocks.llmblock import DEFAULT_MAX_NUM_TOKENS from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init -from instructlab.sdg.blocks.llmblock import ( - DEFAULT_MAX_NUM_TOKENS, - MODEL_FAMILY_MERLINITE, - MODEL_FAMILY_MIXTRAL, -) from instructlab.sdg.pipeline import ( FULL_PIPELINES_PACKAGE, SIMPLE_PIPELINES_PACKAGE, Pipeline, PipelineContext, ) +from instructlab.sdg.prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL from instructlab.sdg.utils import GenerateException, models from instructlab.sdg.utils.json import jldump from instructlab.sdg.utils.taxonomy import ( diff --git a/src/instructlab/sdg/prompts.py b/src/instructlab/sdg/prompts.py new file mode 100644 index 00000000..cc446e7c --- /dev/null +++ b/src/instructlab/sdg/prompts.py @@ -0,0 +1,20 @@ +# Local +from .registry import PromptRegistry + +MODEL_FAMILY_MIXTRAL = "mixtral" +MODEL_FAMILY_MERLINITE = "merlinite" + + +@PromptRegistry.register("blank") +def blank_chat_template(): + return """{{ messages }}""" + + +@PromptRegistry.register(MODEL_FAMILY_MERLINITE) +def merlinite_chat_template(): + return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}""" + + +@PromptRegistry.register(MODEL_FAMILY_MIXTRAL) +def mixtral_chat_template(): + return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + ''}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n""" diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py index 128446f8..69597b69 100644 --- a/src/instructlab/sdg/registry.py +++ b/src/instructlab/sdg/registry.py @@ -1,5 +1,5 @@ # Standard -from typing import Dict, List, Union +from typing import Dict import logging # Third Party @@ -70,7 +70,7 @@ def get_template(cls, name: str) -> Template: :return: The Jinja2 template instance. """ if name not in cls._registry: - raise KeyError(f"Template '{name}' not found.") + raise KeyError(f"Prompt template '{name}' not found.") logger.debug(f"Retrieving prompt template '{name}'") return cls._registry[name] @@ -83,38 +83,3 @@ def get_registry(cls): """ logger.debug("Fetching the block registry map.") return cls._registry - - @classmethod - def render_template( - cls, - name: str, - messages: Union[str, List[Dict[str, str]]], - add_generation_prompt: bool = True, - ) -> str: - """Render the template with the provided messages or query. - - :param name: Name of the template to render. - :param messages: Either a single query string or a list of messages (each as a dict with 'role' and 'content'). - :param add_generation_prompt: Whether to add a generation prompt at the end. - :return: The rendered prompt as a string. - """ - - # Special handling for "blank" template - if name == "blank": - if not isinstance(messages, str): - raise ValueError( - "The 'blank' template can only be used with a single query string, not a list of messages." - ) - return messages # Return the query as-is without templating - - # Get the template - template = cls.get_template(name) - - # If `messages` is a string, wrap it in a list with a default user role - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - - # Render the template with the `messages` list - return template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 23b3dcb9..c5415636 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -20,9 +20,8 @@ import yaml # First Party +from instructlab.sdg import LLMBlock, PipelineContext from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data -from instructlab.sdg import LLMBlock -from instructlab.sdg import PipelineContext TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant." diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 73b96ea0..d22d8e39 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -35,7 +35,7 @@ def setUp(self): def test_model_prompt_empty_string(self, mock_load_config): mock_load_config.return_value = self.config_return_value - # Ensure that if an empty model_prompt is not specified, no model prompt is used. + # Ensure that if an empty model_prompt is specified, no model prompt is used. block = LLMBlock( ctx=self.mock_ctx, pipe=self.mock_pipe, @@ -79,13 +79,13 @@ def test_model_prompt_custom(self, mock_load_config): block_name="test_block", config_path="", output_cols=[], - model_prompt="FOO {prompt} BAR", + model_prompt="FOO {{prompt}} BAR", ) prompt = block._format_prompt(self.dataset[1]) self.assertEqual( prompt, "FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR", - "model_prompt should be a non-empty string when set to None", + "custom model_prompt was not used when explicitly set", ) @patch("src.instructlab.sdg.blocks.block.Block._load_config") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a07df5ef..01848bbf 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,8 +14,7 @@ import pytest # First Party -from instructlab.sdg import Block -from instructlab.sdg import Pipeline, PipelineBlockError +from instructlab.sdg import Block, Pipeline, PipelineBlockError ## Helpers ## diff --git a/tests/test_registry.py b/tests/test_registry.py index ddad7001..0f197a6d 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -3,10 +3,12 @@ # First Party from src.instructlab.sdg.registry import BlockRegistry + def test_block_registry(): @BlockRegistry.register("TestFooClass") class TestFooClass: pass + registry = BlockRegistry.get_registry() assert registry is not None assert registry["TestFooClass"] is TestFooClass From 12f450c48d062e51977be1ccfc73699960e76842 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 25 Nov 2024 20:08:16 -0500 Subject: [PATCH 05/12] Stub in LLMLogProbBlock and LLMMessagesBlock These new blocks don't do anything yet, but stubbing them into the codebase and will continue working on figuring out what they're supposed to do and wiring things up with tests. Co-authored-by: shivchander Co-authored-by: abhi1092 Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 9 +- src/instructlab/sdg/blocks/llmblock.py | 154 +++++++++++++++++++++++++ tests/test_llmblock.py | 72 +++++++++++- 3 files changed, 231 insertions(+), 4 deletions(-) diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index e5dceeae..6970a6ac 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -12,6 +12,8 @@ "FlattenColumnsBlock", "GenerateException", "LLMBlock", + "LLMLogProbBlock", + "LLMMessagesBlock", "Pipeline", "PipelineBlockError", "PipelineConfigParserError", @@ -30,7 +32,12 @@ # Local from .blocks.block import Block from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError -from .blocks.llmblock import ConditionalLLMBlock, LLMBlock +from .blocks.llmblock import ( + ConditionalLLMBlock, + LLMBlock, + LLMLogProbBlock, + LLMMessagesBlock, +) from .blocks.utilblocks import ( CombineColumnsBlock, DuplicateColumnsBlock, diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index a7744249..3aeb9db4 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -345,3 +345,157 @@ def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: return False prompt_template = prompt_template[config_key] return super()._validate(prompt_template, input_dict) + + +# This is part of the public API. +@BlockRegistry.register("LLMLogProbBlock") +class LLMLogProbBlock(LLMBlock): + def __init__( + self, + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=None, + gen_kwargs={}, + parser_kwargs={}, + batch_kwargs={}, + ) -> None: + super().__init__( + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=model_prompt, + gen_kwargs=gen_kwargs, + parser_kwargs=parser_kwargs, + batch_kwargs=batch_kwargs, + ) + + # def _generate_logprobs(self, samples, **gen_kwargs): + # prompts = [ + # self.model_prompt.format(prompt=self._format_prompt(sample)) + # for sample in samples + # ] + # generate_args = {**self.defaults, **gen_kwargs} + + # # verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs + # if "logprobs" not in generate_args: + # generate_args["logprobs"] = 10 + + # if self.server_supports_batched: + # response = self.client.completions.create(prompt=prompts, **generate_args) + # return [choice.logprobs.top_logprobs for choice in response.choices] + + # n = gen_kwargs.get("n", 1) + # results = [] + # for prompt in prompts: + # for _ in range(n): + # response = self.client.completions.create( + # prompt=prompt, **generate_args + # ) + # results.append(response.choices[0].logprobs.top_logprobs) + # return results + + # def _parse(self, generations: List[List[Dict]]) -> List[List[str]]: + # # override the parse method to convert the generations to json string + # # convert the generations to json string to save as dataset + # # this is because the dataset can only store key value pairs which are consistent + # return [[json.dumps(item) for item in sublist] for sublist in generations] + + # def generate(self, samples: Dataset, **gen_kwargs) -> Dataset: + # """ + # Generate the output from the block. This method should first validate the input data, + # then generate the output, and finally parse the generated output before returning it. + + # :return: The parsed output after generation. + # """ + # num_samples = self.block_config.get("num_samples", None) + # logger.debug("Generating outputs for {} samples".format(len(samples))) + + # if (num_samples is not None) and ("num_samples" not in samples.column_names): + # samples = samples.add_column("num_samples", [num_samples] * len(samples)) + + # # validate each sample + # # Log errors and remove invalid samples + # valid_samples = [] + + # for sample in samples: + # if self._validate(self.prompt_template, sample): + # valid_samples.append(sample) + # else: + # logger.warning( + # f"Sample failed validation: {sample}" + # ) # Log details of the failed sample + + # samples = valid_samples + + # if len(samples) == 0: + # logger.warning( + # "No valid samples to generate outputs for, returning empty dataset" + # ) + # return Dataset.from_list([]) + + # # generate the output + + # outputs = self._generate_logprobs(samples, **gen_kwargs) + # logger.debug("Generated outputs: %s", outputs) + + # output_dataset = Dataset.from_list(samples) + # output_dataset = output_dataset.add_column( + # self.output_cols[0], + # self._parse(outputs), # pylint: disable=no-value-for-parameter + # ) + + # return output_dataset + + +# This is part of the public API. +@BlockRegistry.register("LLMMessagesBlock") +class LLMMessagesBlock(LLMBlock): + def __init__( + self, + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=None, + gen_kwargs={}, + parser_kwargs={}, + batch_kwargs={}, + ) -> None: + super().__init__( + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=model_prompt, + gen_kwargs=gen_kwargs, + parser_kwargs=parser_kwargs, + batch_kwargs=batch_kwargs, + ) + + # def _generate(self, samples) -> list: + # generate_args = {**self.defaults, **gen_kwargs} + + # if "n" in generate_args and generate_args.get("temperature", 0) <= 0: + # generate_args["temperature"] = 0.7 + # logger.warning( + # "Temperature should be greater than 0 for n > 1, setting temperature to 0.7" + # ) + + # messages = samples[self.input_col] + + # results = [] + # n = gen_kwargs.get("n", 1) + # for message in messages: + # responses = self.client.chat.completions.create(messages=message, **generate_args) + # if n > 1: + # results.append([choice.message.content for choice in responses.choices]) + # else: + # results.append(responses.choices[0].message.content) + # return results diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index d22d8e39..d52dec58 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -10,7 +10,12 @@ from openai import InternalServerError, NotFoundError # First Party -from src.instructlab.sdg import ConditionalLLMBlock, LLMBlock +from src.instructlab.sdg import ( + ConditionalLLMBlock, + LLMBlock, + LLMLogProbBlock, + LLMMessagesBlock, +) from src.instructlab.sdg.blocks.llmblock import server_supports_batched @@ -88,6 +93,7 @@ def test_model_prompt_custom(self, mock_load_config): "custom model_prompt was not used when explicitly set", ) + @patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMBlockOtherFunctions(unittest.TestCase): def setUp(self): @@ -138,6 +144,7 @@ def test_validate(self, mock_load_config): assert not block._validate(block.prompt_template, {}) assert block._validate(block.prompt_template, {"var1": "foo", "var2": "bar"}) + class TestLLMBlockBatching(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() @@ -187,6 +194,7 @@ def test_server_supports_batched_vllm(self): supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert supports_batched + @patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestConditionalLLMBlock(unittest.TestCase): def setUp(self): @@ -213,5 +221,63 @@ def test_validate(self, mock_load_config): ) assert not block._validate(block.prompt_template, {}) - assert not block._validate(block.prompt_template, {"selector": "_B_", "var1": "foo", "var2": "bar"}) - assert block._validate(block.prompt_template, {"selector": "_A_", "var1": "foo", "var2": "bar"}) + assert not block._validate( + block.prompt_template, {"selector": "_B_", "var1": "foo", "var2": "bar"} + ) + assert block._validate( + block.prompt_template, {"selector": "_A_", "var1": "foo", "var2": "bar"} + ) + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestLLMLogProbBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + + def test_constructor_works(self, mock_load_config): + mock_load_config.return_value = self.config_return_value + block = LLMLogProbBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block is not None + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestLLMMessagesBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + + def test_constructor_works(self, mock_load_config): + mock_load_config.return_value = self.config_return_value + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block is not None From 79d68fb7a17ab5a51dc6397ee25bc5d1c6a05df7 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 25 Nov 2024 20:31:37 -0500 Subject: [PATCH 06/12] Add CHANGELOG.md entries for research reconciliation Signed-off-by: Ben Browning --- .spellcheck-en-custom.txt | 7 +++++++ CHANGELOG.md | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index c06f114b..64aeae66 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 Backport backported +CLI codebase +config configs Dataset dataset @@ -17,6 +19,8 @@ FIXME freeform ICL icl +ie +Jinja JSON Langchain's LLM @@ -39,7 +43,9 @@ Splitter subdirectory subfolder Tatsu +templating Tesseract +TODO tokenizer tokenizers unchunked @@ -47,3 +53,4 @@ upsampled UUID vLLM yaml +yamls diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c7f5a12..9b95d287 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,27 @@ +## Unreleased 0.7.x + +### Features + +#### Custom Blocks and Teacher Models via BlockRegistry and PromptRegistry + +Advanced users are now able to supply custom Pipeline `Block` implementations by registering new blocks with the `BlockRegistry`. It's also possible to register new chat templates for custom teacher models using the new `PromptRegistry`. + +See the `tests/testdata/custom_block.py` and `tests/testdata/custom_block_pipeline.yaml` files in this repository for an example of how to create custom blocks and use them from your own pipeline config yamls. + +See the `tests/testdata/custom_prompt.py` file in this repository for an example how to register custom chat templates used when formatting prompts. + +### Breaking Changes + +#### Pipeline configs and Prompt templates switched to Jinja + +All of our [Pipeline config yamls](src/instructlab/sdg/pipelines) and [prompt template files](src/instructlab/sdg/configs) have moved to [Jinja templates](https://pypi.org/project/Jinja2/) instead of Python string `format()` calls. This brings more expressiveness into our templating language - especially for prompt templates - but does mean any variable substitutions need to be updated from single brackets to double brackets - ie `{document}` becomes `{{document}}`. This only impacts you if you were using custom pipeline config yaml files or custom prompt templates in your config blocks. + +#### ImportBlock removed from Pipeline blocks + +Any users that were specifying custom pipeline configs (instead of using the default `full` or `simple` shipped by us) and also using the `ImportBlock` will now need to rewrite their pipelines to no longer use that block. We do not anticipate that anyone was actually using this block, but please reach out if you were so we can capture your needs in a future release. + +### Fixes + ## v0.6.2 ### Fixes From d8451934740c3c4ea2bfaade4170963cfeb7cdb7 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 26 Nov 2024 12:08:07 -0500 Subject: [PATCH 07/12] Update all knowledge configs to use jinja templates In addition to updating the knowledge configs to use jinja templates, this adds additional tests to validate that we are using jinja templates instead of python string formats. That also required tightening up our usage of jinja `Template` to always preferred `StrictUndefined` behavior everywhere we use it. Signed-off-by: Ben Browning --- src/instructlab/sdg/blocks/llmblock.py | 2 +- .../knowledge/evaluate_faithfulness.yaml | 4 +- .../configs/knowledge/evaluate_question.yaml | 2 +- .../configs/knowledge/evaluate_relevancy.yaml | 4 +- .../generate_questions_responses.yaml | 20 +++--- .../sdg/configs/knowledge/mcq_generation.yaml | 2 +- .../configs/knowledge/simple_generate_qa.yaml | 20 +++--- .../sdg/configs/knowledge/spellcheck.yaml | 2 +- src/instructlab/sdg/registry.py | 4 +- tests/test_llmblock.py | 62 +++++++++++++++++++ 10 files changed, 92 insertions(+), 30 deletions(-) diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 3aeb9db4..692629b6 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -162,7 +162,7 @@ def _format_prompt(self, sample: Dict) -> str: if self.model_prompt is None: model_prompt = PromptRegistry.get_template(self.ctx.model_family) elif self.model_prompt: - model_prompt = Template(self.model_prompt) + model_prompt = Template(self.model_prompt, undefined=StrictUndefined) else: # Our model prompt is an empty string, which we'll render # verbatim without wrapping in the messages format diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml index 828bb31f..fa39d916 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml @@ -58,10 +58,10 @@ generation: | * Return the answer between [Start of Answer] and [End of Answer] tags. [Start of Context] - {document} + {{document}} [End of Context] [Start of Response] - {response} + {{response}} [End of Response] start_tags: ["[Start of Explanation]", "[Start of Answer]"] diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml index 3505e23c..88b473c7 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml @@ -32,7 +32,7 @@ examples: "" generation: | [Start of Question] - {question} + {{question}} [End of Question] start_tags: ["[Start of Explanation]", "[Start of Rating]"] diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml index 81f1d666..cf944005 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml @@ -72,11 +72,11 @@ generation: | Begin your response by providing the feedback followed by the score. Be as objective as possible. [Start of Question] - {question} + {{question}} [End of Question] [Start of Response] - {response} + {{response}} [End of Response] * Return the feedback within the [Start of Feedback] and [End of Feedback] tags. diff --git a/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml b/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml index a2decd76..d7dd61e6 100644 --- a/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml +++ b/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: Develop a series of educational question and answer pairs from a chapter in a {domain} textbook. +introduction: Develop a series of educational question and answer pairs from a chapter in a {{domain}} textbook. principles: | The questions should: @@ -28,29 +28,29 @@ examples: | Here are some examples of questions: [Document] - {icl_document} + {{icl_document}} [QUESTION] - {icl_query_1} + {{icl_query_1}} [ANSWER] - {icl_response_1} + {{icl_response_1}} [END] [QUESTION] - {icl_query_2} + {{icl_query_2}} [ANSWER] - {icl_response_2} + {{icl_response_2}} [END] [QUESTION] - {icl_query_3} + {{icl_query_3}} [ANSWER] - {icl_response_3} + {{icl_response_3}} [END] generation: | Here is the document: [DOCUMENT] - {document_outline} - {document} + {{document_outline}} + {{document}} diff --git a/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml b/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml index 091001c5..1d8c6aca 100644 --- a/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml +++ b/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml @@ -75,7 +75,7 @@ generation: | Here is the document: [Start of Document] - {document} + {{document}} [End of Document] start_tags: ["[Start of Question]", "[Start of Answer]"] diff --git a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml index 784902b4..63310db5 100644 --- a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml +++ b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: Develop a series of educational question and answer pairs from a chapter in a {domain} textbook. +introduction: Develop a series of educational question and answer pairs from a chapter in a {{domain}} textbook. principles: | Here are the requirements: @@ -14,23 +14,23 @@ principles: | examples: | Here is a sample section of the document as an example: - {icl_document} + {{icl_document}} Here are some examples to help you understand the type of questions that are asked for this document: - {icl_query_1} - {icl_response_1} + {{icl_query_1}} + {{icl_response_1}} - {icl_query_2} - {icl_response_2} + {{icl_query_2}} + {{icl_response_2}} - {icl_query_3} - {icl_response_3} + {{icl_query_3}} + {{icl_response_3}} Here is the document: - {document_outline} - {document} + {{document_outline}} + {{document}} generation: | Provide a single question and answer pair based on the document. diff --git a/src/instructlab/sdg/configs/knowledge/spellcheck.yaml b/src/instructlab/sdg/configs/knowledge/spellcheck.yaml index daf1dafa..c6e7b614 100644 --- a/src/instructlab/sdg/configs/knowledge/spellcheck.yaml +++ b/src/instructlab/sdg/configs/knowledge/spellcheck.yaml @@ -11,7 +11,7 @@ examples: "" generation: | Document: - {document} + {{document}} start_tags: [""] end_tags: [""] diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py index 69597b69..f321a394 100644 --- a/src/instructlab/sdg/registry.py +++ b/src/instructlab/sdg/registry.py @@ -3,7 +3,7 @@ import logging # Third Party -from jinja2 import Template +from jinja2 import StrictUndefined, Template logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ def register(cls, name: str): def decorator(func): template_str = func() - cls._registry[name] = Template(template_str) + cls._registry[name] = Template(template_str, undefined=StrictUndefined) logger.debug(f"Registered prompt template '{name}'") return func diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index d52dec58..160dfc85 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from importlib import resources from unittest.mock import MagicMock, patch +import os import unittest # Third Party @@ -94,6 +96,66 @@ def test_model_prompt_custom(self, mock_load_config): ) +class TestLLMBlockWithRealConfigs(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + + def test_knowledge_configs_with_invalid_sample(self): + configs = [ + "evaluate_faithfulness.yaml", + "evaluate_question.yaml", + "evaluate_relevancy.yaml", + "generate_questions_responses.yaml", + "mcq_generation.yaml", + "spellcheck.yaml", + "simple_generate_qa.yaml", + ] + for config in configs: + config_yaml = os.path.join( + resources.files("instructlab.sdg.configs.knowledge"), config + ) + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name=config, + config_path=config_yaml, + output_cols=[], + ) + sample = {"foo": "bar"} + assert not block._validate( + block.prompt_template, sample + ), f"knowledge config {config} validated even though it was given a sample with none of the expected fields" + + def test_simple_generate_qa_with_valid_sample(self): + config_yaml = os.path.join( + resources.files("instructlab.sdg.configs.knowledge"), + "simple_generate_qa.yaml", + ) + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path=config_yaml, + output_cols=[], + ) + sample = { + "domain": "domain goes here", + "document": "document goes here", + "document_outline": "document outline goes here", + "icl_document": "context goes here", + "icl_query_1": "query 1 goes here", + "icl_response_1": "response 1 goes here", + "icl_query_2": "query 2 goes here", + "icl_response_2": "response 2 goes here", + "icl_query_3": "query 3 goes here", + "icl_response_3": "response 3 goes here", + } + assert block._validate(block.prompt_template, sample) + + @patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMBlockOtherFunctions(unittest.TestCase): def setUp(self): From c7d638224d5a2dcc4014d3b8021c2c14eda12f8b Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 26 Nov 2024 12:55:07 -0500 Subject: [PATCH 08/12] Update all skill configs to use jinja templates This also makes the test running `Block._validate` on all our shipped configs a bit more generic so that it can cover all skill and knowledge yaml files without having to keep a separate list of config files to test. Signed-off-by: Ben Browning --- .../sdg/configs/skills/contexts.yaml | 4 +- .../skills/evaluate_freeform_pair.yaml | 4 +- .../skills/evaluate_freeform_questions.yaml | 6 +-- .../skills/evaluate_grounded_pair.yaml | 6 +-- .../skills/evaluate_grounded_questions.yaml | 6 +-- .../configs/skills/freeform_questions.yaml | 6 +-- .../configs/skills/freeform_responses.yaml | 6 +-- .../configs/skills/grounded_questions.yaml | 10 ++--- .../configs/skills/grounded_responses.yaml | 10 ++--- .../skills/simple_generate_qa_freeform.yaml | 6 +-- .../skills/simple_generate_qa_grounded.yaml | 8 ++-- src/instructlab/sdg/registry.py | 3 -- tests/test_llmblock.py | 43 ++++++++----------- 13 files changed, 54 insertions(+), 64 deletions(-) diff --git a/src/instructlab/sdg/configs/skills/contexts.yaml b/src/instructlab/sdg/configs/skills/contexts.yaml index be257c80..813680e1 100644 --- a/src/instructlab/sdg/configs/skills/contexts.yaml +++ b/src/instructlab/sdg/configs/skills/contexts.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: You are asked to come up with a diverse context for - {task_description}. +introduction: You are asked to come up with a diverse context for - {{task_description}}. principles: | Please follow these guiding principles when generating responses: * Use proper grammar and punctuation. @@ -11,7 +11,7 @@ principles: | examples: | To better assist you with this task, here is an example of a context: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] generation: | diff --git a/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml b/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml index 1dd3e38d..298181f4 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml @@ -29,11 +29,11 @@ examples: | generation: | Here's the question and the answer you need to evaluate: [Start of Question] - {question} + {{question}} [End of Question] [Start of Answer] - {response} + {{response}} [End of Answer] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the answer on a scale of 1 to 3 as mentioned above. diff --git a/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml b/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml index 50bafcb8..850fe006 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml @@ -9,7 +9,7 @@ principles: | * The questions should be in English. * The questions should be 1 to 2 sentences long and should be properly formatted. * The question should not be offensive, abusive, or harmful. It should be safe and respectful. - * The question should be relevant to the task given - {task_description}. + * The question should be relevant to the task given - {{task_description}}. If the question meets the above requirements, please rate it 1. If not, please rate it 0. @@ -32,10 +32,10 @@ examples: | generation: | Here's the question you need to evaluate: - Task Description: {task_description} + Task Description: {{task_description}} [Start of Question] - {question} + {{question}} [End of Question] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the question on a scale of 0 or 1 as mentioned above. Strictly follow the format below: diff --git a/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml b/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml index 15132c88..993c52f2 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml @@ -35,15 +35,15 @@ generation: | Here's the context, question and the answer you need to evaluate: [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] [Start of Answer] - {response} + {{response}} [End of Answer] * Return the evaluation between [Start of Evaluation] and [End of Evaluation] tags. diff --git a/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml b/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml index 6999987f..fd761806 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml @@ -9,7 +9,7 @@ principles: | * The questions should be in English. * The questions should be 1 to 2 sentences long and should be properly formatted. * The question should not be offensive, abusive, or harmful. It should be safe and respectful. - * The question should be relevant to the task given - {task_description}. + * The question should be relevant to the task given - {{task_description}}. * Most importantly all the questions should be grounded in the context provided and should be answerable solely based on the provided context. If the question meets the above requirements, please rate it 1. If not, please rate it 0. @@ -37,10 +37,10 @@ generation: | Here's the context and question you need to evaluate. Return the evaluation between [Start of Evaluation] and [End of Evaluation] tags. [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the question on a scale of 0 or 1 as mentioned above. diff --git a/src/instructlab/sdg/configs/skills/freeform_questions.yaml b/src/instructlab/sdg/configs/skills/freeform_questions.yaml index f3d1ed90..de0e64d4 100644 --- a/src/instructlab/sdg/configs/skills/freeform_questions.yaml +++ b/src/instructlab/sdg/configs/skills/freeform_questions.yaml @@ -1,7 +1,7 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. introduction: | - You are asked to come up with a set of {num_samples} diverse questions - {task_description}. + You are asked to come up with a set of {{num_samples}} diverse questions - {{task_description}}. principles: | Please follow these guiding principles when generating responses: @@ -19,11 +19,11 @@ examples: | To better assist you with this task, here is an example: [Start of Question] - {seed_question} + {{seed_question}} [End of Question] generation: | - Now generate {num_samples} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Return each question between [Start of Question] and [End of Question] tags. + Now generate {{num_samples}} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Return each question between [Start of Question] and [End of Question] tags. start_tags: ["[Start of Question]"] end_tags: ["[End of Question]"] diff --git a/src/instructlab/sdg/configs/skills/freeform_responses.yaml b/src/instructlab/sdg/configs/skills/freeform_responses.yaml index cf7ff177..b68c4790 100644 --- a/src/instructlab/sdg/configs/skills/freeform_responses.yaml +++ b/src/instructlab/sdg/configs/skills/freeform_responses.yaml @@ -13,18 +13,18 @@ principles: | examples: | To better assist you with this task, here is an example: [Start of Question] - {seed_question} + {{seed_question}} [End of Question] [Start of Response] - {seed_response} + {{seed_response}} [End of Response] generation: | Now generate a response to the following prompt. Remember to use the same style and format as the example above. [Start of Question] - {question} + {{question}} [End of Question] Return the response between [Start of Response] and [End of Response] tags. diff --git a/src/instructlab/sdg/configs/skills/grounded_questions.yaml b/src/instructlab/sdg/configs/skills/grounded_questions.yaml index 904523c9..979d35f3 100644 --- a/src/instructlab/sdg/configs/skills/grounded_questions.yaml +++ b/src/instructlab/sdg/configs/skills/grounded_questions.yaml @@ -1,7 +1,7 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. introduction: | - You are asked to come up with a set of {num_samples} diverse questions - {task_description}. + You are asked to come up with a set of {{num_samples}} diverse questions - {{task_description}}. principles: | Please follow these guiding principles when generating responses: @@ -21,17 +21,17 @@ examples: | To better assist you with this task, here is an example: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] [Start of Question] - {seed_question} + {{seed_question}} [End of Question] generation: | - Now generate {num_samples} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Do not return any contexts or answers, only the questions. Return each question between [Start of Question] and [End of Question] tags. + Now generate {{num_samples}} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Do not return any contexts or answers, only the questions. Return each question between [Start of Question] and [End of Question] tags. [Start of Context] - {context} + {{context}} [End of Context] start_tags: ["[Start of Question]"] diff --git a/src/instructlab/sdg/configs/skills/grounded_responses.yaml b/src/instructlab/sdg/configs/skills/grounded_responses.yaml index bacd5c10..d8161591 100644 --- a/src/instructlab/sdg/configs/skills/grounded_responses.yaml +++ b/src/instructlab/sdg/configs/skills/grounded_responses.yaml @@ -14,15 +14,15 @@ examples: | To better assist you with this task, here is an example: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] [Start of Question] - {seed_question} + {{seed_question}} [End of Question] [Start of Response] - {seed_response} + {{seed_response}} [End of Response] generation: | @@ -30,10 +30,10 @@ generation: | Return the response between [Start of Response] and [End of Response] tags. [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] Return the response between [Start of Response] and [End of Response] tags. diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml index 5144b22c..4a4217af 100644 --- a/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml @@ -13,12 +13,12 @@ principles: | 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. examples: | - The task is {task_description}. + The task is {{task_description}}. Here is an example to help you understand the type of questions that are asked for: - {seed_question} - {seed_response} + {{seed_question}} + {{seed_response}} generation: | Provide a single question and answer pair based on the examples. diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml index 54588f91..b0423218 100644 --- a/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml @@ -13,16 +13,16 @@ principles: | 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. examples: | - The task is {task_description}. + The task is {{task_description}}. Here is some context for the example question: - {seed_context} + {{seed_context}} Here is an example to help you understand the type of questions that are asked for: - {seed_question} - {seed_response} + {{seed_question}} + {{seed_response}} generation: | Provide a single question and answer pair based on the example. diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py index f321a394..52fa0e5e 100644 --- a/src/instructlab/sdg/registry.py +++ b/src/instructlab/sdg/registry.py @@ -37,7 +37,6 @@ def get_registry(cls): :return: Dictionary of registered block names and classes. """ - logger.debug("Fetching the block registry map.") return cls._registry @@ -71,7 +70,6 @@ def get_template(cls, name: str) -> Template: """ if name not in cls._registry: raise KeyError(f"Prompt template '{name}' not found.") - logger.debug(f"Retrieving prompt template '{name}'") return cls._registry[name] @classmethod @@ -81,5 +79,4 @@ def get_registry(cls): :return: Dictionary of registered block names and classes. """ - logger.debug("Fetching the block registry map.") return cls._registry diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 160dfc85..15519fb7 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -103,31 +103,24 @@ def setUp(self): self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() - def test_knowledge_configs_with_invalid_sample(self): - configs = [ - "evaluate_faithfulness.yaml", - "evaluate_question.yaml", - "evaluate_relevancy.yaml", - "generate_questions_responses.yaml", - "mcq_generation.yaml", - "spellcheck.yaml", - "simple_generate_qa.yaml", - ] - for config in configs: - config_yaml = os.path.join( - resources.files("instructlab.sdg.configs.knowledge"), config - ) - block = LLMBlock( - ctx=self.mock_ctx, - pipe=self.mock_pipe, - block_name=config, - config_path=config_yaml, - output_cols=[], - ) - sample = {"foo": "bar"} - assert not block._validate( - block.prompt_template, sample - ), f"knowledge config {config} validated even though it was given a sample with none of the expected fields" + def test_configs_with_invalid_sample(self): + for config_type in ["knowledge", "skills"]: + for config_yaml in resources.files( + f"instructlab.sdg.configs.{config_type}" + ).iterdir(): + if config_yaml.suffix != ".yaml": + continue + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name=config_yaml.stem, + config_path=config_yaml, + output_cols=[], + ) + sample = {"foo": "bar"} + assert not block._validate( + block.prompt_template, sample + ), f"{config_type} config {config_yaml.name} validated even though it was given a sample with none of the expected fields" def test_simple_generate_qa_with_valid_sample(self): config_yaml = os.path.join( From aa96e7a45d96f95cb5c244688ad0a6fc03239680 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 26 Nov 2024 15:50:21 -0500 Subject: [PATCH 09/12] Actually use BlockRegistry to lookup blocks This gets rid of the hardcoded block types dict and drives everything off the BlockRegistry. This means I also added a functional test showing how users can create and register their own Block implementations and use those in their pipeline config files - see `tests/testdata/custom_block_pipeline.yaml` and `tests/testdata/custom_block.py` for those examples. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 3 +++ src/instructlab/sdg/blocks/filterblock.py | 2 ++ src/instructlab/sdg/blocks/utilblocks.py | 8 ++++++ src/instructlab/sdg/pipeline.py | 22 ++++----------- tests/functional/test_custom_block.py | 11 ++++++++ tests/test_pipeline.py | 6 +++-- tests/testdata/custom_block.py | 33 +++++++++++++++++++++++ tests/testdata/custom_block_pipeline.yaml | 5 ++++ 8 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 tests/functional/test_custom_block.py create mode 100644 tests/testdata/custom_block.py create mode 100644 tests/testdata/custom_block_pipeline.yaml diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 6970a6ac..9edda477 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -3,6 +3,7 @@ # NOTE: This package imports Torch and other heavy packages. __all__ = ( "Block", + "BlockRegistry", "CombineColumnsBlock", "ConditionalLLMBlock", "DuplicateColumnsBlock", @@ -18,6 +19,7 @@ "PipelineBlockError", "PipelineConfigParserError", "PipelineContext", + "PromptRegistry", "RenameColumnsBlock", "SamplePopulatorBlock", "SelectorBlock", @@ -58,5 +60,6 @@ PipelineContext, ) from .prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL +from .registry import BlockRegistry, PromptRegistry from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/blocks/filterblock.py b/src/instructlab/sdg/blocks/filterblock.py index 07372aef..181c973e 100644 --- a/src/instructlab/sdg/blocks/filterblock.py +++ b/src/instructlab/sdg/blocks/filterblock.py @@ -8,6 +8,7 @@ from datasets import Dataset # Local +from ..registry import BlockRegistry from .block import Block logger = logging.getLogger(__name__) @@ -86,6 +87,7 @@ def convert_column(sample): # This is part of the public API. +@BlockRegistry.register("FilterByValueBlock") class FilterByValueBlock(Block): def __init__( self, diff --git a/src/instructlab/sdg/blocks/utilblocks.py b/src/instructlab/sdg/blocks/utilblocks.py index 00d1e32a..a03e6cb7 100644 --- a/src/instructlab/sdg/blocks/utilblocks.py +++ b/src/instructlab/sdg/blocks/utilblocks.py @@ -10,12 +10,14 @@ from instructlab.sdg.utils import pandas # Local +from ..registry import BlockRegistry from .block import Block logger = logging.getLogger(__name__) # This is part of the public API. +@BlockRegistry.register("SamplePopulatorBlock") class SamplePopulatorBlock(Block): def __init__( self, ctx, pipe, block_name, config_paths, column_name, post_fix="" @@ -46,6 +48,7 @@ def generate(self, samples) -> Dataset: # This is part of the public API. +@BlockRegistry.register("SelectorBlock") class SelectorBlock(Block): def __init__( self, ctx, pipe, block_name, choice_map, choice_col, output_col @@ -75,6 +78,7 @@ def generate(self, samples: Dataset) -> Dataset: # This is part of the public API. +@BlockRegistry.register("CombineColumnsBlock") class CombineColumnsBlock(Block): def __init__( self, ctx, pipe, block_name, columns, output_col, separator="\n\n" @@ -103,6 +107,7 @@ def generate(self, samples: Dataset) -> Dataset: ) +@BlockRegistry.register("FlattenColumnsBlock") class FlattenColumnsBlock(Block): """Melt/transform a data from a wide format to a long format see pandas.melt for a description @@ -132,6 +137,7 @@ def generate(self, samples: Dataset) -> Dataset: return pandas.dataset_from_pandas_dataframe(flatten_df) +@BlockRegistry.register("DuplicateColumnsBlock") class DuplicateColumnsBlock(Block): def __init__(self, ctx, pipe, block_name: str, columns_map: dict) -> None: """Create duplicate of columns specified in column map. @@ -150,6 +156,7 @@ def generate(self, samples: Dataset): return samples +@BlockRegistry.register("RenameColumnsBlock") class RenameColumnsBlock(Block): def __init__(self, ctx, pipe, block_name: str, columns_map: dict) -> None: """Rename dataset columns. @@ -165,6 +172,7 @@ def generate(self, samples: Dataset): return samples +@BlockRegistry.register("SetToMajorityValueBlock") class SetToMajorityValueBlock(Block): """Set the value of the specified column to the most common value (the mode) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 66161155..59613a8e 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -19,8 +19,9 @@ from instructlab.sdg.utils import pandas # Local -from .blocks import filterblock, llmblock, utilblocks +from .blocks import llmblock from .blocks.block import Block +from .registry import BlockRegistry logger = logging.getLogger(__name__) @@ -255,24 +256,11 @@ def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int] ) -_block_types = { - "CombineColumnsBlock": utilblocks.CombineColumnsBlock, - "ConditionalLLMBlock": llmblock.ConditionalLLMBlock, - "DuplicateColumnsBlock": utilblocks.DuplicateColumnsBlock, - "FilterByValueBlock": filterblock.FilterByValueBlock, - "FlattenColumnsBlock": utilblocks.FlattenColumnsBlock, - "LLMBlock": llmblock.LLMBlock, - "RenameColumnsBlock": utilblocks.RenameColumnsBlock, - "SamplePopulatorBlock": utilblocks.SamplePopulatorBlock, - "SelectorBlock": utilblocks.SelectorBlock, - "SetToMajorityValueBlock": utilblocks.SetToMajorityValueBlock, -} - - def _lookup_block_type(block_type): - if not block_type in _block_types: + block_types = BlockRegistry.get_registry() + if not block_type in block_types: raise PipelineConfigParserError(f"Unknown block type {block_type}") - return _block_types[block_type] + return block_types[block_type] _PIPELINE_CONFIG_PARSER_MAJOR = 1 diff --git a/tests/functional/test_custom_block.py b/tests/functional/test_custom_block.py new file mode 100644 index 00000000..65456a15 --- /dev/null +++ b/tests/functional/test_custom_block.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import pathlib +import subprocess +import sys + + +def test_custom_block(testdata_path: pathlib.Path): + script = testdata_path.joinpath("custom_block.py") + subprocess.check_call([sys.executable, str(script)], text=True) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 01848bbf..6c801845 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,9 +21,11 @@ @contextmanager def block_types(block_types_dict): + get_registry_mock = mock.MagicMock() + get_registry_mock.return_value = block_types_dict with mock.patch( - "instructlab.sdg.pipeline._block_types", - block_types_dict, + "instructlab.sdg.registry.BlockRegistry.get_registry", + get_registry_mock, ): yield diff --git a/tests/testdata/custom_block.py b/tests/testdata/custom_block.py new file mode 100644 index 00000000..409d4ebc --- /dev/null +++ b/tests/testdata/custom_block.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import pathlib + +# Third Party +from datasets import Dataset + +# First Party +from instructlab.sdg import Block, BlockRegistry, Pipeline, PipelineContext + + +@BlockRegistry.register("EchoBlock") +class EchoBlock(Block): + def generate(self, samples: Dataset): + return samples + + +pipeline_context = PipelineContext(None, "mixtral", "my_model", 5) +pipeline_yaml = pathlib.Path(__file__).parent.joinpath("custom_block_pipeline.yaml") +pipeline = Pipeline.from_file(pipeline_context, pipeline_yaml) +input_ds = Dataset.from_list( + [ + { + "fruit": "apple", + "color": "red", + } + ] +) +output_ds = pipeline.generate(input_ds) +assert len(output_ds) == 1 +assert output_ds[0]["fruit"] == "apple" +assert output_ds[0]["color"] == "red" diff --git a/tests/testdata/custom_block_pipeline.yaml b/tests/testdata/custom_block_pipeline.yaml new file mode 100644 index 00000000..0f4c447e --- /dev/null +++ b/tests/testdata/custom_block_pipeline.yaml @@ -0,0 +1,5 @@ +version: "1.0" +blocks: + - name: echo + type: EchoBlock + config: {} From 5e9608cfb8d905e4e6ff0712a49ffc59be41f673 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 27 Nov 2024 11:54:45 -0500 Subject: [PATCH 10/12] Use PromptRegistry for all chat templates This removes the mapping of model families in SDG itself between granite, mixtral, mistral, merlinite, etc. Instead, it uses the PromptRegistry to lookup chat templates based on the model family given. And, if no model family is given, it still falls back to doing a best-guess based on the file path of the selected teacher model. A simple test was added to demonstrate how to register and use custom chat templates for generating prompts via the PromptRegistry. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 3 --- src/instructlab/sdg/blocks/llmblock.py | 9 +++------ src/instructlab/sdg/generate_data.py | 6 +----- src/instructlab/sdg/prompts.py | 14 ++++++++----- src/instructlab/sdg/registry.py | 8 +++++--- src/instructlab/sdg/utils/models.py | 21 +++++++++---------- tests/functional/test_custom_block.py | 5 +++++ tests/test_models.py | 22 ++++++++++++++++++-- tests/testdata/custom_prompt.py | 28 ++++++++++++++++++++++++++ 9 files changed, 80 insertions(+), 36 deletions(-) create mode 100644 tests/testdata/custom_prompt.py diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 9edda477..4ba5bf0d 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -24,8 +24,6 @@ "SamplePopulatorBlock", "SelectorBlock", "SetToMajorityValueBlock", - "MODEL_FAMILY_MERLINITE", - "MODEL_FAMILY_MIXTRAL", "FULL_PIPELINES_PACKAGE", "SIMPLE_PIPELINES_PACKAGE", "generate_data", @@ -59,7 +57,6 @@ PipelineConfigParserError, PipelineContext, ) -from .prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL from .registry import BlockRegistry, PromptRegistry from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 692629b6..8a5832de 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -13,6 +13,8 @@ import openai # Local +# Import prompts to register default chat templates +from .. import prompts as default_prompts # pylint: disable=unused-import from ..registry import BlockRegistry, PromptRegistry from .block import Block @@ -156,7 +158,6 @@ def _parse(self, generated_string) -> dict: # 3. Empty string - the pipeline has specified that no model prompt is needed def _format_prompt(self, sample: Dict) -> str: prompt_templated_str = self.prompt_template.render(sample).strip() - wrap_in_messages_format = True model_prompt = None if self.model_prompt is None: @@ -167,12 +168,8 @@ def _format_prompt(self, sample: Dict) -> str: # Our model prompt is an empty string, which we'll render # verbatim without wrapping in the messages format model_prompt = PromptRegistry.get_template("blank") - wrap_in_messages_format = False - if wrap_in_messages_format: - messages = [{"role": "user", "content": prompt_templated_str}] - else: - messages = prompt_templated_str + messages = [{"role": "user", "content": prompt_templated_str}] return model_prompt.render( messages=messages, diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index ae12056c..74e279b9 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -27,7 +27,6 @@ Pipeline, PipelineContext, ) -from instructlab.sdg.prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL from instructlab.sdg.utils import GenerateException, models from instructlab.sdg.utils.json import jldump from instructlab.sdg.utils.taxonomy import ( @@ -355,10 +354,7 @@ def generate_data( logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}") - if models.get_model_family(model_family, model_name) == "mixtral": - model_family = MODEL_FAMILY_MIXTRAL - else: - model_family = MODEL_FAMILY_MERLINITE + model_family = models.get_model_family(model_family, model_name) ctx = _context_init( client, diff --git a/src/instructlab/sdg/prompts.py b/src/instructlab/sdg/prompts.py index cc446e7c..77d35a53 100644 --- a/src/instructlab/sdg/prompts.py +++ b/src/instructlab/sdg/prompts.py @@ -1,20 +1,24 @@ # Local from .registry import PromptRegistry -MODEL_FAMILY_MIXTRAL = "mixtral" -MODEL_FAMILY_MERLINITE = "merlinite" +# {{ prompt }} gives us the config's raw prompt string, not wrapped in +# any messages format + +# {{ messages }} gives us the config's prompt in messages format, +# where the config's prompt becomes the content value of a user role +# message @PromptRegistry.register("blank") def blank_chat_template(): - return """{{ messages }}""" + return """{{ prompt }}""" -@PromptRegistry.register(MODEL_FAMILY_MERLINITE) +@PromptRegistry.register("merlinite", "granite") def merlinite_chat_template(): return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}""" -@PromptRegistry.register(MODEL_FAMILY_MIXTRAL) +@PromptRegistry.register("mixtral", "mistral") def mixtral_chat_template(): return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + ''}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n""" diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py index 52fa0e5e..2db22611 100644 --- a/src/instructlab/sdg/registry.py +++ b/src/instructlab/sdg/registry.py @@ -46,7 +46,7 @@ class PromptRegistry: _registry: Dict[str, Template] = {} @classmethod - def register(cls, name: str): + def register(cls, *names: str): """Decorator to register a Jinja2 template function by name. :param name: Name of the template to register. @@ -55,8 +55,10 @@ def register(cls, name: str): def decorator(func): template_str = func() - cls._registry[name] = Template(template_str, undefined=StrictUndefined) - logger.debug(f"Registered prompt template '{name}'") + template = Template(template_str, undefined=StrictUndefined) + for name in names: + cls._registry[name] = template + logger.debug(f"Registered prompt template '{name}'") return func return decorator diff --git a/src/instructlab/sdg/utils/models.py b/src/instructlab/sdg/utils/models.py index a2736988..da01421b 100644 --- a/src/instructlab/sdg/utils/models.py +++ b/src/instructlab/sdg/utils/models.py @@ -5,25 +5,22 @@ import re # First Party +from instructlab.sdg.registry import PromptRegistry from instructlab.sdg.utils import GenerateException # When otherwise unknown, ilab uses this as the default family DEFAULT_MODEL_FAMILY = "merlinite" -# Model families understood by ilab -MODEL_FAMILIES = set(("merlinite", "mixtral")) - -# Map model names to their family -MODEL_FAMILY_MAPPINGS = {"granite": "merlinite", "mistral": "mixtral"} - def get_model_family(model_family, model_path): - model_family_retrieved = MODEL_FAMILY_MAPPINGS.get(model_family, model_family) - if model_family_retrieved and model_family_retrieved.lower() not in MODEL_FAMILIES: - raise GenerateException("Unknown model family: %s" % model_family_retrieved) + registry = PromptRegistry.get_registry() + + # A model_family was given, so use it explicitly + if model_family: + if model_family not in registry: + raise GenerateException("Unknown model family: %s" % model_family) + return model_family # Try to guess the model family based on the model's filename guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower() - guess = MODEL_FAMILY_MAPPINGS.get(guess, guess) - - return guess if guess in MODEL_FAMILIES else DEFAULT_MODEL_FAMILY + return guess if guess in registry else DEFAULT_MODEL_FAMILY diff --git a/tests/functional/test_custom_block.py b/tests/functional/test_custom_block.py index 65456a15..01bcdb07 100644 --- a/tests/functional/test_custom_block.py +++ b/tests/functional/test_custom_block.py @@ -9,3 +9,8 @@ def test_custom_block(testdata_path: pathlib.Path): script = testdata_path.joinpath("custom_block.py") subprocess.check_call([sys.executable, str(script)], text=True) + + +def test_custom_prompt(testdata_path: pathlib.Path): + script = testdata_path.joinpath("custom_prompt.py") + subprocess.check_call([sys.executable, str(script)], text=True) diff --git a/tests/test_models.py b/tests/test_models.py index 27cf4d68..e3755553 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,7 +13,7 @@ class TestModels: def test_granite_model_family(self): assert ( models.get_model_family("granite", "./models/granite-7b-lab-Q4_K_M.gguf") - == "merlinite" + == "granite" ) def test_merlinite_model_family(self): @@ -32,12 +32,30 @@ def test_mixtral_model_family(self): == "mixtral" ) + def test_mistral_model_family(self): + assert ( + models.get_model_family( + "mistral", "./models/mistral-7b-instruct-v0.2.Q4_K_M.gguf" + ) + == "mistral" + ) + def test_default_model_family(self): + assert ( + models.get_model_family(None, "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf") + == "merlinite" + ) + assert ( + models.get_model_family("", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf") + == "merlinite" + ) + + def test_model_family_overrides(self): assert ( models.get_model_family( "mixtral", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf" ) - == "merlinite" + == "mixtral" ) def test_unknown_model_family(self): diff --git a/tests/testdata/custom_prompt.py b/tests/testdata/custom_prompt.py new file mode 100644 index 00000000..ea5f06c4 --- /dev/null +++ b/tests/testdata/custom_prompt.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 + +# First Party +from instructlab.sdg import PromptRegistry + + +# Register our custom chat template under the "custom_model_family" +# model family +@PromptRegistry.register("custom_model_family") +def custom_chat_template(): + return """{% for message in messages %}{% if message['role'] == 'system' %}{{ '<>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<>' + '\n' + message['content'] + ('' if loop.last else '\n') }}{% endif %}{% endfor %}""" + + +# Lookup the chat template for "custom_model_family" model family +template = PromptRegistry.get_template("custom_model_family") +assert template is not None + +# Ensure the template found is our custom one +prompt = template.render( + messages=[ + {"role": "system", "content": "system prompt goes here"}, + {"role": "user", "content": "user content goes here"}, + ] +) +expected_prompt = ( + "<>\nsystem prompt goes here\n<>\nuser content goes here\n" +) +assert prompt == expected_prompt From 42fb71d0495bdca65dcc454ec1eef632ff78dd11 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 27 Nov 2024 12:23:30 -0500 Subject: [PATCH 11/12] Port the new IterBlock from research code This adds a new Block type - `IterBlock` - that calls another block N times for a set of given input samples. Every iteration through the loop, the samples returned from the child block's `generate` call get added to the list of samples produced from this block. So, if you use an `IterBlock` to call an `LLMBlock` 5 times, you'll get 5 samples generated (and 5 calls to the LLM) for every sample in the source dataset. The output dataset will contain all 5 generated samples resulting from each 1 input sample. Co-authored-by: shivchander Co-authored-by: abhi1092 Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 2 + src/instructlab/sdg/blocks/block.py | 27 ++++++----- src/instructlab/sdg/blocks/iterblock.py | 57 +++++++++++++++++++++++ src/instructlab/sdg/blocks/llmblock.py | 14 ++++-- src/instructlab/sdg/registry.py | 43 +++++++++++++---- tests/test_iterblock.py | 61 +++++++++++++++++++++++++ 6 files changed, 178 insertions(+), 26 deletions(-) create mode 100644 src/instructlab/sdg/blocks/iterblock.py create mode 100644 tests/test_iterblock.py diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 4ba5bf0d..6f091424 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -12,6 +12,7 @@ "FilterByValueBlockError", "FlattenColumnsBlock", "GenerateException", + "IterBlock", "LLMBlock", "LLMLogProbBlock", "LLMMessagesBlock", @@ -32,6 +33,7 @@ # Local from .blocks.block import Block from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError +from .blocks.iterblock import IterBlock from .blocks.llmblock import ( ConditionalLLMBlock, LLMBlock, diff --git a/src/instructlab/sdg/blocks/block.py b/src/instructlab/sdg/blocks/block.py index 1e0a39b2..3bceca65 100644 --- a/src/instructlab/sdg/blocks/block.py +++ b/src/instructlab/sdg/blocks/block.py @@ -2,7 +2,6 @@ # Standard from abc import ABC -from collections import ChainMap from typing import Any, Dict, Union import logging import os.path @@ -30,20 +29,23 @@ def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bo Validate the input data for this block. This method validates whether all required variables in the Jinja template are provided in the input_dict. - :param prompt_template: The Jinja2 template object. - :param input_dict: A dictionary of input values to check against the template. - :return: True if the input data is valid (i.e., no missing variables), False otherwise. - """ + Args: + prompt_template (Template): The Jinja2 template object. + input_dict (Dict[str, Any]): A dictionary of input values to check against + the template. - class Default(dict): - def __missing__(self, key: str) -> None: - raise KeyError(key) + Returns: + True if the input data is valid (i.e., no missing variables), False otherwise. + """ try: # Try rendering the template with the input_dict - prompt_template.render(ChainMap(input_dict, Default())) + prompt_template.render(input_dict) return True except UndefinedError as e: + # Jinja throws an UndefinedError for any undefnined template variables, + # assuming the prompt_template was created using StrictUndefined. This + # is the case for anything using PromptRegistry.template_from_string. logger.error(f"Missing key: {e}") return False @@ -54,8 +56,11 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: If the supplied configuration file is a relative path, it is assumed to be part of this Python package. - :param config_path: The path to the configuration file. - :return: The loaded configuration. + Args: + config_path (str): The path to the configuration file. + + Returns: + The loaded configuration. """ if not os.path.isabs(config_path): config_path = os.path.join( diff --git a/src/instructlab/sdg/blocks/iterblock.py b/src/instructlab/sdg/blocks/iterblock.py new file mode 100644 index 00000000..b75b2ab4 --- /dev/null +++ b/src/instructlab/sdg/blocks/iterblock.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +from datasets import Dataset + +# Local +from ..pipeline import _lookup_block_type +from ..registry import BlockRegistry +from .block import Block + +logger = logging.getLogger(__name__) + + +# This is part of the public API. +@BlockRegistry.register("IterBlock") +class IterBlock(Block): + """ + Call another block multiple times for a single set of input + samples, concatening the results of each iteration's call to that + other block in the final returned output. + + Args: + num_iters: The number of times to iterate over the block + block_type: The type of the other block to call (ie LLMBlock) + block_config: Any necessary configuration that will get passed to the + other block to properly configure it. + + Returns: + A Dataset containing all output samples from each iteration + """ + + def __init__( + self, + ctx, + pipe, + block_name, + num_iters, + block_type, + **block_config, + ) -> None: + super().__init__(ctx, pipe, block_name) + self.num_iters = num_iters + block_type = _lookup_block_type(block_type) + self.block = block_type(ctx, pipe, block_name, **block_config) + + def generate(self, samples: Dataset) -> Dataset: + generated_samples = [] + num_iters = self.num_iters + + for _ in range(num_iters): + batch_generated = self.block.generate(samples) + generated_samples.extend(batch_generated) + + return Dataset.from_list(generated_samples) diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 8a5832de..92e7402e 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -7,7 +7,6 @@ # Third Party from datasets import Dataset -from jinja2 import StrictUndefined, Template from tqdm import tqdm import httpx import openai @@ -60,7 +59,7 @@ def server_supports_batched(client, model_id: str) -> bool: def template_from_struct_and_config(struct, config): # replace None with empty strings filtered_config = {k: (v if v is not None else "") for k, v in config.items()} - return Template(struct.format(**filtered_config), undefined=StrictUndefined) + return PromptRegistry.template_from_string(struct.format(**filtered_config)) # This is part of the public API. @@ -163,7 +162,7 @@ def _format_prompt(self, sample: Dict) -> str: if self.model_prompt is None: model_prompt = PromptRegistry.get_template(self.ctx.model_family) elif self.model_prompt: - model_prompt = Template(self.model_prompt, undefined=StrictUndefined) + model_prompt = PromptRegistry.template_from_string(self.model_prompt) else: # Our model prompt is an empty string, which we'll render # verbatim without wrapping in the messages format @@ -222,7 +221,11 @@ def generate(self, samples: Dataset) -> Dataset: Generate the output from the block. This method should first validate the input data, then generate the output, and finally parse the generated output before returning it. - :return: The parsed output after generation. + Args: + samples (Dataset): The samples used as input data + + Returns: + The parsed output after generation. """ num_samples = self.batch_params.get("num_samples", None) logger.debug("Generating outputs for {} samples".format(len(samples))) @@ -407,7 +410,8 @@ def __init__( # Generate the output from the block. This method should first validate the input data, # then generate the output, and finally parse the generated output before returning it. - # :return: The parsed output after generation. + # Returns: + # The parsed output after generation. # """ # num_samples = self.block_config.get("num_samples", None) # logger.debug("Generating outputs for {} samples".format(len(samples))) diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py index 2db22611..e71e0172 100644 --- a/src/instructlab/sdg/registry.py +++ b/src/instructlab/sdg/registry.py @@ -3,7 +3,7 @@ import logging # Third Party -from jinja2 import StrictUndefined, Template +from jinja2 import Environment, StrictUndefined, Template logger = logging.getLogger(__name__) @@ -18,7 +18,8 @@ def register(cls, block_name: str): """ Decorator to register a block class under a specified name. - :param block_name: Name under which to register the block. + Args: + block_name (str): Name under which to register the block. """ def decorator(block_class): @@ -35,7 +36,8 @@ def get_registry(cls): """ Retrieve the current registry map of block types. - :return: Dictionary of registered block names and classes. + Returns: + Dictionary of registered block names and classes. """ return cls._registry @@ -44,18 +46,22 @@ class PromptRegistry: """Registry for managing Jinja2 prompt templates.""" _registry: Dict[str, Template] = {} + _template_env: Environment = Environment(undefined=StrictUndefined) @classmethod def register(cls, *names: str): - """Decorator to register a Jinja2 template function by name. + """Decorator to register Jinja2 template functions by name. - :param name: Name of the template to register. - :return: A decorator that registers the Jinja2 template function. + Args: + names (str): Names of the templates to register. + + Returns: + A decorator that registers the Jinja2 template functions. """ def decorator(func): template_str = func() - template = Template(template_str, undefined=StrictUndefined) + template = cls.template_from_string(template_str) for name in names: cls._registry[name] = template logger.debug(f"Registered prompt template '{name}'") @@ -67,8 +73,11 @@ def decorator(func): def get_template(cls, name: str) -> Template: """Retrieve a Jinja2 template by name. - :param name: Name of the template to retrieve. - :return: The Jinja2 template instance. + Args: + name (str): Name of the template to retrieve. + + Returns: + The Jinja2 template instance. """ if name not in cls._registry: raise KeyError(f"Prompt template '{name}' not found.") @@ -79,6 +88,20 @@ def get_registry(cls): """ Retrieve the current registry map of block types. - :return: Dictionary of registered block names and classes. + Returns: + Dictionary of registered block names and classes. """ return cls._registry + + @classmethod + def template_from_string(cls, template_str): + """ + Create a Jinja Template using our Environment from the source string + + Args: + template_str: The template source, as a string-like thing + + Returns: + Jinja Template + """ + return cls._template_env.from_string(template_str) diff --git a/tests/test_iterblock.py b/tests/test_iterblock.py new file mode 100644 index 00000000..b7561b24 --- /dev/null +++ b/tests/test_iterblock.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock +import unittest + +# Third Party +from datasets import Dataset, Features, Value + +# First Party +from instructlab.sdg import Block, BlockRegistry, IterBlock + + +class TestIterBlock(unittest.TestCase): + @BlockRegistry.register("TestCounterBlock") + class TestCounterBlock(Block): + def __init__( + self, + ctx, + pipe, + block_name, + column, + increment=1, + ) -> None: + super().__init__(ctx, pipe, block_name) + self.column = column + self.increment = increment + self.counter = 0 + + def generate(self, samples: Dataset): + samples = samples.map( + lambda x: {self.column: x[self.column] + self.counter} + ) + self.counter += self.increment + return samples + + def setUp(self): + self.ctx = MagicMock() + self.ctx.dataset_num_procs = 1 + self.pipe = MagicMock() + self.block = IterBlock( + self.ctx, + self.pipe, + "iter_test", + num_iters=4, + block_type="TestCounterBlock", + column="counter", + ) + self.dataset = Dataset.from_dict( + {"counter": [0]}, + features=Features({"counter": Value("int32")}), + ) + + def test_simple_iterate(self): + iterated_dataset = self.block.generate(self.dataset) + # We iterated 4 times, so 4 items in our dataset - one from + # each iteration + self.assertEqual(len(iterated_dataset), 4) + # Each iteration increment our counter, because of our custom + # block used that just increments counters + self.assertEqual(iterated_dataset["counter"], [0, 1, 2, 3]) From db3a1ad185fee213b2661e4fc83851d30277251b Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 10 Dec 2024 13:46:29 -0500 Subject: [PATCH 12/12] Validate blocks by raising BlockConfigParserError instead of asserts Asserts outside of tests should only be used for programming errors in our own code and not to validate user-facing things. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 3 ++- src/instructlab/sdg/blocks/block.py | 5 +++++ src/instructlab/sdg/blocks/llmblock.py | 16 +++++++++------- tests/test_llmblock.py | 25 +++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 6f091424..490df8e4 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -3,6 +3,7 @@ # NOTE: This package imports Torch and other heavy packages. __all__ = ( "Block", + "BlockConfigParserError", "BlockRegistry", "CombineColumnsBlock", "ConditionalLLMBlock", @@ -31,7 +32,7 @@ ) # Local -from .blocks.block import Block +from .blocks.block import Block, BlockConfigParserError from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError from .blocks.iterblock import IterBlock from .blocks.llmblock import ( diff --git a/src/instructlab/sdg/blocks/block.py b/src/instructlab/sdg/blocks/block.py index 3bceca65..bc955596 100644 --- a/src/instructlab/sdg/blocks/block.py +++ b/src/instructlab/sdg/blocks/block.py @@ -68,3 +68,8 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: ) with open(config_path, "r", encoding="utf-8") as config_file: return yaml.safe_load(config_file) + + +# This is part of the public API. +class BlockConfigParserError(Exception): + """An exception raised while parsing a block's configuration.""" diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 92e7402e..89d9a27e 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -15,7 +15,7 @@ # Import prompts to register default chat templates from .. import prompts as default_prompts # pylint: disable=unused-import from ..registry import BlockRegistry, PromptRegistry -from .block import Block +from .block import Block, BlockConfigParserError logger = logging.getLogger(__name__) @@ -290,13 +290,15 @@ def __init__( parser_kwargs={}, batch_kwargs={}, ) -> None: - assert ( - config_paths - ), "ConditionalLLMBlock config_paths requires at least one entry" + if not config_paths: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} requires at least one entry" + ) for config_path in config_paths: - assert ( - len(config_path) == 2 - ), "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names" + if len(config_path) != 2: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} should be a list of config path and selector column names" + ) super().__init__( ctx, pipe, diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 15519fb7..f9d78177 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -10,9 +10,11 @@ from datasets import Dataset, Features, Value from httpx import URL from openai import InternalServerError, NotFoundError +import pytest # First Party from src.instructlab.sdg import ( + BlockConfigParserError, ConditionalLLMBlock, LLMBlock, LLMLogProbBlock, @@ -258,6 +260,29 @@ def setUp(self): self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() + def test_invalid_config_paths(self, _mock_load_config): + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[], + output_cols=[], + selector_column_name="selector", + ) + assert "at least one entry" in str(exc.value) + + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[["foo"]], + output_cols=[], + selector_column_name="selector", + ) + assert "config path and selector" in str(exc.value) + def test_validate(self, mock_load_config): mock_load_config.return_value = { "system": "{{var1}} {{var2}}",