-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port the new IterBlock from research code
This adds a new Block type - `IterBlock` - that calls another block N times for a set of given input samples. Every iteration through the loop, the samples returned from the child block's `generate` call get added to the list of samples produced from this block. So, if you use an `IterBlock` to call an `LLMBlock` 5 times, you'll get 5 samples generated (and 5 calls to the LLM) for every sample in the source dataset. The output dataset will contain all 5 generated samples resulting from each 1 input sample. Signed-off-by: Ben Browning <[email protected]>
- Loading branch information
Showing
3 changed files
with
105 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Standard | ||
import logging | ||
|
||
# Third Party | ||
from datasets import Dataset | ||
|
||
# Local | ||
from ..pipeline import _lookup_block_type | ||
from ..registry import BlockRegistry | ||
from .block import Block | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# This is part of the public API. | ||
@BlockRegistry.register("IterBlock") | ||
class IterBlock(Block): | ||
def __init__( | ||
self, | ||
ctx, | ||
pipe, | ||
block_name, | ||
num_iters, | ||
block_type, | ||
**block_config, | ||
) -> None: | ||
super().__init__(ctx, pipe, block_name) | ||
self.num_iters = num_iters | ||
block_type = _lookup_block_type(block_type) | ||
self.block = block_type(ctx, pipe, block_name, **block_config) | ||
|
||
def generate(self, samples: Dataset) -> Dataset: | ||
generated_samples = [] | ||
num_iters = self.num_iters | ||
|
||
for _ in range(num_iters): | ||
batch_generated = self.block.generate(samples) | ||
generated_samples.extend(batch_generated) | ||
|
||
return Dataset.from_list(generated_samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Standard | ||
from unittest.mock import MagicMock | ||
import unittest | ||
|
||
# Third Party | ||
from datasets import Dataset, Features, Value | ||
|
||
# First Party | ||
from instructlab.sdg import Block, BlockRegistry, IterBlock | ||
|
||
|
||
class TestIterBlock(unittest.TestCase): | ||
@BlockRegistry.register("TestCounterBlock") | ||
class TestCounterBlock(Block): | ||
def __init__( | ||
self, | ||
ctx, | ||
pipe, | ||
block_name, | ||
column, | ||
increment=1, | ||
) -> None: | ||
super().__init__(ctx, pipe, block_name) | ||
self.column = column | ||
self.increment = increment | ||
self.counter = 0 | ||
|
||
def generate(self, samples: Dataset): | ||
samples = samples.map( | ||
lambda x: {self.column: x[self.column] + self.counter} | ||
) | ||
self.counter += self.increment | ||
return samples | ||
|
||
def setUp(self): | ||
self.ctx = MagicMock() | ||
self.ctx.dataset_num_procs = 1 | ||
self.pipe = MagicMock() | ||
self.block = IterBlock( | ||
self.ctx, | ||
self.pipe, | ||
"iter_test", | ||
num_iters=4, | ||
block_type="TestCounterBlock", | ||
column="counter", | ||
) | ||
self.dataset = Dataset.from_dict( | ||
{"counter": [0]}, | ||
features=Features({"counter": Value("int32")}), | ||
) | ||
|
||
def test_simple_iterate(self): | ||
iterated_dataset = self.block.generate(self.dataset) | ||
# We iterated 4 times, so 4 items in our dataset - one from | ||
# each iteration | ||
self.assertEqual(len(iterated_dataset), 4) | ||
# Each iteration increment our counter, because of our custom | ||
# block used that just increments counters | ||
self.assertEqual(iterated_dataset["counter"], [0, 1, 2, 3]) |