diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index c06f114b..64aeae66 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 Backport backported +CLI codebase +config configs Dataset dataset @@ -17,6 +19,8 @@ FIXME freeform ICL icl +ie +Jinja JSON Langchain's LLM @@ -39,7 +43,9 @@ Splitter subdirectory subfolder Tatsu +templating Tesseract +TODO tokenizer tokenizers unchunked @@ -47,3 +53,4 @@ upsampled UUID vLLM yaml +yamls diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c7f5a12..9b95d287 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,27 @@ +## Unreleased 0.7.x + +### Features + +#### Custom Blocks and Teacher Models via BlockRegistry and PromptRegistry + +Advanced users are now able to supply custom Pipeline `Block` implementations by registering new blocks with the `BlockRegistry`. It's also possible to register new chat templates for custom teacher models using the new `PromptRegistry`. + +See the `tests/testdata/custom_block.py` and `tests/testdata/custom_block_pipeline.yaml` files in this repository for an example of how to create custom blocks and use them from your own pipeline config yamls. + +See the `tests/testdata/custom_prompt.py` file in this repository for an example how to register custom chat templates used when formatting prompts. + +### Breaking Changes + +#### Pipeline configs and Prompt templates switched to Jinja + +All of our [Pipeline config yamls](src/instructlab/sdg/pipelines) and [prompt template files](src/instructlab/sdg/configs) have moved to [Jinja templates](https://pypi.org/project/Jinja2/) instead of Python string `format()` calls. This brings more expressiveness into our templating language - especially for prompt templates - but does mean any variable substitutions need to be updated from single brackets to double brackets - ie `{document}` becomes `{{document}}`. This only impacts you if you were using custom pipeline config yaml files or custom prompt templates in your config blocks. + +#### ImportBlock removed from Pipeline blocks + +Any users that were specifying custom pipeline configs (instead of using the default `full` or `simple` shipped by us) and also using the `ImportBlock` will now need to rewrite their pipelines to no longer use that block. We do not anticipate that anyone was actually using this block, but please reach out if you were so we can capture your needs in a future release. + +### Fixes + ## v0.6.2 ### Fixes diff --git a/pyproject.toml b/pyproject.toml index 88978472..aceddb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,8 +97,8 @@ exclude = [ "^src/instructlab/sdg/generate_data\\.py$", "^src/instructlab/sdg/utils/taxonomy\\.py$", "^src/instructlab/sdg/default_flows\\.py$", - "^src/instructlab/sdg/llmblock\\.py$", - "^src/instructlab/sdg/utilblocks\\.py$", + "^src/instructlab/sdg/blocks/llmblock\\.py$", + "^src/instructlab/sdg/blocks/utilblocks\\.py$", ] # honor excludes by not following there through imports follow_imports = "silent" diff --git a/requirements.txt b/requirements.txt index 3d577c7a..4f5bf7b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ GitPython>=3.1.42,<4.0.0 gguf>=0.6.0 httpx>=0.25.0,<1.0.0 instructlab-schema>=0.4.0 +jinja2>=3.0.0 langchain-text-splitters # Note: this dependency goes along with langchain-text-splitters and may be # removed once that one is removed. diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index b2500ae3..490df8e4 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -3,6 +3,8 @@ # NOTE: This package imports Torch and other heavy packages. __all__ = ( "Block", + "BlockConfigParserError", + "BlockRegistry", "CombineColumnsBlock", "ConditionalLLMBlock", "DuplicateColumnsBlock", @@ -11,27 +13,44 @@ "FilterByValueBlockError", "FlattenColumnsBlock", "GenerateException", - "ImportBlock", + "IterBlock", "LLMBlock", + "LLMLogProbBlock", + "LLMMessagesBlock", "Pipeline", "PipelineBlockError", "PipelineConfigParserError", "PipelineContext", + "PromptRegistry", "RenameColumnsBlock", "SamplePopulatorBlock", "SelectorBlock", "SetToMajorityValueBlock", - "SIMPLE_PIPELINES_PACKAGE", "FULL_PIPELINES_PACKAGE", + "SIMPLE_PIPELINES_PACKAGE", "generate_data", ) # Local -from .block import Block -from .filterblock import FilterByValueBlock, FilterByValueBlockError +from .blocks.block import Block, BlockConfigParserError +from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError +from .blocks.iterblock import IterBlock +from .blocks.llmblock import ( + ConditionalLLMBlock, + LLMBlock, + LLMLogProbBlock, + LLMMessagesBlock, +) +from .blocks.utilblocks import ( + CombineColumnsBlock, + DuplicateColumnsBlock, + FlattenColumnsBlock, + RenameColumnsBlock, + SamplePopulatorBlock, + SelectorBlock, + SetToMajorityValueBlock, +) from .generate_data import generate_data -from .importblock import ImportBlock -from .llmblock import ConditionalLLMBlock, LLMBlock from .pipeline import ( FULL_PIPELINES_PACKAGE, SIMPLE_PIPELINES_PACKAGE, @@ -41,14 +60,6 @@ PipelineConfigParserError, PipelineContext, ) -from .utilblocks import ( - CombineColumnsBlock, - DuplicateColumnsBlock, - FlattenColumnsBlock, - RenameColumnsBlock, - SamplePopulatorBlock, - SelectorBlock, - SetToMajorityValueBlock, -) +from .registry import BlockRegistry, PromptRegistry from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/block.py b/src/instructlab/sdg/block.py deleted file mode 100644 index 20205801..00000000 --- a/src/instructlab/sdg/block.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -from abc import ABC -from typing import Any, Dict, Union -import logging -import os.path - -# Third Party -import yaml - -logger = logging.getLogger(__name__) - - -# This is part of the public API. -class Block(ABC): - def __init__(self, ctx, pipe, block_name: str) -> None: - self.ctx = ctx - self.pipe = pipe - self.block_name = block_name - - def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: - """ - Load the configuration file for this block. - - If the supplied configuration file is a relative path, it is assumed - to be part of this Python package. - - :param config_path: The path to the configuration file. - :return: The loaded configuration. - """ - if not os.path.isabs(config_path): - config_path = os.path.join( - os.path.dirname(self.pipe.config_path), config_path - ) - with open(config_path, "r", encoding="utf-8") as config_file: - return yaml.safe_load(config_file) diff --git a/src/instructlab/sdg/blocks/__init__.py b/src/instructlab/sdg/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/instructlab/sdg/blocks/block.py b/src/instructlab/sdg/blocks/block.py new file mode 100644 index 00000000..bc955596 --- /dev/null +++ b/src/instructlab/sdg/blocks/block.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from abc import ABC +from typing import Any, Dict, Union +import logging +import os.path + +# Third Party +from jinja2 import Template, UndefinedError +import yaml + +# Local +from ..registry import BlockRegistry + +logger = logging.getLogger(__name__) + + +# This is part of the public API. +@BlockRegistry.register("Block") +class Block(ABC): + def __init__(self, ctx, pipe, block_name: str) -> None: + self.ctx = ctx + self.pipe = pipe + self.block_name = block_name + + def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bool: + """ + Validate the input data for this block. This method validates whether all required + variables in the Jinja template are provided in the input_dict. + + Args: + prompt_template (Template): The Jinja2 template object. + input_dict (Dict[str, Any]): A dictionary of input values to check against + the template. + + Returns: + True if the input data is valid (i.e., no missing variables), False otherwise. + """ + + try: + # Try rendering the template with the input_dict + prompt_template.render(input_dict) + return True + except UndefinedError as e: + # Jinja throws an UndefinedError for any undefnined template variables, + # assuming the prompt_template was created using StrictUndefined. This + # is the case for anything using PromptRegistry.template_from_string. + logger.error(f"Missing key: {e}") + return False + + def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: + """ + Load the configuration file for this block. + + If the supplied configuration file is a relative path, it is assumed + to be part of this Python package. + + Args: + config_path (str): The path to the configuration file. + + Returns: + The loaded configuration. + """ + if not os.path.isabs(config_path): + config_path = os.path.join( + os.path.dirname(self.pipe.config_path), config_path + ) + with open(config_path, "r", encoding="utf-8") as config_file: + return yaml.safe_load(config_file) + + +# This is part of the public API. +class BlockConfigParserError(Exception): + """An exception raised while parsing a block's configuration.""" diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/blocks/filterblock.py similarity index 98% rename from src/instructlab/sdg/filterblock.py rename to src/instructlab/sdg/blocks/filterblock.py index 07372aef..181c973e 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/blocks/filterblock.py @@ -8,6 +8,7 @@ from datasets import Dataset # Local +from ..registry import BlockRegistry from .block import Block logger = logging.getLogger(__name__) @@ -86,6 +87,7 @@ def convert_column(sample): # This is part of the public API. +@BlockRegistry.register("FilterByValueBlock") class FilterByValueBlock(Block): def __init__( self, diff --git a/src/instructlab/sdg/blocks/iterblock.py b/src/instructlab/sdg/blocks/iterblock.py new file mode 100644 index 00000000..b75b2ab4 --- /dev/null +++ b/src/instructlab/sdg/blocks/iterblock.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +from datasets import Dataset + +# Local +from ..pipeline import _lookup_block_type +from ..registry import BlockRegistry +from .block import Block + +logger = logging.getLogger(__name__) + + +# This is part of the public API. +@BlockRegistry.register("IterBlock") +class IterBlock(Block): + """ + Call another block multiple times for a single set of input + samples, concatening the results of each iteration's call to that + other block in the final returned output. + + Args: + num_iters: The number of times to iterate over the block + block_type: The type of the other block to call (ie LLMBlock) + block_config: Any necessary configuration that will get passed to the + other block to properly configure it. + + Returns: + A Dataset containing all output samples from each iteration + """ + + def __init__( + self, + ctx, + pipe, + block_name, + num_iters, + block_type, + **block_config, + ) -> None: + super().__init__(ctx, pipe, block_name) + self.num_iters = num_iters + block_type = _lookup_block_type(block_type) + self.block = block_type(ctx, pipe, block_name, **block_config) + + def generate(self, samples: Dataset) -> Dataset: + generated_samples = [] + num_iters = self.num_iters + + for _ in range(num_iters): + batch_generated = self.block.generate(samples) + generated_samples.extend(batch_generated) + + return Dataset.from_list(generated_samples) diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py similarity index 57% rename from src/instructlab/sdg/llmblock.py rename to src/instructlab/sdg/blocks/llmblock.py index 0e9a5f22..89d9a27e 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from collections import ChainMap from typing import Any, Dict import logging import re @@ -13,29 +12,15 @@ import openai # Local -from .block import Block +# Import prompts to register default chat templates +from .. import prompts as default_prompts # pylint: disable=unused-import +from ..registry import BlockRegistry, PromptRegistry +from .block import Block, BlockConfigParserError logger = logging.getLogger(__name__) DEFAULT_MAX_NUM_TOKENS = 4096 -MODEL_FAMILY_MIXTRAL = "mixtral" -MODEL_FAMILY_MERLINITE = "merlinite" - -_MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" -_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" - -_MODEL_PROMPTS = { - MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL, - MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE, -} - - -def _get_model_prompt(model_family): - if model_family not in _MODEL_PROMPTS: - raise ValueError(f"Unknown model family: {model_family}") - return _MODEL_PROMPTS[model_family] - def server_supports_batched(client, model_id: str) -> bool: supported = getattr(client, "server_supports_batched", None) @@ -71,7 +56,14 @@ def server_supports_batched(client, model_id: str) -> bool: return supported +def template_from_struct_and_config(struct, config): + # replace None with empty strings + filtered_config = {k: (v if v is not None else "") for k, v in config.items()} + return PromptRegistry.template_from_string(struct.format(**filtered_config)) + + # This is part of the public API. +@BlockRegistry.register("LLMBlock") # pylint: disable=dangerous-default-value class LLMBlock(Block): # pylint: disable=too-many-instance-attributes @@ -92,7 +84,9 @@ def __init__( self.prompt_struct = ( """{system}\n{introduction}\n{principles}\n{examples}\n{generation}""" ) - self.prompt_template = self.prompt_struct.format(**self.block_config) + self.prompt_template = template_from_struct_and_config( + self.prompt_struct, self.block_config + ) self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs @@ -162,15 +156,25 @@ def _parse(self, generated_string) -> dict: # 2. Non-empty string - the pipeline has specified a custom model prompt # 3. Empty string - the pipeline has specified that no model prompt is needed def _format_prompt(self, sample: Dict) -> str: - prompt = self.prompt_template.format(**sample).strip() + prompt_templated_str = self.prompt_template.render(sample).strip() model_prompt = None if self.model_prompt is None: - model_prompt = _get_model_prompt(self.ctx.model_family) + model_prompt = PromptRegistry.get_template(self.ctx.model_family) elif self.model_prompt: - model_prompt = self.model_prompt + model_prompt = PromptRegistry.template_from_string(self.model_prompt) + else: + # Our model prompt is an empty string, which we'll render + # verbatim without wrapping in the messages format + model_prompt = PromptRegistry.get_template("blank") + + messages = [{"role": "user", "content": prompt_templated_str}] - return prompt if model_prompt is None else model_prompt.format(prompt=prompt) + return model_prompt.render( + messages=messages, + prompt=prompt_templated_str, + add_generation_prompt=True, + ).strip() def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults): gen_kwargs = {**defaults, **gen_kwargs} @@ -217,7 +221,11 @@ def generate(self, samples: Dataset) -> Dataset: Generate the output from the block. This method should first validate the input data, then generate the output, and finally parse the generated output before returning it. - :return: The parsed output after generation. + Args: + samples (Dataset): The samples used as input data + + Returns: + The parsed output after generation. """ num_samples = self.batch_params.get("num_samples", None) logger.debug("Generating outputs for {} samples".format(len(samples))) @@ -265,27 +273,9 @@ def generate(self, samples: Dataset) -> Dataset: return Dataset.from_list(new_data) - def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: - """ - Validate the input data for this block. This method should be implemented by subclasses - to define how the block validates its input data. - - :return: True if the input data is valid, False otherwise. - """ - - class Default(dict): - def __missing__(self, key: str) -> None: - raise KeyError(key) - - try: - prompt_template.format_map(ChainMap(input_dict, Default())) - return True - except KeyError as e: - logger.error("Missing key: {}".format(e)) - return False - # This is part of the public API. +@BlockRegistry.register("ConditionalLLMBlock") class ConditionalLLMBlock(LLMBlock): def __init__( self, @@ -300,6 +290,15 @@ def __init__( parser_kwargs={}, batch_kwargs={}, ) -> None: + if not config_paths: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} requires at least one entry" + ) + for config_path in config_paths: + if len(config_path) != 2: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} should be a list of config path and selector column names" + ) super().__init__( ctx, pipe, @@ -314,11 +313,13 @@ def __init__( self.selector_column_name = selector_column_name self.prompt_template = {} if len(config_paths) == 1 and config_paths[0][1] == "All": - self.prompt_template = self.prompt_struct.format(**self.block_config) + self.prompt_template = template_from_struct_and_config( + self.prompt_struct, self.block_config + ) else: for config, config_key in config_paths: - self.prompt_template[config_key] = self.prompt_struct.format( - **self._load_config(config) + self.prompt_template[config_key] = template_from_struct_and_config( + self.prompt_struct, self._load_config(config) ) def _format_prompt(self, sample: Dict) -> str: @@ -333,5 +334,171 @@ def _format_prompt(self, sample: Dict) -> str: def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: if isinstance(prompt_template, dict): - prompt_template = prompt_template[input_dict[self.selector_column_name]] + if not self.selector_column_name in input_dict: + logger.error( + f"ConditionalLLMBlock {self.block_name} missing key: {self.selector_column_name}" + ) + return False + config_key = input_dict[self.selector_column_name] + if not config_key in prompt_template: + logger.error( + f"ConditionalLLMBlock {self.block_name} selector key {config_key} not found in block config" + ) + return False + prompt_template = prompt_template[config_key] return super()._validate(prompt_template, input_dict) + + +# This is part of the public API. +@BlockRegistry.register("LLMLogProbBlock") +class LLMLogProbBlock(LLMBlock): + def __init__( + self, + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=None, + gen_kwargs={}, + parser_kwargs={}, + batch_kwargs={}, + ) -> None: + super().__init__( + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=model_prompt, + gen_kwargs=gen_kwargs, + parser_kwargs=parser_kwargs, + batch_kwargs=batch_kwargs, + ) + + # def _generate_logprobs(self, samples, **gen_kwargs): + # prompts = [ + # self.model_prompt.format(prompt=self._format_prompt(sample)) + # for sample in samples + # ] + # generate_args = {**self.defaults, **gen_kwargs} + + # # verify if logprobs is mentioned in the generate_args, if not add it and return top10 logprobs + # if "logprobs" not in generate_args: + # generate_args["logprobs"] = 10 + + # if self.server_supports_batched: + # response = self.client.completions.create(prompt=prompts, **generate_args) + # return [choice.logprobs.top_logprobs for choice in response.choices] + + # n = gen_kwargs.get("n", 1) + # results = [] + # for prompt in prompts: + # for _ in range(n): + # response = self.client.completions.create( + # prompt=prompt, **generate_args + # ) + # results.append(response.choices[0].logprobs.top_logprobs) + # return results + + # def _parse(self, generations: List[List[Dict]]) -> List[List[str]]: + # # override the parse method to convert the generations to json string + # # convert the generations to json string to save as dataset + # # this is because the dataset can only store key value pairs which are consistent + # return [[json.dumps(item) for item in sublist] for sublist in generations] + + # def generate(self, samples: Dataset, **gen_kwargs) -> Dataset: + # """ + # Generate the output from the block. This method should first validate the input data, + # then generate the output, and finally parse the generated output before returning it. + + # Returns: + # The parsed output after generation. + # """ + # num_samples = self.block_config.get("num_samples", None) + # logger.debug("Generating outputs for {} samples".format(len(samples))) + + # if (num_samples is not None) and ("num_samples" not in samples.column_names): + # samples = samples.add_column("num_samples", [num_samples] * len(samples)) + + # # validate each sample + # # Log errors and remove invalid samples + # valid_samples = [] + + # for sample in samples: + # if self._validate(self.prompt_template, sample): + # valid_samples.append(sample) + # else: + # logger.warning( + # f"Sample failed validation: {sample}" + # ) # Log details of the failed sample + + # samples = valid_samples + + # if len(samples) == 0: + # logger.warning( + # "No valid samples to generate outputs for, returning empty dataset" + # ) + # return Dataset.from_list([]) + + # # generate the output + + # outputs = self._generate_logprobs(samples, **gen_kwargs) + # logger.debug("Generated outputs: %s", outputs) + + # output_dataset = Dataset.from_list(samples) + # output_dataset = output_dataset.add_column( + # self.output_cols[0], + # self._parse(outputs), # pylint: disable=no-value-for-parameter + # ) + + # return output_dataset + + +# This is part of the public API. +@BlockRegistry.register("LLMMessagesBlock") +class LLMMessagesBlock(LLMBlock): + def __init__( + self, + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=None, + gen_kwargs={}, + parser_kwargs={}, + batch_kwargs={}, + ) -> None: + super().__init__( + ctx, + pipe, + block_name, + config_path, + output_cols, + model_prompt=model_prompt, + gen_kwargs=gen_kwargs, + parser_kwargs=parser_kwargs, + batch_kwargs=batch_kwargs, + ) + + # def _generate(self, samples) -> list: + # generate_args = {**self.defaults, **gen_kwargs} + + # if "n" in generate_args and generate_args.get("temperature", 0) <= 0: + # generate_args["temperature"] = 0.7 + # logger.warning( + # "Temperature should be greater than 0 for n > 1, setting temperature to 0.7" + # ) + + # messages = samples[self.input_col] + + # results = [] + # n = gen_kwargs.get("n", 1) + # for message in messages: + # responses = self.client.chat.completions.create(messages=message, **generate_args) + # if n > 1: + # results.append([choice.message.content for choice in responses.choices]) + # else: + # results.append(responses.choices[0].message.content) + # return results diff --git a/src/instructlab/sdg/utilblocks.py b/src/instructlab/sdg/blocks/utilblocks.py similarity index 94% rename from src/instructlab/sdg/utilblocks.py rename to src/instructlab/sdg/blocks/utilblocks.py index 00d1e32a..a03e6cb7 100644 --- a/src/instructlab/sdg/utilblocks.py +++ b/src/instructlab/sdg/blocks/utilblocks.py @@ -10,12 +10,14 @@ from instructlab.sdg.utils import pandas # Local +from ..registry import BlockRegistry from .block import Block logger = logging.getLogger(__name__) # This is part of the public API. +@BlockRegistry.register("SamplePopulatorBlock") class SamplePopulatorBlock(Block): def __init__( self, ctx, pipe, block_name, config_paths, column_name, post_fix="" @@ -46,6 +48,7 @@ def generate(self, samples) -> Dataset: # This is part of the public API. +@BlockRegistry.register("SelectorBlock") class SelectorBlock(Block): def __init__( self, ctx, pipe, block_name, choice_map, choice_col, output_col @@ -75,6 +78,7 @@ def generate(self, samples: Dataset) -> Dataset: # This is part of the public API. +@BlockRegistry.register("CombineColumnsBlock") class CombineColumnsBlock(Block): def __init__( self, ctx, pipe, block_name, columns, output_col, separator="\n\n" @@ -103,6 +107,7 @@ def generate(self, samples: Dataset) -> Dataset: ) +@BlockRegistry.register("FlattenColumnsBlock") class FlattenColumnsBlock(Block): """Melt/transform a data from a wide format to a long format see pandas.melt for a description @@ -132,6 +137,7 @@ def generate(self, samples: Dataset) -> Dataset: return pandas.dataset_from_pandas_dataframe(flatten_df) +@BlockRegistry.register("DuplicateColumnsBlock") class DuplicateColumnsBlock(Block): def __init__(self, ctx, pipe, block_name: str, columns_map: dict) -> None: """Create duplicate of columns specified in column map. @@ -150,6 +156,7 @@ def generate(self, samples: Dataset): return samples +@BlockRegistry.register("RenameColumnsBlock") class RenameColumnsBlock(Block): def __init__(self, ctx, pipe, block_name: str, columns_map: dict) -> None: """Rename dataset columns. @@ -165,6 +172,7 @@ def generate(self, samples: Dataset): return samples +@BlockRegistry.register("SetToMajorityValueBlock") class SetToMajorityValueBlock(Block): """Set the value of the specified column to the most common value (the mode) diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml index 828bb31f..fa39d916 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_faithfulness.yaml @@ -58,10 +58,10 @@ generation: | * Return the answer between [Start of Answer] and [End of Answer] tags. [Start of Context] - {document} + {{document}} [End of Context] [Start of Response] - {response} + {{response}} [End of Response] start_tags: ["[Start of Explanation]", "[Start of Answer]"] diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml index 3505e23c..88b473c7 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_question.yaml @@ -32,7 +32,7 @@ examples: "" generation: | [Start of Question] - {question} + {{question}} [End of Question] start_tags: ["[Start of Explanation]", "[Start of Rating]"] diff --git a/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml b/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml index 81f1d666..cf944005 100644 --- a/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml +++ b/src/instructlab/sdg/configs/knowledge/evaluate_relevancy.yaml @@ -72,11 +72,11 @@ generation: | Begin your response by providing the feedback followed by the score. Be as objective as possible. [Start of Question] - {question} + {{question}} [End of Question] [Start of Response] - {response} + {{response}} [End of Response] * Return the feedback within the [Start of Feedback] and [End of Feedback] tags. diff --git a/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml b/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml index a2decd76..d7dd61e6 100644 --- a/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml +++ b/src/instructlab/sdg/configs/knowledge/generate_questions_responses.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: Develop a series of educational question and answer pairs from a chapter in a {domain} textbook. +introduction: Develop a series of educational question and answer pairs from a chapter in a {{domain}} textbook. principles: | The questions should: @@ -28,29 +28,29 @@ examples: | Here are some examples of questions: [Document] - {icl_document} + {{icl_document}} [QUESTION] - {icl_query_1} + {{icl_query_1}} [ANSWER] - {icl_response_1} + {{icl_response_1}} [END] [QUESTION] - {icl_query_2} + {{icl_query_2}} [ANSWER] - {icl_response_2} + {{icl_response_2}} [END] [QUESTION] - {icl_query_3} + {{icl_query_3}} [ANSWER] - {icl_response_3} + {{icl_response_3}} [END] generation: | Here is the document: [DOCUMENT] - {document_outline} - {document} + {{document_outline}} + {{document}} diff --git a/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml b/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml index 091001c5..1d8c6aca 100644 --- a/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml +++ b/src/instructlab/sdg/configs/knowledge/mcq_generation.yaml @@ -75,7 +75,7 @@ generation: | Here is the document: [Start of Document] - {document} + {{document}} [End of Document] start_tags: ["[Start of Question]", "[Start of Answer]"] diff --git a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml index 784902b4..63310db5 100644 --- a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml +++ b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: Develop a series of educational question and answer pairs from a chapter in a {domain} textbook. +introduction: Develop a series of educational question and answer pairs from a chapter in a {{domain}} textbook. principles: | Here are the requirements: @@ -14,23 +14,23 @@ principles: | examples: | Here is a sample section of the document as an example: - {icl_document} + {{icl_document}} Here are some examples to help you understand the type of questions that are asked for this document: - {icl_query_1} - {icl_response_1} + {{icl_query_1}} + {{icl_response_1}} - {icl_query_2} - {icl_response_2} + {{icl_query_2}} + {{icl_response_2}} - {icl_query_3} - {icl_response_3} + {{icl_query_3}} + {{icl_response_3}} Here is the document: - {document_outline} - {document} + {{document_outline}} + {{document}} generation: | Provide a single question and answer pair based on the document. diff --git a/src/instructlab/sdg/configs/knowledge/spellcheck.yaml b/src/instructlab/sdg/configs/knowledge/spellcheck.yaml index daf1dafa..c6e7b614 100644 --- a/src/instructlab/sdg/configs/knowledge/spellcheck.yaml +++ b/src/instructlab/sdg/configs/knowledge/spellcheck.yaml @@ -11,7 +11,7 @@ examples: "" generation: | Document: - {document} + {{document}} start_tags: [""] end_tags: [""] diff --git a/src/instructlab/sdg/configs/skills/contexts.yaml b/src/instructlab/sdg/configs/skills/contexts.yaml index be257c80..813680e1 100644 --- a/src/instructlab/sdg/configs/skills/contexts.yaml +++ b/src/instructlab/sdg/configs/skills/contexts.yaml @@ -1,6 +1,6 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. -introduction: You are asked to come up with a diverse context for - {task_description}. +introduction: You are asked to come up with a diverse context for - {{task_description}}. principles: | Please follow these guiding principles when generating responses: * Use proper grammar and punctuation. @@ -11,7 +11,7 @@ principles: | examples: | To better assist you with this task, here is an example of a context: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] generation: | diff --git a/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml b/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml index 1dd3e38d..298181f4 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml @@ -29,11 +29,11 @@ examples: | generation: | Here's the question and the answer you need to evaluate: [Start of Question] - {question} + {{question}} [End of Question] [Start of Answer] - {response} + {{response}} [End of Answer] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the answer on a scale of 1 to 3 as mentioned above. diff --git a/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml b/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml index 50bafcb8..850fe006 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml @@ -9,7 +9,7 @@ principles: | * The questions should be in English. * The questions should be 1 to 2 sentences long and should be properly formatted. * The question should not be offensive, abusive, or harmful. It should be safe and respectful. - * The question should be relevant to the task given - {task_description}. + * The question should be relevant to the task given - {{task_description}}. If the question meets the above requirements, please rate it 1. If not, please rate it 0. @@ -32,10 +32,10 @@ examples: | generation: | Here's the question you need to evaluate: - Task Description: {task_description} + Task Description: {{task_description}} [Start of Question] - {question} + {{question}} [End of Question] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the question on a scale of 0 or 1 as mentioned above. Strictly follow the format below: diff --git a/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml b/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml index 15132c88..993c52f2 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml @@ -35,15 +35,15 @@ generation: | Here's the context, question and the answer you need to evaluate: [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] [Start of Answer] - {response} + {{response}} [End of Answer] * Return the evaluation between [Start of Evaluation] and [End of Evaluation] tags. diff --git a/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml b/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml index 6999987f..fd761806 100644 --- a/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml +++ b/src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml @@ -9,7 +9,7 @@ principles: | * The questions should be in English. * The questions should be 1 to 2 sentences long and should be properly formatted. * The question should not be offensive, abusive, or harmful. It should be safe and respectful. - * The question should be relevant to the task given - {task_description}. + * The question should be relevant to the task given - {{task_description}}. * Most importantly all the questions should be grounded in the context provided and should be answerable solely based on the provided context. If the question meets the above requirements, please rate it 1. If not, please rate it 0. @@ -37,10 +37,10 @@ generation: | Here's the context and question you need to evaluate. Return the evaluation between [Start of Evaluation] and [End of Evaluation] tags. [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the question on a scale of 0 or 1 as mentioned above. diff --git a/src/instructlab/sdg/configs/skills/freeform_questions.yaml b/src/instructlab/sdg/configs/skills/freeform_questions.yaml index f3d1ed90..de0e64d4 100644 --- a/src/instructlab/sdg/configs/skills/freeform_questions.yaml +++ b/src/instructlab/sdg/configs/skills/freeform_questions.yaml @@ -1,7 +1,7 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. introduction: | - You are asked to come up with a set of {num_samples} diverse questions - {task_description}. + You are asked to come up with a set of {{num_samples}} diverse questions - {{task_description}}. principles: | Please follow these guiding principles when generating responses: @@ -19,11 +19,11 @@ examples: | To better assist you with this task, here is an example: [Start of Question] - {seed_question} + {{seed_question}} [End of Question] generation: | - Now generate {num_samples} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Return each question between [Start of Question] and [End of Question] tags. + Now generate {{num_samples}} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Return each question between [Start of Question] and [End of Question] tags. start_tags: ["[Start of Question]"] end_tags: ["[End of Question]"] diff --git a/src/instructlab/sdg/configs/skills/freeform_responses.yaml b/src/instructlab/sdg/configs/skills/freeform_responses.yaml index cf7ff177..b68c4790 100644 --- a/src/instructlab/sdg/configs/skills/freeform_responses.yaml +++ b/src/instructlab/sdg/configs/skills/freeform_responses.yaml @@ -13,18 +13,18 @@ principles: | examples: | To better assist you with this task, here is an example: [Start of Question] - {seed_question} + {{seed_question}} [End of Question] [Start of Response] - {seed_response} + {{seed_response}} [End of Response] generation: | Now generate a response to the following prompt. Remember to use the same style and format as the example above. [Start of Question] - {question} + {{question}} [End of Question] Return the response between [Start of Response] and [End of Response] tags. diff --git a/src/instructlab/sdg/configs/skills/grounded_questions.yaml b/src/instructlab/sdg/configs/skills/grounded_questions.yaml index 904523c9..979d35f3 100644 --- a/src/instructlab/sdg/configs/skills/grounded_questions.yaml +++ b/src/instructlab/sdg/configs/skills/grounded_questions.yaml @@ -1,7 +1,7 @@ system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. introduction: | - You are asked to come up with a set of {num_samples} diverse questions - {task_description}. + You are asked to come up with a set of {{num_samples}} diverse questions - {{task_description}}. principles: | Please follow these guiding principles when generating responses: @@ -21,17 +21,17 @@ examples: | To better assist you with this task, here is an example: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] [Start of Question] - {seed_question} + {{seed_question}} [End of Question] generation: | - Now generate {num_samples} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Do not return any contexts or answers, only the questions. Return each question between [Start of Question] and [End of Question] tags. + Now generate {{num_samples}} such questions, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Do not return any contexts or answers, only the questions. Return each question between [Start of Question] and [End of Question] tags. [Start of Context] - {context} + {{context}} [End of Context] start_tags: ["[Start of Question]"] diff --git a/src/instructlab/sdg/configs/skills/grounded_responses.yaml b/src/instructlab/sdg/configs/skills/grounded_responses.yaml index bacd5c10..d8161591 100644 --- a/src/instructlab/sdg/configs/skills/grounded_responses.yaml +++ b/src/instructlab/sdg/configs/skills/grounded_responses.yaml @@ -14,15 +14,15 @@ examples: | To better assist you with this task, here is an example: [Start of Context] - {seed_context} + {{seed_context}} [End of Context] [Start of Question] - {seed_question} + {{seed_question}} [End of Question] [Start of Response] - {seed_response} + {{seed_response}} [End of Response] generation: | @@ -30,10 +30,10 @@ generation: | Return the response between [Start of Response] and [End of Response] tags. [Start of Context] - {context} + {{context}} [End of Context] [Start of Question] - {question} + {{question}} [End of Question] Return the response between [Start of Response] and [End of Response] tags. diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml index 5144b22c..4a4217af 100644 --- a/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml @@ -13,12 +13,12 @@ principles: | 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. examples: | - The task is {task_description}. + The task is {{task_description}}. Here is an example to help you understand the type of questions that are asked for: - {seed_question} - {seed_response} + {{seed_question}} + {{seed_response}} generation: | Provide a single question and answer pair based on the examples. diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml index 54588f91..b0423218 100644 --- a/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml @@ -13,16 +13,16 @@ principles: | 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. examples: | - The task is {task_description}. + The task is {{task_description}}. Here is some context for the example question: - {seed_context} + {{seed_context}} Here is an example to help you understand the type of questions that are asked for: - {seed_question} - {seed_response} + {{seed_question}} + {{seed_response}} generation: | Provide a single question and answer pair based on the example. diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index aac007b7..74e279b9 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -18,14 +18,9 @@ import yaml # First Party -# pylint: disable=ungrouped-imports +from instructlab.sdg.blocks.llmblock import DEFAULT_MAX_NUM_TOKENS from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init -from instructlab.sdg.llmblock import ( - DEFAULT_MAX_NUM_TOKENS, - MODEL_FAMILY_MERLINITE, - MODEL_FAMILY_MIXTRAL, -) from instructlab.sdg.pipeline import ( FULL_PIPELINES_PACKAGE, SIMPLE_PIPELINES_PACKAGE, @@ -359,10 +354,7 @@ def generate_data( logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}") - if models.get_model_family(model_family, model_name) == "mixtral": - model_family = MODEL_FAMILY_MIXTRAL - else: - model_family = MODEL_FAMILY_MERLINITE + model_family = models.get_model_family(model_family, model_name) ctx = _context_init( client, diff --git a/src/instructlab/sdg/importblock.py b/src/instructlab/sdg/importblock.py deleted file mode 100644 index e65a7f01..00000000 --- a/src/instructlab/sdg/importblock.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -import logging - -# Third Party -from datasets import Dataset - -# Local -from .block import Block - -logger = logging.getLogger(__name__) - - -# This is part of the public API. -class ImportBlock(Block): - def __init__( - self, - ctx, - pipe, - block_name, - path, - ) -> None: - """ - ImportBlock imports a chain of blocks from another pipeline config file. - - Parameters: - - ctx (PipelineContext): A PipelineContext object containing runtime parameters. - - pipe (Pipeline): The Pipeline containing this block in its chain. - - block_name (str): An identifier for this block. - - path (str): A path (absolute, or relative to the instructlab.sdg package) to a pipeline config file. - """ - super().__init__(ctx, pipe, block_name) - self.path = path - - # FIXME: find a better fix for this circular import error: - # - # src/instructlab/sdg/__init__.py:29: in - # from .importblock import ImportBlock - # src/instructlab/sdg/importblock.py:6: in - # from . import pipeline - # src/instructlab/sdg/pipeline.py:102: in - # "ImportBlock": importblock.ImportBlock, - # E AttributeError: partially initialized module 'src.instructlab.sdg.importblock' has no attribute 'ImportBlock' (most likely due to a circular import) - # - # pylint: disable=C0415 - # Local - from . import pipeline - - self.pipeline = pipeline.Pipeline.from_file(self.ctx, self.path) - - def generate(self, samples) -> Dataset: - logger.info("ImportBlock chaining to blocks from {self.path}") - return self.pipeline.generate(samples) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 52621f81..59613a8e 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -19,8 +19,9 @@ from instructlab.sdg.utils import pandas # Local -from . import filterblock, importblock, llmblock, utilblocks -from .block import Block +from .blocks import llmblock +from .blocks.block import Block +from .registry import BlockRegistry logger = logging.getLogger(__name__) @@ -255,25 +256,11 @@ def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int] ) -_block_types = { - "CombineColumnsBlock": utilblocks.CombineColumnsBlock, - "ConditionalLLMBlock": llmblock.ConditionalLLMBlock, - "DuplicateColumnsBlock": utilblocks.DuplicateColumnsBlock, - "FilterByValueBlock": filterblock.FilterByValueBlock, - "FlattenColumnsBlock": utilblocks.FlattenColumnsBlock, - "ImportBlock": importblock.ImportBlock, - "LLMBlock": llmblock.LLMBlock, - "RenameColumnsBlock": utilblocks.RenameColumnsBlock, - "SamplePopulatorBlock": utilblocks.SamplePopulatorBlock, - "SelectorBlock": utilblocks.SelectorBlock, - "SetToMajorityValueBlock": utilblocks.SetToMajorityValueBlock, -} - - def _lookup_block_type(block_type): - if not block_type in _block_types: + block_types = BlockRegistry.get_registry() + if not block_type in block_types: raise PipelineConfigParserError(f"Unknown block type {block_type}") - return _block_types[block_type] + return block_types[block_type] _PIPELINE_CONFIG_PARSER_MAJOR = 1 diff --git a/src/instructlab/sdg/pipelines/schema/v1.json b/src/instructlab/sdg/pipelines/schema/v1.json index 64b9c477..690d3d73 100644 --- a/src/instructlab/sdg/pipelines/schema/v1.json +++ b/src/instructlab/sdg/pipelines/schema/v1.json @@ -34,17 +34,6 @@ }, "config": { "anyOf": [ - { - "type": "object", - "description": "ImportBlock", - "required": ["path"], - "additionalProperties": false, - "properties": { - "path": { - "type": "string" - } - } - }, { "type": "object", "description": "FilterByValueBlock", diff --git a/src/instructlab/sdg/prompts.py b/src/instructlab/sdg/prompts.py new file mode 100644 index 00000000..77d35a53 --- /dev/null +++ b/src/instructlab/sdg/prompts.py @@ -0,0 +1,24 @@ +# Local +from .registry import PromptRegistry + +# {{ prompt }} gives us the config's raw prompt string, not wrapped in +# any messages format + +# {{ messages }} gives us the config's prompt in messages format, +# where the config's prompt becomes the content value of a user role +# message + + +@PromptRegistry.register("blank") +def blank_chat_template(): + return """{{ prompt }}""" + + +@PromptRegistry.register("merlinite", "granite") +def merlinite_chat_template(): + return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}""" + + +@PromptRegistry.register("mixtral", "mistral") +def mixtral_chat_template(): + return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + ''}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n""" diff --git a/src/instructlab/sdg/registry.py b/src/instructlab/sdg/registry.py new file mode 100644 index 00000000..e71e0172 --- /dev/null +++ b/src/instructlab/sdg/registry.py @@ -0,0 +1,107 @@ +# Standard +from typing import Dict +import logging + +# Third Party +from jinja2 import Environment, StrictUndefined, Template + +logger = logging.getLogger(__name__) + + +class BlockRegistry: + """Registry for block classes to avoid manual additions to block type map.""" + + _registry: Dict[str, type] = {} + + @classmethod + def register(cls, block_name: str): + """ + Decorator to register a block class under a specified name. + + Args: + block_name (str): Name under which to register the block. + """ + + def decorator(block_class): + cls._registry[block_name] = block_class + logger.debug( + f"Registered block '{block_name}' with class '{block_class.__name__}'" + ) + return block_class + + return decorator + + @classmethod + def get_registry(cls): + """ + Retrieve the current registry map of block types. + + Returns: + Dictionary of registered block names and classes. + """ + return cls._registry + + +class PromptRegistry: + """Registry for managing Jinja2 prompt templates.""" + + _registry: Dict[str, Template] = {} + _template_env: Environment = Environment(undefined=StrictUndefined) + + @classmethod + def register(cls, *names: str): + """Decorator to register Jinja2 template functions by name. + + Args: + names (str): Names of the templates to register. + + Returns: + A decorator that registers the Jinja2 template functions. + """ + + def decorator(func): + template_str = func() + template = cls.template_from_string(template_str) + for name in names: + cls._registry[name] = template + logger.debug(f"Registered prompt template '{name}'") + return func + + return decorator + + @classmethod + def get_template(cls, name: str) -> Template: + """Retrieve a Jinja2 template by name. + + Args: + name (str): Name of the template to retrieve. + + Returns: + The Jinja2 template instance. + """ + if name not in cls._registry: + raise KeyError(f"Prompt template '{name}' not found.") + return cls._registry[name] + + @classmethod + def get_registry(cls): + """ + Retrieve the current registry map of block types. + + Returns: + Dictionary of registered block names and classes. + """ + return cls._registry + + @classmethod + def template_from_string(cls, template_str): + """ + Create a Jinja Template using our Environment from the source string + + Args: + template_str: The template source, as a string-like thing + + Returns: + Jinja Template + """ + return cls._template_env.from_string(template_str) diff --git a/src/instructlab/sdg/utils/models.py b/src/instructlab/sdg/utils/models.py index a2736988..da01421b 100644 --- a/src/instructlab/sdg/utils/models.py +++ b/src/instructlab/sdg/utils/models.py @@ -5,25 +5,22 @@ import re # First Party +from instructlab.sdg.registry import PromptRegistry from instructlab.sdg.utils import GenerateException # When otherwise unknown, ilab uses this as the default family DEFAULT_MODEL_FAMILY = "merlinite" -# Model families understood by ilab -MODEL_FAMILIES = set(("merlinite", "mixtral")) - -# Map model names to their family -MODEL_FAMILY_MAPPINGS = {"granite": "merlinite", "mistral": "mixtral"} - def get_model_family(model_family, model_path): - model_family_retrieved = MODEL_FAMILY_MAPPINGS.get(model_family, model_family) - if model_family_retrieved and model_family_retrieved.lower() not in MODEL_FAMILIES: - raise GenerateException("Unknown model family: %s" % model_family_retrieved) + registry = PromptRegistry.get_registry() + + # A model_family was given, so use it explicitly + if model_family: + if model_family not in registry: + raise GenerateException("Unknown model family: %s" % model_family) + return model_family # Try to guess the model family based on the model's filename guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower() - guess = MODEL_FAMILY_MAPPINGS.get(guess, guess) - - return guess if guess in MODEL_FAMILIES else DEFAULT_MODEL_FAMILY + return guess if guess in registry else DEFAULT_MODEL_FAMILY diff --git a/tests/functional/test_custom_block.py b/tests/functional/test_custom_block.py new file mode 100644 index 00000000..01bcdb07 --- /dev/null +++ b/tests/functional/test_custom_block.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import pathlib +import subprocess +import sys + + +def test_custom_block(testdata_path: pathlib.Path): + script = testdata_path.joinpath("custom_block.py") + subprocess.check_call([sys.executable, str(script)], text=True) + + +def test_custom_prompt(testdata_path: pathlib.Path): + script = testdata_path.joinpath("custom_prompt.py") + subprocess.check_call([sys.executable, str(script)], text=True) diff --git a/tests/test_default_pipeline_configs.py b/tests/test_default_pipeline_configs.py index a774fb2d..8009874b 100644 --- a/tests/test_default_pipeline_configs.py +++ b/tests/test_default_pipeline_configs.py @@ -9,13 +9,15 @@ from datasets import Dataset # First Party -from instructlab.sdg.filterblock import FilterByValueBlock -from instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock -from instructlab.sdg.pipeline import Pipeline, PipelineContext -from instructlab.sdg.utilblocks import ( +from instructlab.sdg import ( CombineColumnsBlock, + ConditionalLLMBlock, DuplicateColumnsBlock, + FilterByValueBlock, FlattenColumnsBlock, + LLMBlock, + Pipeline, + PipelineContext, RenameColumnsBlock, SamplePopulatorBlock, SelectorBlock, @@ -35,7 +37,7 @@ def _noop_generate(self, samples): @patch.object(RenameColumnsBlock, "generate", _noop_generate) @patch.object(SamplePopulatorBlock, "generate", _noop_generate) @patch.object(SelectorBlock, "generate", _noop_generate) -@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True) +@patch("instructlab.sdg.blocks.llmblock.server_supports_batched", lambda c, m: True) @patch.object(Pipeline, "_drop_duplicates", lambda self, dataset, cols: dataset) class TestDefaultPipelineConfigs(unittest.TestCase): def setUp(self): diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index e84258ab..ab43890a 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -9,8 +9,7 @@ from datasets import Dataset, Features, Value # First Party -from instructlab.sdg.filterblock import FilterByValueBlock -from instructlab.sdg.pipeline import PipelineContext +from instructlab.sdg import FilterByValueBlock class TestFilterByValueBlock(unittest.TestCase): diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index d3f6ced5..c5415636 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -20,9 +20,8 @@ import yaml # First Party +from instructlab.sdg import LLMBlock, PipelineContext from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data -from instructlab.sdg.llmblock import LLMBlock -from instructlab.sdg.pipeline import PipelineContext TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant." diff --git a/tests/test_importblock.py b/tests/test_importblock.py deleted file mode 100644 index d13d59cc..00000000 --- a/tests/test_importblock.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -from unittest.mock import MagicMock, patch -import os -import tempfile -import unittest - -# Third Party -from datasets import Dataset, Features, Value - -# First Party -from instructlab.sdg.importblock import ImportBlock -from instructlab.sdg.pipeline import Pipeline - -# Local -from .conftest import get_single_threaded_ctx - - -class TestImportBlockWithMockPipeline(unittest.TestCase): - @patch("instructlab.sdg.pipeline.Pipeline") - def setUp(self, mock_pipeline): - self.ctx = get_single_threaded_ctx() - self.pipe = MagicMock() - self.block_name = "test_block" - self.path = "/path/to/config" - self.mock_pipeline = mock_pipeline - self.import_block = ImportBlock(self.ctx, self.pipe, self.block_name, self.path) - self.dataset = Dataset.from_dict({}) - - def test_initialization(self): - self.assertEqual(self.import_block.block_name, self.block_name) - self.assertEqual(self.import_block.path, self.path) - self.mock_pipeline.from_file.assert_called_once_with(self.ctx, self.path) - - def test_generate(self): - self.mock_pipeline.from_file.return_value.generate.return_value = self.dataset - samples = self.import_block.generate(self.dataset) - self.mock_pipeline.from_file.return_value.generate.assert_called_once_with( - samples - ) - self.assertEqual(samples, self.dataset) - - -_CHILD_YAML = """\ -version: "1.0" -blocks: -- name: greater_than_thirty - type: FilterByValueBlock - config: - filter_column: age - filter_value: 30 - operation: gt - convert_dtype: int -""" - - -_PARENT_YAML_FMT = """\ -version: "1.0" -blocks: -- name: forty_or_under - type: FilterByValueBlock - config: - filter_column: age - filter_value: 40 - operation: le - convert_dtype: int - default_value: 1000 -- name: import_child - type: ImportBlock - config: - path: %s -- name: big_bdays - type: FilterByValueBlock - config: - filter_column: age - filter_value: - - 30 - - 40 - operation: eq - convert_dtype: int -""" - - -class TestImportBlockWithFilterByValue(unittest.TestCase): - def setUp(self): - self.ctx = get_single_threaded_ctx() - self.child_yaml = self._write_tmp_yaml(_CHILD_YAML) - self.parent_yaml = self._write_tmp_yaml(_PARENT_YAML_FMT % self.child_yaml) - self.dataset = Dataset.from_dict( - {"age": ["25", "30", "35", "40", "45"]}, - features=Features({"age": Value("string")}), - ) - - def tearDown(self): - os.remove(self.parent_yaml) - os.remove(self.child_yaml) - - def _write_tmp_yaml(self, content): - tmp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".yaml") - tmp_file.write(content) - tmp_file.close() - return tmp_file.name - - def test_generate(self): - pipeline = Pipeline.from_file(self.ctx, self.parent_yaml) - filtered_dataset = pipeline.generate(self.dataset) - self.assertEqual(len(filtered_dataset), 1) - self.assertEqual(filtered_dataset["age"], [40]) diff --git a/tests/test_iterblock.py b/tests/test_iterblock.py new file mode 100644 index 00000000..b7561b24 --- /dev/null +++ b/tests/test_iterblock.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock +import unittest + +# Third Party +from datasets import Dataset, Features, Value + +# First Party +from instructlab.sdg import Block, BlockRegistry, IterBlock + + +class TestIterBlock(unittest.TestCase): + @BlockRegistry.register("TestCounterBlock") + class TestCounterBlock(Block): + def __init__( + self, + ctx, + pipe, + block_name, + column, + increment=1, + ) -> None: + super().__init__(ctx, pipe, block_name) + self.column = column + self.increment = increment + self.counter = 0 + + def generate(self, samples: Dataset): + samples = samples.map( + lambda x: {self.column: x[self.column] + self.counter} + ) + self.counter += self.increment + return samples + + def setUp(self): + self.ctx = MagicMock() + self.ctx.dataset_num_procs = 1 + self.pipe = MagicMock() + self.block = IterBlock( + self.ctx, + self.pipe, + "iter_test", + num_iters=4, + block_type="TestCounterBlock", + column="counter", + ) + self.dataset = Dataset.from_dict( + {"counter": [0]}, + features=Features({"counter": Value("int32")}), + ) + + def test_simple_iterate(self): + iterated_dataset = self.block.generate(self.dataset) + # We iterated 4 times, so 4 items in our dataset - one from + # each iteration + self.assertEqual(len(iterated_dataset), 4) + # Each iteration increment our counter, because of our custom + # block used that just increments counters + self.assertEqual(iterated_dataset["counter"], [0, 1, 2, 3]) diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 613ea846..f9d78177 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -1,18 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from importlib import resources from unittest.mock import MagicMock, patch +import os import unittest # Third Party from datasets import Dataset, Features, Value from httpx import URL -from openai import InternalServerError, NotFoundError, OpenAI +from openai import InternalServerError, NotFoundError +import pytest # First Party -from src.instructlab.sdg.llmblock import LLMBlock, server_supports_batched +from src.instructlab.sdg import ( + BlockConfigParserError, + ConditionalLLMBlock, + LLMBlock, + LLMLogProbBlock, + LLMMessagesBlock, +) +from src.instructlab.sdg.blocks.llmblock import server_supports_batched +@patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMBlockModelPrompt(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() @@ -20,7 +31,7 @@ def setUp(self): self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() self.config_return_value = { - "system": "{fruit}", + "system": "{{fruit}}", "introduction": "introduction", "principles": "principles", "examples": "examples", @@ -31,10 +42,9 @@ def setUp(self): features=Features({"fruit": Value("string")}), ) - @patch("src.instructlab.sdg.block.Block._load_config") def test_model_prompt_empty_string(self, mock_load_config): mock_load_config.return_value = self.config_return_value - # Ensure that if an empty model_prompt is not specified, no model prompt is used. + # Ensure that if an empty model_prompt is specified, no model prompt is used. block = LLMBlock( ctx=self.mock_ctx, pipe=self.mock_pipe, @@ -50,7 +60,6 @@ def test_model_prompt_empty_string(self, mock_load_config): "no model prompt should be used when explicitly set to an empty string", ) - @patch("src.instructlab.sdg.block.Block._load_config") def test_model_prompt_none(self, mock_load_config): mock_load_config.return_value = self.config_return_value # Ensure that if a custom model_prompt is not specified, it defaults to setting it to @@ -70,8 +79,7 @@ def test_model_prompt_none(self, mock_load_config): "model_prompt based on model_family should be used set to None", ) - @patch("src.instructlab.sdg.block.Block._load_config") - def test_model_prompt_none(self, mock_load_config): + def test_model_prompt_custom(self, mock_load_config): mock_load_config.return_value = self.config_return_value # Ensure that if a custom model_prompt is specified, it is used correctly block = LLMBlock( @@ -80,20 +88,88 @@ def test_model_prompt_none(self, mock_load_config): block_name="test_block", config_path="", output_cols=[], - model_prompt="FOO {prompt} BAR", + model_prompt="FOO {{prompt}} BAR", ) prompt = block._format_prompt(self.dataset[1]) self.assertEqual( prompt, "FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR", - "model_prompt should be a non-empty string when set to None", + "custom model_prompt was not used when explicitly set", + ) + + +class TestLLMBlockWithRealConfigs(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + + def test_configs_with_invalid_sample(self): + for config_type in ["knowledge", "skills"]: + for config_yaml in resources.files( + f"instructlab.sdg.configs.{config_type}" + ).iterdir(): + if config_yaml.suffix != ".yaml": + continue + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name=config_yaml.stem, + config_path=config_yaml, + output_cols=[], + ) + sample = {"foo": "bar"} + assert not block._validate( + block.prompt_template, sample + ), f"{config_type} config {config_yaml.name} validated even though it was given a sample with none of the expected fields" + + def test_simple_generate_qa_with_valid_sample(self): + config_yaml = os.path.join( + resources.files("instructlab.sdg.configs.knowledge"), + "simple_generate_qa.yaml", + ) + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path=config_yaml, + output_cols=[], ) + sample = { + "domain": "domain goes here", + "document": "document goes here", + "document_outline": "document outline goes here", + "icl_document": "context goes here", + "icl_query_1": "query 1 goes here", + "icl_response_1": "response 1 goes here", + "icl_query_2": "query 2 goes here", + "icl_response_2": "response 2 goes here", + "icl_query_3": "query 3 goes here", + "icl_response_3": "response 3 goes here", + } + assert block._validate(block.prompt_template, sample) + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestLLMBlockOtherFunctions(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } - @patch("src.instructlab.sdg.block.Block._load_config") def test_max_num_tokens_override(self, mock_load_config): mock_load_config.return_value = self.config_return_value self.mock_ctx.max_num_tokens = 512 - # Ensure that if a custom model_prompt is specified, it is used correctly + # Ensure that if max_tokens is specified, it is used correctly block = LLMBlock( ctx=self.mock_ctx, pipe=self.mock_pipe, @@ -106,26 +182,49 @@ def test_max_num_tokens_override(self, mock_load_config): num_tokens = block.gen_kwargs["max_tokens"] assert num_tokens == 512 + def test_validate(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{var1}} {{var2}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + + assert not block._validate(block.prompt_template, {}) + assert block._validate(block.prompt_template, {"var1": "foo", "var2": "bar"}) + + +class TestLLMBlockBatching(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.mock_client = MagicMock() + self.mock_client.server_supports_batched = None + self.mock_client.base_url = URL("http://localhost:8000/v1") + self.mock_client.get = MagicMock() + self.mock_ctx.client = self.mock_client + def test_server_supports_batched_llama_cpp(self): resp_text = """{"message":"Hello from InstructLab! Visit us at https://instructlab.ai"}""" - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.return_value = MagicMock() - mock_client.get().text = resp_text - self.mock_ctx.client = mock_client + self.mock_client.get.return_value = MagicMock() + self.mock_client.get().text = resp_text supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert not supports_batched def test_server_supports_batched_other_llama_cpp(self): resp_text = "another server" - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.return_value = MagicMock() - mock_client.get().text = resp_text + self.mock_client.get.return_value = MagicMock() + self.mock_client.get().text = resp_text mock_completion = MagicMock() mock_completion.create = MagicMock() mock_completion.create.side_effect = InternalServerError( @@ -133,17 +232,12 @@ def test_server_supports_batched_other_llama_cpp(self): response=MagicMock(), body=MagicMock(), ) - mock_client.completions = mock_completion - self.mock_ctx.client = mock_client + self.mock_client.completions = mock_completion supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert not supports_batched def test_server_supports_batched_vllm(self): - mock_client = MagicMock() - mock_client.server_supports_batched = None - mock_client.base_url = URL("http://localhost:8000/v1") - mock_client.get = MagicMock() - mock_client.get.side_effect = NotFoundError( + self.mock_client.get.side_effect = NotFoundError( "mock error", response=MagicMock(), body=MagicMock(), @@ -153,7 +247,117 @@ def test_server_supports_batched_vllm(self): mock_completion = MagicMock() mock_completion.create = MagicMock() mock_completion.create.return_value = mock_completion_resp - mock_client.completions = mock_completion - self.mock_ctx.client = mock_client + self.mock_client.completions = mock_completion supports_batched = server_supports_batched(self.mock_ctx.client, "my-model") assert supports_batched + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestConditionalLLMBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + + def test_invalid_config_paths(self, _mock_load_config): + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[], + output_cols=[], + selector_column_name="selector", + ) + assert "at least one entry" in str(exc.value) + + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[["foo"]], + output_cols=[], + selector_column_name="selector", + ) + assert "config path and selector" in str(exc.value) + + def test_validate(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{var1}} {{var2}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + block = ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[["/foo/bar", "_A_"]], + output_cols=[], + selector_column_name="selector", + ) + + assert not block._validate(block.prompt_template, {}) + assert not block._validate( + block.prompt_template, {"selector": "_B_", "var1": "foo", "var2": "bar"} + ) + assert block._validate( + block.prompt_template, {"selector": "_A_", "var1": "foo", "var2": "bar"} + ) + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestLLMLogProbBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + + def test_constructor_works(self, mock_load_config): + mock_load_config.return_value = self.config_return_value + block = LLMLogProbBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block is not None + + +@patch("src.instructlab.sdg.blocks.block.Block._load_config") +class TestLLMMessagesBlock(unittest.TestCase): + def setUp(self): + self.mock_ctx = MagicMock() + self.mock_ctx.model_family = "mixtral" + self.mock_ctx.model_id = "test_model" + self.mock_pipe = MagicMock() + self.config_return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + + def test_constructor_works(self, mock_load_config): + mock_load_config.return_value = self.config_return_value + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block is not None diff --git a/tests/test_models.py b/tests/test_models.py index 27cf4d68..e3755553 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,7 +13,7 @@ class TestModels: def test_granite_model_family(self): assert ( models.get_model_family("granite", "./models/granite-7b-lab-Q4_K_M.gguf") - == "merlinite" + == "granite" ) def test_merlinite_model_family(self): @@ -32,12 +32,30 @@ def test_mixtral_model_family(self): == "mixtral" ) + def test_mistral_model_family(self): + assert ( + models.get_model_family( + "mistral", "./models/mistral-7b-instruct-v0.2.Q4_K_M.gguf" + ) + == "mistral" + ) + def test_default_model_family(self): + assert ( + models.get_model_family(None, "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf") + == "merlinite" + ) + assert ( + models.get_model_family("", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf") + == "merlinite" + ) + + def test_model_family_overrides(self): assert ( models.get_model_family( "mixtral", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf" ) - == "merlinite" + == "mixtral" ) def test_unknown_model_family(self): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7367161c..6c801845 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,17 +14,18 @@ import pytest # First Party -from instructlab.sdg.block import Block -from instructlab.sdg.pipeline import Pipeline, PipelineBlockError +from instructlab.sdg import Block, Pipeline, PipelineBlockError ## Helpers ## @contextmanager def block_types(block_types_dict): + get_registry_mock = mock.MagicMock() + get_registry_mock.return_value = block_types_dict with mock.patch( - "instructlab.sdg.pipeline._block_types", - block_types_dict, + "instructlab.sdg.registry.BlockRegistry.get_registry", + get_registry_mock, ): yield diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 00000000..0f197a6d --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +# First Party +from src.instructlab.sdg.registry import BlockRegistry + + +def test_block_registry(): + @BlockRegistry.register("TestFooClass") + class TestFooClass: + pass + + registry = BlockRegistry.get_registry() + assert registry is not None + assert registry["TestFooClass"] is TestFooClass diff --git a/tests/test_sample_populator_block.py b/tests/test_sample_populator_block.py index e19eed31..f363b9d0 100644 --- a/tests/test_sample_populator_block.py +++ b/tests/test_sample_populator_block.py @@ -8,7 +8,7 @@ from datasets import Dataset, Features, Value # First Party -from instructlab.sdg.utilblocks import SamplePopulatorBlock +from instructlab.sdg import SamplePopulatorBlock class TestSamplePopulatorBlock(unittest.TestCase): @@ -17,7 +17,7 @@ def setUp(self): self.ctx.dataset_num_procs = 1 self.pipe = MagicMock() - @patch("instructlab.sdg.block.Block._load_config") + @patch("instructlab.sdg.blocks.block.Block._load_config") def test_generate(self, mock_load_config): def load_config(file_name): if file_name == "coffee.yaml" or file_name == "tea.yaml": diff --git a/tests/test_utilblocks.py b/tests/test_utilblocks.py index 5ac6233f..a8849d6a 100644 --- a/tests/test_utilblocks.py +++ b/tests/test_utilblocks.py @@ -8,7 +8,7 @@ from datasets import Dataset, Features, Value # First Party -from src.instructlab.sdg.utilblocks import ( +from src.instructlab.sdg import ( DuplicateColumnsBlock, FlattenColumnsBlock, RenameColumnsBlock, diff --git a/tests/testdata/custom_block.py b/tests/testdata/custom_block.py new file mode 100644 index 00000000..409d4ebc --- /dev/null +++ b/tests/testdata/custom_block.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import pathlib + +# Third Party +from datasets import Dataset + +# First Party +from instructlab.sdg import Block, BlockRegistry, Pipeline, PipelineContext + + +@BlockRegistry.register("EchoBlock") +class EchoBlock(Block): + def generate(self, samples: Dataset): + return samples + + +pipeline_context = PipelineContext(None, "mixtral", "my_model", 5) +pipeline_yaml = pathlib.Path(__file__).parent.joinpath("custom_block_pipeline.yaml") +pipeline = Pipeline.from_file(pipeline_context, pipeline_yaml) +input_ds = Dataset.from_list( + [ + { + "fruit": "apple", + "color": "red", + } + ] +) +output_ds = pipeline.generate(input_ds) +assert len(output_ds) == 1 +assert output_ds[0]["fruit"] == "apple" +assert output_ds[0]["color"] == "red" diff --git a/tests/testdata/custom_block_pipeline.yaml b/tests/testdata/custom_block_pipeline.yaml new file mode 100644 index 00000000..0f4c447e --- /dev/null +++ b/tests/testdata/custom_block_pipeline.yaml @@ -0,0 +1,5 @@ +version: "1.0" +blocks: + - name: echo + type: EchoBlock + config: {} diff --git a/tests/testdata/custom_prompt.py b/tests/testdata/custom_prompt.py new file mode 100644 index 00000000..ea5f06c4 --- /dev/null +++ b/tests/testdata/custom_prompt.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 + +# First Party +from instructlab.sdg import PromptRegistry + + +# Register our custom chat template under the "custom_model_family" +# model family +@PromptRegistry.register("custom_model_family") +def custom_chat_template(): + return """{% for message in messages %}{% if message['role'] == 'system' %}{{ '<>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<>' + '\n' + message['content'] + ('' if loop.last else '\n') }}{% endif %}{% endfor %}""" + + +# Lookup the chat template for "custom_model_family" model family +template = PromptRegistry.get_template("custom_model_family") +assert template is not None + +# Ensure the template found is our custom one +prompt = template.render( + messages=[ + {"role": "system", "content": "system prompt goes here"}, + {"role": "user", "content": "user content goes here"}, + ] +) +expected_prompt = ( + "<>\nsystem prompt goes here\n<>\nuser content goes here\n" +) +assert prompt == expected_prompt