From 81fad3c637209849ab182d2506a8c6dd7914ac13 Mon Sep 17 00:00:00 2001 From: Khaled Sulayman Date: Wed, 11 Dec 2024 11:42:20 -0500 Subject: [PATCH] allow generate_data logger parameter to overwrite locally defined loggers Signed-off-by: Khaled Sulayman --- src/instructlab/sdg/datamixing.py | 29 +++++++------- src/instructlab/sdg/generate_data.py | 33 +++++++++------- src/instructlab/sdg/pipeline.py | 15 +++++--- src/instructlab/sdg/utils/taxonomy.py | 55 +++++++++++++++------------ 4 files changed, 75 insertions(+), 57 deletions(-) diff --git a/src/instructlab/sdg/datamixing.py b/src/instructlab/sdg/datamixing.py index e6ca8675..25da098c 100644 --- a/src/instructlab/sdg/datamixing.py +++ b/src/instructlab/sdg/datamixing.py @@ -22,7 +22,7 @@ # when |knowledge| << |skills| MIN_UPSAMPLE_THRESHOLD = 0.03 ALLOWED_COLS = ["id", "messages", "metadata"] -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) class DatasetListing(TypedDict): @@ -40,7 +40,7 @@ def _adjust_train_sample_size(ds: Dataset, num_samples: int): Return a dataset with num_samples random samples selected from the original dataset. """ - logger.info(f"Rebalancing dataset to have {num_samples} samples ...") + LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...") df = ds.to_pandas() df = df.sample(n=num_samples, random_state=42, replace=True) return pandas.dataset_from_pandas_dataframe(df) @@ -135,10 +135,10 @@ def _load_ds(self, path): """ if not os.path.isabs(path): path = os.path.join(os.path.dirname(self.recipe_path), path) - logger.info(f"Loading dataset from {path} ...") + LOGGER.info(f"Loading dataset from {path} ...") dataset = load_dataset("json", data_files=path, split="train") - logger.info(f"Dataset columns: {dataset.column_names}") - logger.info(f"Dataset loaded with {len(dataset)} samples") + LOGGER.info(f"Dataset columns: {dataset.column_names}") + LOGGER.info(f"Dataset loaded with {len(dataset)} samples") return dataset def _load_and_sample_datasets(self, num_proc): @@ -161,7 +161,7 @@ def _create_mixed_dataset(self, num_proc): concatenating all datasets in this recipe """ if not self.dataset_added: - logger.error("No dataset added to the recipe") + LOGGER.error("No dataset added to the recipe") mixed_ds = self._load_and_sample_datasets(num_proc) mixed_ds = concatenate_datasets(mixed_ds) @@ -212,7 +212,7 @@ def save_mixed_dataset(self, output_path, num_proc): """ mixed_ds = self._create_mixed_dataset(num_proc) mixed_ds.to_json(output_path, orient="records", lines=True) - logger.info(f"Mixed Dataset saved to {output_path}") + LOGGER.info(f"Mixed Dataset saved to {output_path}") def _unescape(s): @@ -235,7 +235,7 @@ def _get_question_hack(synth_example): parts = synth_example["output"].split("?", 1) if len(parts) != 2: - logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}") return parts[0].strip() + "?" if len(parts) == 2 else "" @@ -251,7 +251,7 @@ def _get_response_hack(synth_example): parts = synth_example["output"].split("?", 1) if len(parts) != 2: - logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}") return parts[1].strip() if len(parts) == 2 else parts[0].strip() @@ -333,7 +333,7 @@ def __pick_documents(rec, p): selected_docs = [e for e in all_context if e != answer_document] if len(selected_docs) > 0: if len(selected_docs) < num_doc_in_context: - logger.debug( + LOGGER.debug( f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context." ) if random.uniform(0, 1) < p: @@ -352,7 +352,7 @@ def __pick_documents(rec, p): else selected_docs ) else: - logger.warning( + LOGGER.warning( "Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results." ) docs = [answer_document] @@ -697,7 +697,7 @@ def collect( if knowledge_to_skills_ratio < MIN_UPSAMPLE_THRESHOLD: sampling_size = int(self._precomputed_skills_length * 0.03) - logger.info( + LOGGER.info( "\033[93mKnowledge detected to be less than %.2f%% of skills (%.2f%%), upsampling to: %d\033[0m", MIN_UPSAMPLE_THRESHOLD * 100, knowledge_to_skills_ratio * 100, @@ -739,7 +739,10 @@ def _gen_mixed_data(self, recipe, output_file_recipe, output_file_data): self.num_procs, ) - def generate(self): + def generate(self, logger=None): + if logger is not None: + global LOGGER # pylint: disable=global-statement + LOGGER = logger self._gen_mixed_data( self.knowledge_recipe, self.output_file_knowledge_recipe, diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 533db868..ae5c6582 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -34,7 +34,7 @@ read_taxonomy_leaf_nodes, ) -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) _SYS_PROMPT = "I am a Red Hat® Instruct Model, an AI language model developed by Red Hat and IBM Research based on the granite-3.0-8b-base model. My primary role is to serve as a chat assistant." @@ -90,7 +90,7 @@ def _gen_train_data( for output_dataset in machine_instruction_data: for synth_example in output_dataset: - logger.debug(synth_example) + LOGGER.debug(synth_example) user = _get_question_hack(synth_example) if len(synth_example.get("context", "")) > 0: user += "\n" + synth_example["context"] @@ -223,7 +223,7 @@ def _sdg_init(ctx, pipeline): config = yaml.safe_load(file) docling_model_path = config["models"][0]["path"] except (FileNotFoundError, NotADirectoryError, PermissionError) as e: - logger.warning(f"unable to read docling models path from config.yaml {e}") + LOGGER.warning(f"unable to read docling models path from config.yaml {e}") for d in data_dirs: pipeline_path = os.path.join(d, "pipelines", pipeline) @@ -285,7 +285,7 @@ def _mixer_init( # to be removed: logger def generate_data( client: openai.OpenAI, - logger: logging.Logger = logger, # pylint: disable=redefined-outer-name + logger: logging.Logger = None, # pylint: disable=redefined-outer-name system_prompt: Optional[str] = None, use_legacy_pretraining_format: Optional[bool] = True, model_family: Optional[str] = None, @@ -318,6 +318,10 @@ def generate_data( We expect three files to be present in this directory: "knowledge.yaml", "freeform_skills.yaml", and "grounded_skills.yaml". """ + if logger is not None: + global LOGGER # pylint: disable=global-statement + LOGGER = logger + generate_start = time.time() system_prompt = system_prompt if system_prompt is not None else _SYS_PROMPT @@ -336,7 +340,7 @@ def generate_data( document_output_dir = Path(output_dir) / f"documents-{date_suffix}" leaf_nodes = read_taxonomy_leaf_nodes( - taxonomy, taxonomy_base, yaml_rules, document_output_dir + taxonomy, taxonomy_base, yaml_rules, document_output_dir, logger=LOGGER ) if not leaf_nodes: raise GenerateException("Error: No new leaf nodes found in the taxonomy.") @@ -352,7 +356,7 @@ def generate_data( system_prompt, ) - logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}") + LOGGER.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}") model_family = models.get_model_family(model_family, model_name) @@ -385,7 +389,7 @@ def generate_data( ) if console_output: - logger.info( + LOGGER.info( "Synthesizing new instructions. If you aren't satisfied with the generated instructions, interrupt training (Ctrl-C) and try adjusting your YAML files. Adding more examples may help." ) @@ -402,6 +406,7 @@ def generate_data( document_output_dir, model_name, docling_model_path=docling_model_path, + logger=LOGGER, ) if not samples: @@ -417,17 +422,17 @@ def generate_data( else: pipe = freeform_skills_pipe - logger.debug("Samples: %s", samples) + LOGGER.debug("Samples: %s", samples) new_generated_data = pipe.generate(samples, leaf_node_path) if len(new_generated_data) == 0: empty_sdg_leaf_nodes.append(leaf_node_path) - logger.warning("Empty dataset for qna node: %s", leaf_node_path) + LOGGER.warning("Empty dataset for qna node: %s", leaf_node_path) continue generated_data.append(new_generated_data) - logger.info("Generated %d samples", len(generated_data)) - logger.debug("Generated data: %s", generated_data) + LOGGER.info("Generated %d samples", len(generated_data)) + LOGGER.debug("Generated data: %s", generated_data) if is_knowledge: # generate mmlubench data for the current leaf node @@ -453,12 +458,12 @@ def generate_data( system_prompt, ) - mixer.generate() + mixer.generate(logger=LOGGER) generate_duration = time.time() - generate_start - logger.info(f"Generation took {generate_duration:.2f}s") + LOGGER.info(f"Generation took {generate_duration:.2f}s") if len(empty_sdg_leaf_nodes) > 0: - logger.warning( + LOGGER.warning( "Leaf nodes with empty sdg output: {}".format( " ".join(empty_sdg_leaf_nodes) ) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 59613a8e..9a6e497c 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -23,7 +23,7 @@ from .blocks.block import Block from .registry import BlockRegistry -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) # This is part of the public API. @@ -134,13 +134,16 @@ def from_file(cls, ctx, pipeline_yaml): pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml) return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml)) - def generate(self, dataset, checkpoint_name=None) -> Dataset: + def generate(self, dataset, checkpoint_name=None, logger=None) -> Dataset: """ Generate the dataset by running the pipeline steps. dataset: the input dataset checkpoint_name: unique subdir name for the checkpoint within checkpoint_dir """ + if logger is not None: + global LOGGER # pylint: disable=global-statement + LOGGER = logger # The checkpointer allows us to resume from where we left off # Saving the output of pipe instances along the way checkpoint_dir = None @@ -153,12 +156,12 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset: # If not batching, simply delegate to _generate_single if not self.ctx.batching_enabled: - logger.info("Running pipeline single-threaded") + LOGGER.info("Running pipeline single-threaded") return self._generate_single(dataset) # Otherwise, split the dataset into batches and run each batch as a # future in the thread pool - logger.info( + LOGGER.info( "Running pipeline with multi-threaded batching. Using %s workers for batches of size %s", self.ctx.batch_num_workers, self.ctx.batch_size, @@ -197,7 +200,7 @@ def _generate_single(self, dataset) -> Dataset: drop_columns = block_prop.get("drop_columns", []) drop_duplicates_cols = block_prop.get("drop_duplicates", False) block = block_type(self.ctx, self, block_name, **block_config) - logger.info("Running block: %s", block_name) + LOGGER.info("Running block: %s", block_name) # Execute the block and wrap errors with the block name/type dataset = block.generate(dataset) except Exception as err: @@ -284,7 +287,7 @@ def _parse_pipeline_config_file(pipeline_yaml): "The pipeline config file format is from a future major version." ) if major <= _PIPELINE_CONFIG_PARSER_MAJOR and minor > _PIPELINE_CONFIG_PARSER_MINOR: - logger.warning( + LOGGER.warning( "The pipeline config file may have new features that will be ignored." ) diff --git a/src/instructlab/sdg/utils/taxonomy.py b/src/instructlab/sdg/utils/taxonomy.py index 24592fe1..88c575e2 100644 --- a/src/instructlab/sdg/utils/taxonomy.py +++ b/src/instructlab/sdg/utils/taxonomy.py @@ -30,7 +30,7 @@ # Initialize the pdf parser PDFParser = pdf_parser_v1() -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) def _is_taxonomy_file(fn: str) -> bool: @@ -41,7 +41,7 @@ def _is_taxonomy_file(fn: str) -> bool: return True if path.name.casefold() in {"qna.yml", "qna.yaml"}: # warning for incorrect extension or case variants - logger.warning( + LOGGER.warning( "Found a '%s' file: %s: taxonomy files must be named 'qna.yaml'. File will not be checked.", path.name, path, @@ -145,17 +145,17 @@ def _get_documents( file_contents = [] filepaths = [] - logger.info("Processing files...") + LOGGER.info("Processing files...") for pattern in file_patterns: # Use glob to find files matching the pattern matched_files = glob.glob( os.path.join(repo.working_dir, pattern), recursive=True ) - logger.info(f"Pattern '{pattern}' matched {len(matched_files)} files.") + LOGGER.info(f"Pattern '{pattern}' matched {len(matched_files)} files.") for file_path in matched_files: if os.path.isfile(file_path): - logger.info(f"Processing file: {file_path}") + LOGGER.info(f"Processing file: {file_path}") try: if file_path.lower().endswith(".md"): # Process Markdown files @@ -163,24 +163,24 @@ def _get_documents( content = file.read() file_contents.append(content) filepaths.append(Path(file_path)) - logger.info( + LOGGER.info( f"Appended Markdown content from {file_path}" ) elif file_path.lower().endswith(".pdf"): # Process PDF files using docling_parse's pdf_parser_v1 doc_key = f"key_{os.path.basename(file_path)}" # Unique document key - logger.info(f"Loading PDF document from {file_path}") + LOGGER.info(f"Loading PDF document from {file_path}") success = PDFParser.load_document(doc_key, file_path) if not success: - logger.warning( + LOGGER.warning( f"Failed to load PDF document: {file_path}" ) continue num_pages = PDFParser.number_of_pages(doc_key) - logger.info(f"PDF '{file_path}' has {num_pages} pages.") + LOGGER.info(f"PDF '{file_path}' has {num_pages} pages.") pdf_text = "" @@ -190,7 +190,7 @@ def _get_documents( doc_key, page ) if "pages" not in json_doc or not json_doc["pages"]: - logger.warning( + LOGGER.warning( f"Page {page + 1} could not be parsed in '{file_path}'" ) continue @@ -205,7 +205,7 @@ def _get_documents( if text.strip(): # Only append non-empty text pdf_text += text.strip() + "\n" except Exception as page_error: # pylint: disable=broad-exception-caught - logger.warning( + LOGGER.warning( f"Error parsing page {page + 1} of '{file_path}': {page_error}" ) continue @@ -216,24 +216,24 @@ def _get_documents( # Unload the document to free memory PDFParser.unload_document(doc_key) - logger.info(f"Unloaded PDF document: {file_path}") + LOGGER.info(f"Unloaded PDF document: {file_path}") else: - logger.info(f"Skipping unsupported file type: {file_path}") + LOGGER.info(f"Skipping unsupported file type: {file_path}") except Exception as file_error: # pylint: disable=broad-exception-caught - logger.error( + LOGGER.error( f"Error processing file '{file_path}': {file_error}" ) continue else: - logger.info(f"Skipping non-file path: {file_path}") + LOGGER.info(f"Skipping non-file path: {file_path}") if file_contents: return file_contents, filepaths raise SystemExit("Couldn't find knowledge documents") except (OSError, git.exc.GitCommandError, FileNotFoundError) as e: - logger.error("Error retrieving documents: %s", str(e)) + LOGGER.error("Error retrieving documents: %s", str(e)) raise e @@ -273,7 +273,7 @@ def _read_taxonomy_file( source=documents, document_output_dir=unique_output_dir, ) - logger.debug("Content from git repo fetched") + LOGGER.debug("Content from git repo fetched") for seed_example in contents.get("seed_examples"): context = seed_example.get("context", "") @@ -321,10 +321,10 @@ def read_taxonomy( if yaml_rules is not None: # user attempted to pass custom rules file yaml_rules_path = Path(yaml_rules) if yaml_rules_path.is_file(): # file was found, use specified config - logger.debug("Using YAML rules from %s", yaml_rules) + LOGGER.debug("Using YAML rules from %s", yaml_rules) yamllint_config = yaml_rules_path.read_text(encoding="utf-8") else: - logger.debug("Cannot find %s. Using default rules.", yaml_rules) + LOGGER.debug("Cannot find %s. Using default rules.", yaml_rules) seed_instruction_data = [] is_file = os.path.isfile(taxonomy) @@ -333,7 +333,7 @@ def read_taxonomy( taxonomy, yamllint_config, document_output_dir ) if warnings: - logger.warning( + LOGGER.warning( f"{warnings} warnings (see above) due to taxonomy file not (fully) usable." ) if errors: @@ -348,9 +348,9 @@ def read_taxonomy( total_errors = 0 total_warnings = 0 if taxonomy_files: - logger.debug("Found taxonomy files:") + LOGGER.debug("Found taxonomy files:") for e in taxonomy_files: - logger.debug(f"* {e}") + LOGGER.debug(f"* {e}") for f in taxonomy_files: file_path = os.path.join(taxonomy, f) data, warnings, errors = _read_taxonomy_file( @@ -361,7 +361,7 @@ def read_taxonomy( if data: seed_instruction_data.extend(data) if total_warnings: - logger.warning( + LOGGER.warning( f"{total_warnings} warnings (see above) due to taxonomy files that were not (fully) usable." ) if total_errors: @@ -372,8 +372,11 @@ def read_taxonomy( def read_taxonomy_leaf_nodes( - taxonomy, taxonomy_base, yaml_rules, document_output_dir=None + taxonomy, taxonomy_base, yaml_rules, document_output_dir=None, logger=None ): + if logger is not None: + global LOGGER # pylint: disable=global-statement + LOGGER = logger seed_instruction_data = read_taxonomy( taxonomy, taxonomy_base, yaml_rules, document_output_dir ) @@ -463,7 +466,11 @@ def leaf_node_to_samples( document_output_dir, model_name, docling_model_path=None, + logger=None, ): + if logger is not None: + global LOGGER # pylint: disable=global-statement + LOGGER = logger if not leaf_node: return [] if leaf_node[0].get("documents"):