Skip to content

Commit

Permalink
Port the new IterBlock from research code
Browse files Browse the repository at this point in the history
This adds a new Block type - `IterBlock` - that calls another block N
times for a set of given input samples. Every iteration through the
loop, the samples returned from the child block's `generate` call get
added to the list of samples produced from this block.

So, if you use an `IterBlock` to call an `LLMBlock` 5 times, you'll get
5 samples generated (and 5 calls to the LLM) for every sample in the
source dataset. The output dataset will contain all 5 generated samples
resulting from each 1 input sample.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Nov 27, 2024
1 parent 1007e8e commit 8b0d52a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"FilterByValueBlockError",
"FlattenColumnsBlock",
"GenerateException",
"IterBlock",
"LLMBlock",
"LLMLogProbBlock",
"LLMMessagesBlock",
Expand All @@ -32,6 +33,7 @@
# Local
from .blocks.block import Block
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.iterblock import IterBlock
from .blocks.llmblock import (
ConditionalLLMBlock,
LLMBlock,
Expand Down
42 changes: 42 additions & 0 deletions src/instructlab/sdg/blocks/iterblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import logging

# Third Party
from datasets import Dataset

# Local
from ..pipeline import _lookup_block_type
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)


# This is part of the public API.
@BlockRegistry.register("IterBlock")
class IterBlock(Block):
def __init__(
self,
ctx,
pipe,
block_name,
num_iters,
block_type,
**block_config,
) -> None:
super().__init__(ctx, pipe, block_name)
self.num_iters = num_iters
block_type = _lookup_block_type(block_type)
self.block = block_type(ctx, pipe, block_name, **block_config)

def generate(self, samples: Dataset) -> Dataset:
generated_samples = []
num_iters = self.num_iters

for _ in range(num_iters):
batch_generated = self.block.generate(samples)
generated_samples.extend(batch_generated)

return Dataset.from_list(generated_samples)
61 changes: 61 additions & 0 deletions tests/test_iterblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from unittest.mock import MagicMock
import unittest

# Third Party
from datasets import Dataset, Features, Value

# First Party
from instructlab.sdg import Block, BlockRegistry, IterBlock


class TestIterBlock(unittest.TestCase):
@BlockRegistry.register("TestCounterBlock")
class TestCounterBlock(Block):
def __init__(
self,
ctx,
pipe,
block_name,
column,
increment=1,
) -> None:
super().__init__(ctx, pipe, block_name)
self.column = column
self.increment = increment
self.counter = 0

def generate(self, samples: Dataset):
samples = samples.map(
lambda x: {self.column: x[self.column] + self.counter}
)
self.counter += self.increment
return samples

def setUp(self):
self.ctx = MagicMock()
self.ctx.dataset_num_procs = 1
self.pipe = MagicMock()
self.block = IterBlock(
self.ctx,
self.pipe,
"iter_test",
num_iters=4,
block_type="TestCounterBlock",
column="counter",
)
self.dataset = Dataset.from_dict(
{"counter": [0]},
features=Features({"counter": Value("int32")}),
)

def test_simple_iterate(self):
iterated_dataset = self.block.generate(self.dataset)
# We iterated 4 times, so 4 items in our dataset - one from
# each iteration
self.assertEqual(len(iterated_dataset), 4)
# Each iteration increment our counter, because of our custom
# block used that just increments counters
self.assertEqual(iterated_dataset["counter"], [0, 1, 2, 3])

0 comments on commit 8b0d52a

Please sign in to comment.