Skip to content

Commit 8576050

Browse files
committed
Implement LLMMessagesBlock
This was just a commented out stub before, ported over from the research prototypes. This uncomments it, ports things to work within the current structure of InstructLab, and adds a few tests to verify the logic within the block. Fixes #414 Signed-off-by: Ben Browning <[email protected]>
1 parent 02ccaef commit 8576050

File tree

2 files changed

+107
-46
lines changed

2 files changed

+107
-46
lines changed

src/instructlab/sdg/blocks/llmblock.py

+50-34
Original file line numberDiff line numberDiff line change
@@ -457,48 +457,64 @@ def __init__(
457457

458458
# This is part of the public API.
459459
@BlockRegistry.register("LLMMessagesBlock")
460-
class LLMMessagesBlock(LLMBlock):
460+
class LLMMessagesBlock(Block):
461461
def __init__(
462462
self,
463463
ctx,
464464
pipe,
465465
block_name,
466-
config_path,
467-
output_cols,
468-
model_prompt=None,
466+
input_col,
467+
output_col,
469468
gen_kwargs={},
470-
parser_kwargs={},
471-
batch_kwargs={},
472469
) -> None:
473-
super().__init__(
474-
ctx,
475-
pipe,
476-
block_name,
477-
config_path,
478-
output_cols,
479-
model_prompt=model_prompt,
480-
gen_kwargs=gen_kwargs,
481-
parser_kwargs=parser_kwargs,
482-
batch_kwargs=batch_kwargs,
470+
super().__init__(ctx, pipe, block_name)
471+
self.input_col = input_col
472+
self.output_col = output_col
473+
self.gen_kwargs = self._gen_kwargs(
474+
gen_kwargs,
475+
model=self.ctx.model_id,
476+
temperature=0,
477+
max_tokens=DEFAULT_MAX_NUM_TOKENS,
483478
)
484479

485-
# def _generate(self, samples) -> list:
486-
# generate_args = {**self.defaults, **gen_kwargs}
487-
488-
# if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
489-
# generate_args["temperature"] = 0.7
490-
# logger.warning(
491-
# "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
492-
# )
480+
def _gen_kwargs(self, gen_kwargs, **defaults):
481+
gen_kwargs = {**defaults, **gen_kwargs}
482+
if "temperature" in gen_kwargs:
483+
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
484+
if (
485+
"n" in gen_kwargs
486+
and gen_kwargs["n"] > 1
487+
and gen_kwargs.get("temperature", 0) <= 0
488+
):
489+
gen_kwargs["temperature"] = 0.7
490+
logger.warning(
491+
"Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
492+
)
493+
return gen_kwargs
493494

494-
# messages = samples[self.input_col]
495+
def _generate(self, samples) -> list:
496+
messages = samples[self.input_col]
497+
logger.debug("STARTING GENERATION FOR LLMMessagesBlock")
498+
logger.debug(f"Generation arguments: {self.gen_kwargs}")
499+
results = []
500+
progress_bar = tqdm(
501+
range(len(samples)), desc=f"{self.block_name} Chat Completion Generation"
502+
)
503+
n = self.gen_kwargs.get("n", 1)
504+
for message in messages:
505+
logger.debug(f"CREATING CHAT COMPLETION FOR MESSAGE: {message}")
506+
responses = self.ctx.client.chat.completions.create(
507+
messages=message, **self.gen_kwargs
508+
)
509+
if n > 1:
510+
results.append([choice.message.content for choice in responses.choices])
511+
else:
512+
results.append(responses.choices[0].message.content)
513+
progress_bar.update(n)
514+
return results
495515

496-
# results = []
497-
# n = gen_kwargs.get("n", 1)
498-
# for message in messages:
499-
# responses = self.client.chat.completions.create(messages=message, **generate_args)
500-
# if n > 1:
501-
# results.append([choice.message.content for choice in responses.choices])
502-
# else:
503-
# results.append(responses.choices[0].message.content)
504-
# return results
516+
def generate(self, samples: Dataset) -> Dataset:
517+
outputs = self._generate(samples)
518+
logger.debug("Generated outputs: %s", outputs)
519+
samples = samples.add_column(self.output_col, outputs)
520+
return samples

tests/test_llmblock.py

+57-12
Original file line numberDiff line numberDiff line change
@@ -336,28 +336,73 @@ def test_constructor_works(self, mock_load_config):
336336
assert block is not None
337337

338338

339-
@patch("src.instructlab.sdg.blocks.block.Block._load_config")
340339
class TestLLMMessagesBlock(unittest.TestCase):
341340
def setUp(self):
342341
self.mock_ctx = MagicMock()
343342
self.mock_ctx.model_family = "mixtral"
344343
self.mock_ctx.model_id = "test_model"
345344
self.mock_pipe = MagicMock()
346-
self.config_return_value = {
347-
"system": "{{fruit}}",
348-
"introduction": "introduction",
349-
"principles": "principles",
350-
"examples": "examples",
351-
"generation": "generation",
352-
}
345+
self.mock_client = MagicMock()
346+
self.mock_ctx.client = self.mock_client
353347

354-
def test_constructor_works(self, mock_load_config):
355-
mock_load_config.return_value = self.config_return_value
348+
def test_constructor_works(self):
356349
block = LLMMessagesBlock(
357350
ctx=self.mock_ctx,
358351
pipe=self.mock_pipe,
359352
block_name="gen_knowledge",
360-
config_path="",
361-
output_cols=[],
353+
input_col="messages",
354+
output_col="output",
362355
)
363356
assert block is not None
357+
358+
def test_temperature_validation(self):
359+
block = LLMMessagesBlock(
360+
ctx=self.mock_ctx,
361+
pipe=self.mock_pipe,
362+
block_name="gen_knowledge",
363+
input_col="messages",
364+
output_col="output",
365+
gen_kwargs={"n": 5, "temperature": 0},
366+
)
367+
assert block.gen_kwargs["temperature"] != 0
368+
369+
block = LLMMessagesBlock(
370+
ctx=self.mock_ctx,
371+
pipe=self.mock_pipe,
372+
block_name="gen_knowledge",
373+
input_col="messages",
374+
output_col="output",
375+
gen_kwargs={"n": 1, "temperature": 0},
376+
)
377+
assert block.gen_kwargs["temperature"] == 0
378+
379+
def test_calls_chat_completion_api(self):
380+
# Mock the OpenAI client so we don't actually hit a server here
381+
mock_choice = MagicMock()
382+
mock_choice.message = MagicMock()
383+
mock_choice.message.content = "generated response"
384+
mock_completion_resp = MagicMock()
385+
mock_completion_resp.choices = [mock_choice]
386+
mock_completion = MagicMock()
387+
mock_completion.create = MagicMock()
388+
mock_completion.create.return_value = mock_completion_resp
389+
mock_chat = MagicMock()
390+
mock_chat.completions = mock_completion
391+
self.mock_client.chat = mock_chat
392+
block = LLMMessagesBlock(
393+
ctx=self.mock_ctx,
394+
pipe=self.mock_pipe,
395+
block_name="gen_knowledge",
396+
input_col="messages",
397+
output_col="output",
398+
gen_kwargs={"n": 1, "temperature": 0},
399+
)
400+
samples = Dataset.from_dict(
401+
{"messages": ["my message"]},
402+
features=Features({"messages": Value("string")}),
403+
)
404+
output = block.generate(samples)
405+
assert len(output) == 1
406+
assert "output" in output.column_names
407+
mock_completion.create.assert_called()
408+
assert mock_completion.create.call_args.kwargs["messages"] == "my message"

0 commit comments

Comments
 (0)