Skip to content

Commit

Permalink
Merge pull request #340 from cdoern/max-num-tokens
Browse files Browse the repository at this point in the history
feat: expose max_num_tokens as configurable
  • Loading branch information
mergify[bot] authored Nov 8, 2024
2 parents baf4c30 + a8cb715 commit eacae02
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
10 changes: 9 additions & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
# pylint: disable=ungrouped-imports
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from instructlab.sdg.llmblock import (
DEFAULT_MAX_NUM_TOKENS,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
)
from instructlab.sdg.pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Expand Down Expand Up @@ -188,6 +192,7 @@ def _context_init(
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
):
extra_kwargs = {}
if batch_size is not None:
Expand All @@ -201,6 +206,7 @@ def _context_init(
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
max_num_tokens=max_num_tokens,
**extra_kwargs,
)

Expand Down Expand Up @@ -288,6 +294,7 @@ def generate_data(
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -353,6 +360,7 @@ def generate_data(
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
Expand Down
25 changes: 21 additions & 4 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

Expand Down Expand Up @@ -78,11 +80,23 @@ def __init__(
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs
max_num_token_override = ctx.max_num_tokens
self.parser_name = parser_kwargs.get("parser_name", None)
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
# max_num_tokens should only be applicable to knowledge blocks
# gen_knowledge if the full/simple pipeline's knowledge generation block
if block_name != "gen_knowledge":
logger.debug(
f"Not applying max_num_tokens to block {block_name}. This is only applicable for gen_knowledge."
)
max_num_token_override = DEFAULT_MAX_NUM_TOKENS
self.gen_kwargs = self._gen_kwargs(
gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096
max_num_token_override,
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
Expand Down Expand Up @@ -142,23 +156,26 @@ def _format_prompt(self, sample: Dict) -> str:

return prompt if model_prompt is None else model_prompt.format(prompt=prompt)

def _gen_kwargs(self, gen_kwargs, **defaults):
def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults):
gen_kwargs = {**defaults, **gen_kwargs}
if (
"n" in gen_kwargs
and isinstance(gen_kwargs["n"], str)
and gen_kwargs["n"] == "scaled"
):
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if max_num_token_override != DEFAULT_MAX_NUM_TOKENS:
gen_kwargs["max_tokens"] = max_num_token_override
elif "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
return gen_kwargs

def _generate(self, samples) -> list:
prompts = [self._format_prompt(sample) for sample in samples]
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
print(self.gen_kwargs)
if self.server_supports_batched:
response = self.ctx.client.completions.create(
prompt=prompts, **self.gen_kwargs
Expand Down
3 changes: 2 additions & 1 deletion src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
max_num_tokens: the maximum number of tokens to generate per sample.
"""

# The default batch size of 8 has been determined as a good default for
Expand All @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -195,7 +197,6 @@ def _generate_single(self, dataset) -> Dataset:
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,20 @@ def test_model_prompt_none(self, mock_load_config):
"FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR",
"model_prompt should be a non-empty string when set to None",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_max_num_tokens_override(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
self.mock_ctx.max_num_tokens = 512
# Ensure that if a custom model_prompt is specified, it is used correctly
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
config_path="",
output_cols=[],
model_prompt="",
gen_kwargs={"max_tokens": 2048},
)
num_tokens = block.gen_kwargs["max_tokens"]
assert num_tokens == 512

0 comments on commit eacae02

Please sign in to comment.