Skip to content

Commit

Permalink
fix: upsample the phase10 knowledge dataset
Browse files Browse the repository at this point in the history
When we mix the knowledge dataset with skills today, we do not account for the potential discrepancy
in size between the generated knowledge data and skills data. This leads to the models potentially
forgetting the data it was trained on in the knowledge phase. As a simple workaround, we simply
upsample the knowledge samples before mixing them in with the generated skills dataset.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Nov 14, 2024
1 parent b6f07a8 commit d6a6e7c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ def __init__(
date_suffix,
sys_prompt,
num_procs,
# HACK(osilkin): `knowledge_sample_amount` is necessary when the size of skills data greatly exceeds skills.
# This will be replaced in the future with a more robust mechanism for SDG to manage the
# upsampling relationship.
knowledge_upsample_amount: int,
auxiliary_inst=None,
):
self.data_dirs = data_dirs
Expand All @@ -555,6 +559,7 @@ def __init__(
self.date_suffix = date_suffix
self.num_procs = num_procs
self.auxiliary_inst = auxiliary_inst
self.knowledge_upsample_amount = knowledge_upsample_amount

self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
self.skills_recipe = self._load_default_recipe("skills.yaml")
Expand Down Expand Up @@ -615,10 +620,17 @@ def collect(
output_file_leaf_skills = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p10.jsonl"
)
# HACK(osilkin): `knowledge_upsample_amount` is currently used when the generated knowledge data
# is orders of magnitude smaller (approx. < 3%) than the skills dataset.
# It is used to upsample that dataset so that the model doesn't forget it in training.
#
# This work around is currently hacky as we lack insight into the size of both datasets
# when we generate this data, and it may vary across different scenarios
self._gen_leaf_node_data(
skills_phase_data,
self.skills_recipe,
output_file_leaf_skills,
sampling_size=self.knowledge_upsample_amount,
)
else:
messages = new_generated_data.map(
Expand Down
19 changes: 17 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import (
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
DEFAULT_MAX_NUM_TOKENS,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
Expand Down Expand Up @@ -254,7 +255,14 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt):
def _mixer_init(
ctx,
output_dir,
date_suffix,
knowledge_auxiliary_inst,
system_prompt,
upsample_amount: int,
):
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())

Expand All @@ -264,6 +272,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_p
date_suffix,
system_prompt,
ctx.dataset_num_procs,
upsample_amount,
knowledge_auxiliary_inst,
)

Expand Down Expand Up @@ -295,6 +304,7 @@ def generate_data(
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
upsample_amount: Optional[int] = DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -372,7 +382,12 @@ def generate_data(
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(
ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt
ctx,
output_dir,
date_suffix,
knowledge_pipe.auxiliary_inst,
system_prompt,
upsample_amount,
)

if console_output:
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT = 5000

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
Expand Down

0 comments on commit d6a6e7c

Please sign in to comment.