From a8cb715539a6d86d3f86ccc6735067d9ccdbe386 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 6 Nov 2024 15:55:12 -0500 Subject: [PATCH] expose max_num_tokens as configurable max-num-tokens is a nice way to run a shorter or longer SDG run. locally I have been modifiyng the pipeline yaml from 2048 to 512 which ends up just generating less data exposing this to the CLI could allow power users to run different types of SDG runs! Signed-off-by: Charlie Doern --- src/instructlab/sdg/generate_data.py | 10 +++++++++- src/instructlab/sdg/llmblock.py | 25 +++++++++++++++++++++---- src/instructlab/sdg/pipeline.py | 3 ++- tests/test_llmblock.py | 17 +++++++++++++++++ 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index f3e38e65..7b423053 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -20,7 +20,11 @@ # pylint: disable=ungrouped-imports from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init -from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL +from instructlab.sdg.llmblock import ( + DEFAULT_MAX_NUM_TOKENS, + MODEL_FAMILY_MERLINITE, + MODEL_FAMILY_MIXTRAL, +) from instructlab.sdg.pipeline import ( FULL_PIPELINES_PACKAGE, SIMPLE_PIPELINES_PACKAGE, @@ -183,6 +187,7 @@ def _context_init( save_freq: int, batch_num_workers: Optional[int], batch_size: Optional[int], + max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS, ): extra_kwargs = {} if batch_size is not None: @@ -196,6 +201,7 @@ def _context_init( num_instructions_to_generate=num_instructions_to_generate, checkpoint_dir=checkpoint_dir, save_freq=save_freq, + max_num_tokens=max_num_tokens, **extra_kwargs, ) @@ -281,6 +287,7 @@ def generate_data( pipeline: Optional[str] = "simple", batch_size: Optional[int] = None, checkpoint_dir: Optional[str] = None, + max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS, ) -> None: """Generate data for training and testing a model. @@ -343,6 +350,7 @@ def generate_data( 1, # save_freq batch_size=batch_size, batch_num_workers=num_cpus, + max_num_tokens=max_num_tokens, ) knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init( diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 2ddd30c2..ddb85477 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -16,6 +16,8 @@ logger = logging.getLogger(__name__) +DEFAULT_MAX_NUM_TOKENS = 4096 + MODEL_FAMILY_MIXTRAL = "mixtral" MODEL_FAMILY_MERLINITE = "merlinite" @@ -78,11 +80,23 @@ def __init__( self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs + max_num_token_override = ctx.max_num_tokens self.parser_name = parser_kwargs.get("parser_name", None) self.parsing_pattern = parser_kwargs.get("parsing_pattern", None) self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None) + # max_num_tokens should only be applicable to knowledge blocks + # gen_knowledge if the full/simple pipeline's knowledge generation block + if block_name != "gen_knowledge": + logger.debug( + f"Not applying max_num_tokens to block {block_name}. This is only applicable for gen_knowledge." + ) + max_num_token_override = DEFAULT_MAX_NUM_TOKENS self.gen_kwargs = self._gen_kwargs( - gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096 + max_num_token_override, + gen_kwargs, + model=self.ctx.model_id, + temperature=0, + max_tokens=DEFAULT_MAX_NUM_TOKENS, ) # Whether the LLM server supports a list of input prompts # and supports the n parameter to generate n outputs per input @@ -142,7 +156,7 @@ def _format_prompt(self, sample: Dict) -> str: return prompt if model_prompt is None else model_prompt.format(prompt=prompt) - def _gen_kwargs(self, gen_kwargs, **defaults): + def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults): gen_kwargs = {**defaults, **gen_kwargs} if ( "n" in gen_kwargs @@ -150,15 +164,18 @@ def _gen_kwargs(self, gen_kwargs, **defaults): and gen_kwargs["n"] == "scaled" ): gen_kwargs["n"] = self.ctx.num_instructions_to_generate - if "max_tokens" in gen_kwargs: - gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"]) if "temperature" in gen_kwargs: gen_kwargs["temperature"] = float(gen_kwargs["temperature"]) + if max_num_token_override != DEFAULT_MAX_NUM_TOKENS: + gen_kwargs["max_tokens"] = max_num_token_override + elif "max_tokens" in gen_kwargs: + gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"]) return gen_kwargs def _generate(self, samples) -> list: prompts = [self._format_prompt(sample) for sample in samples] logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}") + print(self.gen_kwargs) if self.server_supports_batched: response = self.ctx.client.completions.create( prompt=prompts, **self.gen_kwargs diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 20b03fb4..52621f81 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes central executor pool. dataset_num_procs: The number of processes to use when performing parallel map operations on individual datasets. + max_num_tokens: the maximum number of tokens to generate per sample. """ # The default batch size of 8 has been determined as a good default for @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS checkpoint_dir: Optional[str] = None save_freq: Optional[int] = 1 + max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None @@ -195,7 +197,6 @@ def _generate_single(self, dataset) -> Dataset: drop_duplicates_cols = block_prop.get("drop_duplicates", False) block = block_type(self.ctx, self, block_name, **block_config) logger.info("Running block: %s", block_name) - # Execute the block and wrap errors with the block name/type dataset = block.generate(dataset) except Exception as err: diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index f7835f32..bbdc8441 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -86,3 +86,20 @@ def test_model_prompt_none(self, mock_load_config): "FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR", "model_prompt should be a non-empty string when set to None", ) + + @patch("src.instructlab.sdg.block.Block._load_config") + def test_max_num_tokens_override(self, mock_load_config): + mock_load_config.return_value = self.config_return_value + self.mock_ctx.max_num_tokens = 512 + # Ensure that if a custom model_prompt is specified, it is used correctly + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_prompt="", + gen_kwargs={"max_tokens": 2048}, + ) + num_tokens = block.gen_kwargs["max_tokens"] + assert num_tokens == 512