-
Notifications
You must be signed in to change notification settings - Fork 68
[QEff. Finetuning] Adding finetune_experiemental.py and related files #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ft_experimental
Are you sure you want to change the base?
Changes from all commits
b260357
d0d3251
a7278cb
a30e5a1
04279f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,279 @@ | |
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| """ | ||
| Main entry point for fine-tuning LLMs using the experimental finetune framework. | ||
| """ | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback | ||
| from QEfficient.finetune.experimental.core.component_registry import ComponentFactory | ||
| from QEfficient.finetune.experimental.core.config_manager import ( | ||
| ConfigManager, | ||
| MasterConfig, | ||
| ) | ||
| from QEfficient.finetune.experimental.core.dataset import SFTDataset # noqa: F401 | ||
| from QEfficient.finetune.experimental.core.logger import Logger | ||
| from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401 | ||
| from QEfficient.finetune.experimental.core.optimizer import prepare_optimizer | ||
| from QEfficient.finetune.experimental.core.trainer import sft_trainer | ||
| from QEfficient.finetune.experimental.core.utils.peft_utils import convert_peft_config_to_lora_config | ||
| from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config | ||
|
|
||
| logger = Logger(__name__) | ||
|
|
||
| # Try importing QAIC-specific module, proceed without it if it's unavailable | ||
| try: | ||
| import torch_qaic # noqa: F401 | ||
| except ImportError as e: | ||
| logger.log_rank_zero( | ||
| f"Unable to import 'torch_qaic' package due to exception: {e}. Moving ahead without the torch_qaic extension.", | ||
| level="warning", | ||
| ) | ||
|
|
||
|
|
||
| class FineTuningPipeline: | ||
| """ | ||
| Main pipeline class for fine-tuning LLMs. | ||
| """ | ||
|
|
||
| def __init__(self, config_manager: ConfigManager): | ||
| """ | ||
| Initialize the fine-tuning pipeline with configuration. | ||
|
|
||
| Args: | ||
| config_manager: ConfigManager instance with loaded and validated configuration | ||
| """ | ||
| self.config_manager = config_manager | ||
| self.config = self.config_manager.config | ||
| self.output_dir = Path(self.config.training["output_dir"]) | ||
| self._setup_environment() | ||
|
|
||
| def _setup_environment(self) -> None: | ||
| """Set up environment variables for output directories.""" | ||
| os.environ["OUTPUT_DIR"] = str(self.output_dir) | ||
| os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs") | ||
| os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir) | ||
|
|
||
| def _create_datasets(self) -> Tuple[Any, Any]: | ||
| """ | ||
| Create training and evaluation datasets. | ||
|
|
||
| Returns: | ||
| Tuple of (train_dataset, eval_dataset) | ||
| """ | ||
| dataset_config = self.config_manager.get_dataset_config() | ||
|
|
||
| dataset_type = dataset_config.get("dataset_type") | ||
| dataset_name = dataset_config.get("dataset_name") | ||
| train_split = dataset_config.get("train_split", "train") | ||
| test_split = dataset_config.get("test_split", "test") | ||
| seed = self.config.training["seed"] | ||
|
|
||
| # Create a copy of dataset_config excluding keys that are passed explicitly | ||
| # to avoid duplicate keyword arguments when unpacking | ||
| excluded_keys = ("dataset_type", "dataset_name", "split", "seed", "train_split", "test_split") | ||
| dataset_config_copy = {k: v for k, v in dataset_config.items() if k not in excluded_keys} | ||
|
|
||
| # Helper function to create a dataset for a specific split | ||
| def create_dataset_for_split(split_name: str) -> Any: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe seed, dataset_type and dataset_name remains same for both train and test. Only split needs to be explicitly passed. |
||
| return ComponentFactory.create_dataset( | ||
| dataset_type=dataset_type, | ||
| dataset_name=dataset_name, | ||
| split=split_name, | ||
| seed=seed, | ||
| **dataset_config_copy, | ||
| ) | ||
|
|
||
| # Create training and evaluation datasets using config values | ||
| train_dataset = create_dataset_for_split(train_split) | ||
| eval_dataset = create_dataset_for_split(test_split) | ||
|
|
||
| return train_dataset, eval_dataset | ||
|
|
||
| def _create_model(self) -> Any: | ||
| """ | ||
| Create and load the model instance. | ||
|
|
||
| Returns: | ||
| Model instance with loaded model and tokenizer | ||
| """ | ||
| # Get model config as dict | ||
| model_config = self.config_manager.get_model_config() | ||
|
|
||
| # Extract required fields | ||
| model_type = model_config.pop("model_type") | ||
| model_name = model_config.pop("model_name") | ||
|
|
||
| # Filter out PEFT-related fields, these shouldn't be passed to model creation | ||
| excluded_keys = {"use_peft", "peft_config"} | ||
| model_config_kwargs = {k: v for k, v in model_config.items() if k not in excluded_keys} | ||
|
|
||
| model_instance = ComponentFactory.create_model(model_type, model_name, **model_config_kwargs) | ||
| return model_instance | ||
|
|
||
| def _create_optimizer(self) -> Tuple[Any, Dict[str, Any]]: | ||
| """ | ||
| Create optimizer configuration. | ||
|
|
||
| Returns: | ||
| Tuple of (optimizer_class, optimizer_kwargs) | ||
| """ | ||
| optimizer_config = self.config_manager.get_optimizer_config() | ||
| return prepare_optimizer(optimizer_config) | ||
|
|
||
| def _create_callbacks(self) -> List[Any]: | ||
| """ | ||
| Create callback instances from configuration. | ||
|
|
||
| Returns: | ||
| List of callback instances | ||
| """ | ||
| callback_config = self.config_manager.get_callback_config() | ||
| callbacks = [] | ||
|
|
||
| # callback_config.callbacks is a dictionary of callback configurations | ||
| for callback_name, callback_kwargs in callback_config["callbacks"].items(): | ||
| try: | ||
| callback_instance = ComponentFactory.create_callback(callback_name, **callback_kwargs) | ||
| callbacks.append(callback_instance) | ||
| except ValueError as e: | ||
| logger.log_rank_zero(f"Warning: Failed to create callback '{callback_name}': {e}", level="warning") | ||
|
|
||
| return callbacks | ||
|
|
||
| def _create_trainer( | ||
| self, | ||
| model: Any, | ||
| tokenizer: Any, | ||
| train_dataset: Any, | ||
| eval_dataset: Any, | ||
| optimizer_cls_and_kwargs: Tuple[Any, Dict[str, Any]], | ||
| callbacks: List[Any], | ||
| training_config: Dict[str, Any], | ||
| ) -> Any: | ||
| """ | ||
| Create and configure the trainer instance. | ||
|
|
||
| Args: | ||
| model: The model to train | ||
| tokenizer: Tokenizer for processing | ||
| train_dataset: Training dataset | ||
| eval_dataset: Evaluation dataset | ||
| optimizer_cls_and_kwargs: Optimizer class and kwargs tuple | ||
| callbacks: List of callbacks | ||
| training_config: Training configuration dictionary | ||
|
|
||
| Returns: | ||
| Trainer instance | ||
| """ | ||
| trainer_type = training_config.pop("type") | ||
|
|
||
| # Get PEFT config if enabled | ||
| model_config_dict = self.config_manager.get_model_config() | ||
| peft_config = None | ||
| if model_config_dict.get("use_peft", False): | ||
| peft_config_dataclass = model_config_dict.get("peft_config") | ||
| if peft_config_dataclass is not None: | ||
| peft_config = convert_peft_config_to_lora_config(peft_config_dataclass) | ||
|
|
||
| # Build dependencies for trainer configuration | ||
| dependencies = {} | ||
| if peft_config is not None: | ||
| dependencies["peft_config"] = peft_config | ||
| trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies) | ||
|
|
||
| # Clean up training config: remove fields that shouldn't be passed to TrainingArguments | ||
| training_config.pop("device", None) | ||
| # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| training_config.pop("deepspeed_config", None) | ||
| training_config.pop("torch_dtype", None) | ||
|
|
||
| # Create trainer arguments instance | ||
| args = args_cls(**training_config) | ||
| # Initialize trainer | ||
| trainer = trainer_cls( | ||
| model=model, | ||
| processing_class=tokenizer, | ||
| args=args, | ||
| compute_loss_func=None, | ||
| train_dataset=train_dataset.dataset, | ||
| eval_dataset=eval_dataset.dataset, | ||
| optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, | ||
| callbacks=callbacks, | ||
| **additional_kwargs, | ||
| ) | ||
|
|
||
| replace_progress_callback(trainer, callbacks, logger) | ||
|
|
||
| return trainer | ||
|
|
||
| def run(self) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks really clean, thanks |
||
| """ | ||
| Execute the complete fine-tuning pipeline. | ||
| """ | ||
| # Validate configuration | ||
| self.config_manager.validate_config() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be explicitly called. This should happen as a part of init of ConfigManager class. CC: @tchawada Please do the needful.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay |
||
|
|
||
| # Prepare training configuration | ||
| training_config = prepare_training_config(config_manager=self.config_manager) | ||
|
|
||
| # Create datasets | ||
| logger.log_rank_zero("Creating datasets...") | ||
| train_dataset, eval_dataset = self._create_datasets() | ||
|
|
||
| # Create model and tokenizer | ||
| logger.log_rank_zero("Loading model and tokenizer...") | ||
| model_instance = self._create_model() | ||
| model = model_instance.model | ||
| tokenizer = model_instance.tokenizer | ||
|
|
||
| # Create optimizer | ||
| logger.log_rank_zero("Preparing optimizer...") | ||
| optimizer_cls_and_kwargs = self._create_optimizer() | ||
|
|
||
| # Create callbacks | ||
| logger.log_rank_zero("Creating callbacks...") | ||
| callbacks = self._create_callbacks() | ||
|
|
||
| # Create trainer | ||
| logger.log_rank_zero("Initializing trainer...") | ||
| trainer = self._create_trainer( | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, | ||
| callbacks=callbacks, | ||
| training_config=training_config, | ||
| ) | ||
|
|
||
| # Start training | ||
| logger.log_rank_zero("Starting training...") | ||
| trainer.train() | ||
|
|
||
|
|
||
| def main(): | ||
| """ | ||
| Main entry point for fine-tuning. | ||
|
|
||
| Parses command-line arguments or config file and runs the fine-tuning pipeline. | ||
| """ | ||
| # ConfigManager now handles argument parsing internally via its __init__ | ||
| # It will automatically detect and parse: | ||
| # - Command-line args (if len(sys.argv) > 1) | ||
| # - Config file path (if sys.argv[1] ends with .yaml) | ||
| # - Or use defaults if no args provided | ||
| config_manager = ConfigManager() | ||
|
|
||
| # Create and run pipeline - pass ConfigManager directly to avoid redundant wrapping | ||
| pipeline = FineTuningPipeline(config_manager) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| from transformers.integrations.integration_utils import TensorBoardCallback | ||
| from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState | ||
|
|
||
| from QEfficient.finetune.experimental.core.component_registry import registry | ||
| from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry | ||
| from QEfficient.finetune.experimental.core.utils.profiler_utils import ( | ||
| get_op_verifier_ctx, | ||
| init_qaic_profiling, | ||
|
|
@@ -197,9 +197,39 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra | |
| self.op_verifier_ctx_step.__exit__(None, None, None) | ||
|
|
||
|
|
||
| def create_callbacks(name: str, **kwargs) -> Any: | ||
| """Create a callback instance.""" | ||
| callback_class = registry.get_callback(name) | ||
| if callback_class is None: | ||
| raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") | ||
| return callback_class(**kwargs) | ||
| def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None: | ||
| """ | ||
| Replace default ProgressCallback with EnhancedProgressCallback if not already present. | ||
|
|
||
| Args: | ||
| trainer: Trainer instance | ||
| callbacks: List of callbacks already added | ||
| logger: Optional logger instance for warning messages | ||
| """ | ||
| # Check if EnhancedProgressCallback is already in callbacks | ||
| has_enhanced = any(callback.__class__.__name__ == "EnhancedProgressCallback" for callback in callbacks) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if not has_enhanced: | ||
| try: | ||
| # Remove default ProgressCallback if present | ||
| trainer.remove_callback(ProgressCallback) | ||
| except (AttributeError, ValueError) as e: | ||
| # Callback not present or method doesn't exist, continue | ||
| if logger: | ||
| logger.log_rank_zero( | ||
| f"Debug: Could not remove default ProgressCallback: {e}. This is expected if callback is not present.", | ||
| level="debug", | ||
| ) | ||
| pass | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| try: | ||
| # Add EnhancedProgressCallback | ||
| enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar") | ||
| trainer.add_callback(enhanced_callback) | ||
| except Exception as e: | ||
| if logger: | ||
| logger.log_rank_zero(f"Warning: Could not add enhanced progress callback: {e}", level="warning") | ||
| else: | ||
| import warnings | ||
|
|
||
| warnings.warn(f"Could not add enhanced progress callback: {e}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use logger for this. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we are passing and not blocking. It will be fine here but if user has provided device="qaic" and torch_qaic is loaded then we need to break the execution. This has to happen either here or inside ConfigManager. I believe this kind of validation and all other validations about config should reside inside ConfigManager.
CC: @tchawada
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in new PR