Skip to content

Commit

Permalink
allow generate_data logger parameter to overwrite locally defined log…
Browse files Browse the repository at this point in the history
…gers

Signed-off-by: Khaled Sulayman <[email protected]>
  • Loading branch information
khaledsulayman committed Dec 11, 2024
1 parent dcbabc5 commit 81fad3c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 57 deletions.
29 changes: 16 additions & 13 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# when |knowledge| << |skills|
MIN_UPSAMPLE_THRESHOLD = 0.03
ALLOWED_COLS = ["id", "messages", "metadata"]
logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


class DatasetListing(TypedDict):
Expand All @@ -40,7 +40,7 @@ def _adjust_train_sample_size(ds: Dataset, num_samples: int):
Return a dataset with num_samples random samples selected from the
original dataset.
"""
logger.info(f"Rebalancing dataset to have {num_samples} samples ...")
LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...")
df = ds.to_pandas()
df = df.sample(n=num_samples, random_state=42, replace=True)
return pandas.dataset_from_pandas_dataframe(df)
Expand Down Expand Up @@ -135,10 +135,10 @@ def _load_ds(self, path):
"""
if not os.path.isabs(path):
path = os.path.join(os.path.dirname(self.recipe_path), path)
logger.info(f"Loading dataset from {path} ...")
LOGGER.info(f"Loading dataset from {path} ...")
dataset = load_dataset("json", data_files=path, split="train")
logger.info(f"Dataset columns: {dataset.column_names}")
logger.info(f"Dataset loaded with {len(dataset)} samples")
LOGGER.info(f"Dataset columns: {dataset.column_names}")
LOGGER.info(f"Dataset loaded with {len(dataset)} samples")
return dataset

def _load_and_sample_datasets(self, num_proc):
Expand All @@ -161,7 +161,7 @@ def _create_mixed_dataset(self, num_proc):
concatenating all datasets in this recipe
"""
if not self.dataset_added:
logger.error("No dataset added to the recipe")
LOGGER.error("No dataset added to the recipe")

mixed_ds = self._load_and_sample_datasets(num_proc)
mixed_ds = concatenate_datasets(mixed_ds)
Expand Down Expand Up @@ -212,7 +212,7 @@ def save_mixed_dataset(self, output_path, num_proc):
"""
mixed_ds = self._create_mixed_dataset(num_proc)
mixed_ds.to_json(output_path, orient="records", lines=True)
logger.info(f"Mixed Dataset saved to {output_path}")
LOGGER.info(f"Mixed Dataset saved to {output_path}")


def _unescape(s):
Expand All @@ -235,7 +235,7 @@ def _get_question_hack(synth_example):

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[0].strip() + "?" if len(parts) == 2 else ""


Expand All @@ -251,7 +251,7 @@ def _get_response_hack(synth_example):

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


Expand Down Expand Up @@ -333,7 +333,7 @@ def __pick_documents(rec, p):
selected_docs = [e for e in all_context if e != answer_document]
if len(selected_docs) > 0:
if len(selected_docs) < num_doc_in_context:
logger.debug(
LOGGER.debug(
f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context."
)
if random.uniform(0, 1) < p:
Expand All @@ -352,7 +352,7 @@ def __pick_documents(rec, p):
else selected_docs
)
else:
logger.warning(
LOGGER.warning(
"Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results."
)
docs = [answer_document]
Expand Down Expand Up @@ -697,7 +697,7 @@ def collect(
if knowledge_to_skills_ratio < MIN_UPSAMPLE_THRESHOLD:
sampling_size = int(self._precomputed_skills_length * 0.03)

logger.info(
LOGGER.info(
"\033[93mKnowledge detected to be less than %.2f%% of skills (%.2f%%), upsampling to: %d\033[0m",
MIN_UPSAMPLE_THRESHOLD * 100,
knowledge_to_skills_ratio * 100,
Expand Down Expand Up @@ -739,7 +739,10 @@ def _gen_mixed_data(self, recipe, output_file_recipe, output_file_data):
self.num_procs,
)

def generate(self):
def generate(self, logger=None):
if logger is not None:
global LOGGER # pylint: disable=global-statement
LOGGER = logger
self._gen_mixed_data(
self.knowledge_recipe,
self.output_file_knowledge_recipe,
Expand Down
33 changes: 19 additions & 14 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
read_taxonomy_leaf_nodes,
)

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)

_SYS_PROMPT = "I am a Red Hat® Instruct Model, an AI language model developed by Red Hat and IBM Research based on the granite-3.0-8b-base model. My primary role is to serve as a chat assistant."

Expand Down Expand Up @@ -90,7 +90,7 @@ def _gen_train_data(

for output_dataset in machine_instruction_data:
for synth_example in output_dataset:
logger.debug(synth_example)
LOGGER.debug(synth_example)
user = _get_question_hack(synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
Expand Down Expand Up @@ -223,7 +223,7 @@ def _sdg_init(ctx, pipeline):
config = yaml.safe_load(file)
docling_model_path = config["models"][0]["path"]
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.warning(f"unable to read docling models path from config.yaml {e}")
LOGGER.warning(f"unable to read docling models path from config.yaml {e}")

for d in data_dirs:
pipeline_path = os.path.join(d, "pipelines", pipeline)
Expand Down Expand Up @@ -285,7 +285,7 @@ def _mixer_init(
# to be removed: logger
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
logger: logging.Logger = None, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
Expand Down Expand Up @@ -318,6 +318,10 @@ def generate_data(
We expect three files to be present in this directory: "knowledge.yaml",
"freeform_skills.yaml", and "grounded_skills.yaml".
"""
if logger is not None:
global LOGGER # pylint: disable=global-statement
LOGGER = logger

generate_start = time.time()

system_prompt = system_prompt if system_prompt is not None else _SYS_PROMPT
Expand All @@ -336,7 +340,7 @@ def generate_data(
document_output_dir = Path(output_dir) / f"documents-{date_suffix}"

leaf_nodes = read_taxonomy_leaf_nodes(
taxonomy, taxonomy_base, yaml_rules, document_output_dir
taxonomy, taxonomy_base, yaml_rules, document_output_dir, logger=LOGGER
)
if not leaf_nodes:
raise GenerateException("Error: No new leaf nodes found in the taxonomy.")
Expand All @@ -352,7 +356,7 @@ def generate_data(
system_prompt,
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")
LOGGER.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")

model_family = models.get_model_family(model_family, model_name)

Expand Down Expand Up @@ -385,7 +389,7 @@ def generate_data(
)

if console_output:
logger.info(
LOGGER.info(
"Synthesizing new instructions. If you aren't satisfied with the generated instructions, interrupt training (Ctrl-C) and try adjusting your YAML files. Adding more examples may help."
)

Expand All @@ -402,6 +406,7 @@ def generate_data(
document_output_dir,
model_name,
docling_model_path=docling_model_path,
logger=LOGGER,
)

if not samples:
Expand All @@ -417,17 +422,17 @@ def generate_data(
else:
pipe = freeform_skills_pipe

logger.debug("Samples: %s", samples)
LOGGER.debug("Samples: %s", samples)

new_generated_data = pipe.generate(samples, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
LOGGER.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data.append(new_generated_data)

logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)
LOGGER.info("Generated %d samples", len(generated_data))
LOGGER.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
Expand All @@ -453,12 +458,12 @@ def generate_data(
system_prompt,
)

mixer.generate()
mixer.generate(logger=LOGGER)

generate_duration = time.time() - generate_start
logger.info(f"Generation took {generate_duration:.2f}s")
LOGGER.info(f"Generation took {generate_duration:.2f}s")
if len(empty_sdg_leaf_nodes) > 0:
logger.warning(
LOGGER.warning(
"Leaf nodes with empty sdg output: {}".format(
" ".join(empty_sdg_leaf_nodes)
)
Expand Down
15 changes: 9 additions & 6 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .blocks.block import Block
from .registry import BlockRegistry

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


# This is part of the public API.
Expand Down Expand Up @@ -134,13 +134,16 @@ def from_file(cls, ctx, pipeline_yaml):
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml))

def generate(self, dataset, checkpoint_name=None) -> Dataset:
def generate(self, dataset, checkpoint_name=None, logger=None) -> Dataset:
"""
Generate the dataset by running the pipeline steps.
dataset: the input dataset
checkpoint_name: unique subdir name for the checkpoint within checkpoint_dir
"""

if logger is not None:
global LOGGER # pylint: disable=global-statement
LOGGER = logger
# The checkpointer allows us to resume from where we left off
# Saving the output of pipe instances along the way
checkpoint_dir = None
Expand All @@ -153,12 +156,12 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:

# If not batching, simply delegate to _generate_single
if not self.ctx.batching_enabled:
logger.info("Running pipeline single-threaded")
LOGGER.info("Running pipeline single-threaded")
return self._generate_single(dataset)

# Otherwise, split the dataset into batches and run each batch as a
# future in the thread pool
logger.info(
LOGGER.info(
"Running pipeline with multi-threaded batching. Using %s workers for batches of size %s",
self.ctx.batch_num_workers,
self.ctx.batch_size,
Expand Down Expand Up @@ -197,7 +200,7 @@ def _generate_single(self, dataset) -> Dataset:
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)
LOGGER.info("Running block: %s", block_name)
# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down Expand Up @@ -284,7 +287,7 @@ def _parse_pipeline_config_file(pipeline_yaml):
"The pipeline config file format is from a future major version."
)
if major <= _PIPELINE_CONFIG_PARSER_MAJOR and minor > _PIPELINE_CONFIG_PARSER_MINOR:
logger.warning(
LOGGER.warning(
"The pipeline config file may have new features that will be ignored."
)

Expand Down
Loading

0 comments on commit 81fad3c

Please sign in to comment.