Skip to content

Commit

Permalink
Merge pull request #409 from bbrowning/research-sync
Browse files Browse the repository at this point in the history
Reconcile core data generation features with latest research advances
  • Loading branch information
bbrowning authored Dec 10, 2024
2 parents 74d1bbe + db3a1ad commit fd53dcd
Show file tree
Hide file tree
Showing 52 changed files with 1,060 additions and 432 deletions.
7 changes: 7 additions & 0 deletions .spellcheck-en-custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
Backport
backported
CLI
codebase
config
configs
Dataset
dataset
Expand All @@ -17,6 +19,8 @@ FIXME
freeform
ICL
icl
ie
Jinja
JSON
Langchain's
LLM
Expand All @@ -39,11 +43,14 @@ Splitter
subdirectory
subfolder
Tatsu
templating
Tesseract
TODO
tokenizer
tokenizers
unchunked
upsampled
UUID
vLLM
yaml
yamls
24 changes: 24 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
## Unreleased 0.7.x

### Features

#### Custom Blocks and Teacher Models via BlockRegistry and PromptRegistry

Advanced users are now able to supply custom Pipeline `Block` implementations by registering new blocks with the `BlockRegistry`. It's also possible to register new chat templates for custom teacher models using the new `PromptRegistry`.

See the `tests/testdata/custom_block.py` and `tests/testdata/custom_block_pipeline.yaml` files in this repository for an example of how to create custom blocks and use them from your own pipeline config yamls.

See the `tests/testdata/custom_prompt.py` file in this repository for an example how to register custom chat templates used when formatting prompts.

### Breaking Changes

#### Pipeline configs and Prompt templates switched to Jinja

All of our [Pipeline config yamls](src/instructlab/sdg/pipelines) and [prompt template files](src/instructlab/sdg/configs) have moved to [Jinja templates](https://pypi.org/project/Jinja2/) instead of Python string `format()` calls. This brings more expressiveness into our templating language - especially for prompt templates - but does mean any variable substitutions need to be updated from single brackets to double brackets - ie `{document}` becomes `{{document}}`. This only impacts you if you were using custom pipeline config yaml files or custom prompt templates in your config blocks.

#### ImportBlock removed from Pipeline blocks

Any users that were specifying custom pipeline configs (instead of using the default `full` or `simple` shipped by us) and also using the `ImportBlock` will now need to rewrite their pipelines to no longer use that block. We do not anticipate that anyone was actually using this block, but please reach out if you were so we can capture your needs in a future release.

### Fixes

## v0.6.2

### Fixes
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ exclude = [
"^src/instructlab/sdg/generate_data\\.py$",
"^src/instructlab/sdg/utils/taxonomy\\.py$",
"^src/instructlab/sdg/default_flows\\.py$",
"^src/instructlab/sdg/llmblock\\.py$",
"^src/instructlab/sdg/utilblocks\\.py$",
"^src/instructlab/sdg/blocks/llmblock\\.py$",
"^src/instructlab/sdg/blocks/utilblocks\\.py$",
]
# honor excludes by not following there through imports
follow_imports = "silent"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ GitPython>=3.1.42,<4.0.0
gguf>=0.6.0
httpx>=0.25.0,<1.0.0
instructlab-schema>=0.4.0
jinja2>=3.0.0
langchain-text-splitters
# Note: this dependency goes along with langchain-text-splitters and may be
# removed once that one is removed.
Expand Down
41 changes: 26 additions & 15 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# NOTE: This package imports Torch and other heavy packages.
__all__ = (
"Block",
"BlockConfigParserError",
"BlockRegistry",
"CombineColumnsBlock",
"ConditionalLLMBlock",
"DuplicateColumnsBlock",
Expand All @@ -11,27 +13,44 @@
"FilterByValueBlockError",
"FlattenColumnsBlock",
"GenerateException",
"ImportBlock",
"IterBlock",
"LLMBlock",
"LLMLogProbBlock",
"LLMMessagesBlock",
"Pipeline",
"PipelineBlockError",
"PipelineConfigParserError",
"PipelineContext",
"PromptRegistry",
"RenameColumnsBlock",
"SamplePopulatorBlock",
"SelectorBlock",
"SetToMajorityValueBlock",
"SIMPLE_PIPELINES_PACKAGE",
"FULL_PIPELINES_PACKAGE",
"SIMPLE_PIPELINES_PACKAGE",
"generate_data",
)

# Local
from .block import Block
from .filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.block import Block, BlockConfigParserError
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.iterblock import IterBlock
from .blocks.llmblock import (
ConditionalLLMBlock,
LLMBlock,
LLMLogProbBlock,
LLMMessagesBlock,
)
from .blocks.utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
SetToMajorityValueBlock,
)
from .generate_data import generate_data
from .importblock import ImportBlock
from .llmblock import ConditionalLLMBlock, LLMBlock
from .pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Expand All @@ -41,14 +60,6 @@
PipelineConfigParserError,
PipelineContext,
)
from .utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
SetToMajorityValueBlock,
)
from .registry import BlockRegistry, PromptRegistry
from .utils import GenerateException
from .utils.taxonomy import TaxonomyReadingException
37 changes: 0 additions & 37 deletions src/instructlab/sdg/block.py

This file was deleted.

Empty file.
75 changes: 75 additions & 0 deletions src/instructlab/sdg/blocks/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from abc import ABC
from typing import Any, Dict, Union
import logging
import os.path

# Third Party
from jinja2 import Template, UndefinedError
import yaml

# Local
from ..registry import BlockRegistry

logger = logging.getLogger(__name__)


# This is part of the public API.
@BlockRegistry.register("Block")
class Block(ABC):
def __init__(self, ctx, pipe, block_name: str) -> None:
self.ctx = ctx
self.pipe = pipe
self.block_name = block_name

def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bool:
"""
Validate the input data for this block. This method validates whether all required
variables in the Jinja template are provided in the input_dict.
Args:
prompt_template (Template): The Jinja2 template object.
input_dict (Dict[str, Any]): A dictionary of input values to check against
the template.
Returns:
True if the input data is valid (i.e., no missing variables), False otherwise.
"""

try:
# Try rendering the template with the input_dict
prompt_template.render(input_dict)
return True
except UndefinedError as e:
# Jinja throws an UndefinedError for any undefnined template variables,
# assuming the prompt_template was created using StrictUndefined. This
# is the case for anything using PromptRegistry.template_from_string.
logger.error(f"Missing key: {e}")
return False

def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
"""
Load the configuration file for this block.
If the supplied configuration file is a relative path, it is assumed
to be part of this Python package.
Args:
config_path (str): The path to the configuration file.
Returns:
The loaded configuration.
"""
if not os.path.isabs(config_path):
config_path = os.path.join(
os.path.dirname(self.pipe.config_path), config_path
)
with open(config_path, "r", encoding="utf-8") as config_file:
return yaml.safe_load(config_file)


# This is part of the public API.
class BlockConfigParserError(Exception):
"""An exception raised while parsing a block's configuration."""
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datasets import Dataset

# Local
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,6 +87,7 @@ def convert_column(sample):


# This is part of the public API.
@BlockRegistry.register("FilterByValueBlock")
class FilterByValueBlock(Block):
def __init__(
self,
Expand Down
57 changes: 57 additions & 0 deletions src/instructlab/sdg/blocks/iterblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import logging

# Third Party
from datasets import Dataset

# Local
from ..pipeline import _lookup_block_type
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)


# This is part of the public API.
@BlockRegistry.register("IterBlock")
class IterBlock(Block):
"""
Call another block multiple times for a single set of input
samples, concatening the results of each iteration's call to that
other block in the final returned output.
Args:
num_iters: The number of times to iterate over the block
block_type: The type of the other block to call (ie LLMBlock)
block_config: Any necessary configuration that will get passed to the
other block to properly configure it.
Returns:
A Dataset containing all output samples from each iteration
"""

def __init__(
self,
ctx,
pipe,
block_name,
num_iters,
block_type,
**block_config,
) -> None:
super().__init__(ctx, pipe, block_name)
self.num_iters = num_iters
block_type = _lookup_block_type(block_type)
self.block = block_type(ctx, pipe, block_name, **block_config)

def generate(self, samples: Dataset) -> Dataset:
generated_samples = []
num_iters = self.num_iters

for _ in range(num_iters):
batch_generated = self.block.generate(samples)
generated_samples.extend(batch_generated)

return Dataset.from_list(generated_samples)
Loading

0 comments on commit fd53dcd

Please sign in to comment.