diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py index d647b73a6..e423c827a 100644 --- a/QEfficient/cloud/finetune_experimental.py +++ b/QEfficient/cloud/finetune_experimental.py @@ -4,3 +4,279 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +""" +Main entry point for fine-tuning LLMs using the experimental finetune framework. +""" + +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory +from QEfficient.finetune.experimental.core.config_manager import ( + ConfigManager, + MasterConfig, +) +from QEfficient.finetune.experimental.core.dataset import SFTDataset # noqa: F401 +from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401 +from QEfficient.finetune.experimental.core.optimizer import prepare_optimizer +from QEfficient.finetune.experimental.core.trainer import sft_trainer +from QEfficient.finetune.experimental.core.utils.peft_utils import convert_peft_config_to_lora_config +from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config + +logger = Logger(__name__) + +# Try importing QAIC-specific module, proceed without it if it's unavailable +try: + import torch_qaic # noqa: F401 +except ImportError as e: + logger.log_rank_zero( + f"Unable to import 'torch_qaic' package due to exception: {e}. Moving ahead without the torch_qaic extension.", + level="warning", + ) + + +class FineTuningPipeline: + """ + Main pipeline class for fine-tuning LLMs. + """ + + def __init__(self, config_manager: ConfigManager): + """ + Initialize the fine-tuning pipeline with configuration. + + Args: + config_manager: ConfigManager instance with loaded and validated configuration + """ + self.config_manager = config_manager + self.config = self.config_manager.config + self.output_dir = Path(self.config.training["output_dir"]) + self._setup_environment() + + def _setup_environment(self) -> None: + """Set up environment variables for output directories.""" + os.environ["OUTPUT_DIR"] = str(self.output_dir) + os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs") + os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir) + + def _create_datasets(self) -> Tuple[Any, Any]: + """ + Create training and evaluation datasets. + + Returns: + Tuple of (train_dataset, eval_dataset) + """ + dataset_config = self.config_manager.get_dataset_config() + + dataset_type = dataset_config.get("dataset_type") + dataset_name = dataset_config.get("dataset_name") + train_split = dataset_config.get("train_split", "train") + test_split = dataset_config.get("test_split", "test") + seed = self.config.training["seed"] + + # Create a copy of dataset_config excluding keys that are passed explicitly + # to avoid duplicate keyword arguments when unpacking + excluded_keys = ("dataset_type", "dataset_name", "split", "seed", "train_split", "test_split") + dataset_config_copy = {k: v for k, v in dataset_config.items() if k not in excluded_keys} + + # Helper function to create a dataset for a specific split + def create_dataset_for_split(split_name: str) -> Any: + return ComponentFactory.create_dataset( + dataset_type=dataset_type, + dataset_name=dataset_name, + split=split_name, + seed=seed, + **dataset_config_copy, + ) + + # Create training and evaluation datasets using config values + train_dataset = create_dataset_for_split(train_split) + eval_dataset = create_dataset_for_split(test_split) + + return train_dataset, eval_dataset + + def _create_model(self) -> Any: + """ + Create and load the model instance. + + Returns: + Model instance with loaded model and tokenizer + """ + # Get model config as dict + model_config = self.config_manager.get_model_config() + + # Extract required fields + model_type = model_config.pop("model_type") + model_name = model_config.pop("model_name") + + # Filter out PEFT-related fields, these shouldn't be passed to model creation + excluded_keys = {"use_peft", "peft_config"} + model_config_kwargs = {k: v for k, v in model_config.items() if k not in excluded_keys} + + model_instance = ComponentFactory.create_model(model_type, model_name, **model_config_kwargs) + return model_instance + + def _create_optimizer(self) -> Tuple[Any, Dict[str, Any]]: + """ + Create optimizer configuration. + + Returns: + Tuple of (optimizer_class, optimizer_kwargs) + """ + optimizer_config = self.config_manager.get_optimizer_config() + return prepare_optimizer(optimizer_config) + + def _create_callbacks(self) -> List[Any]: + """ + Create callback instances from configuration. + + Returns: + List of callback instances + """ + callback_config = self.config_manager.get_callback_config() + callbacks = [] + + # callback_config.callbacks is a dictionary of callback configurations + for callback_name, callback_kwargs in callback_config["callbacks"].items(): + try: + callback_instance = ComponentFactory.create_callback(callback_name, **callback_kwargs) + callbacks.append(callback_instance) + except ValueError as e: + logger.log_rank_zero(f"Warning: Failed to create callback '{callback_name}': {e}", level="warning") + + return callbacks + + def _create_trainer( + self, + model: Any, + tokenizer: Any, + train_dataset: Any, + eval_dataset: Any, + optimizer_cls_and_kwargs: Tuple[Any, Dict[str, Any]], + callbacks: List[Any], + training_config: Dict[str, Any], + ) -> Any: + """ + Create and configure the trainer instance. + + Args: + model: The model to train + tokenizer: Tokenizer for processing + train_dataset: Training dataset + eval_dataset: Evaluation dataset + optimizer_cls_and_kwargs: Optimizer class and kwargs tuple + callbacks: List of callbacks + training_config: Training configuration dictionary + + Returns: + Trainer instance + """ + trainer_type = training_config.pop("type") + + # Get PEFT config if enabled + model_config_dict = self.config_manager.get_model_config() + peft_config = None + if model_config_dict.get("use_peft", False): + peft_config_dataclass = model_config_dict.get("peft_config") + if peft_config_dataclass is not None: + peft_config = convert_peft_config_to_lora_config(peft_config_dataclass) + + # Build dependencies for trainer configuration + dependencies = {} + if peft_config is not None: + dependencies["peft_config"] = peft_config + trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies) + + # Clean up training config: remove fields that shouldn't be passed to TrainingArguments + training_config.pop("device", None) + # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config + training_config.pop("deepspeed_config", None) + training_config.pop("torch_dtype", None) + + # Create trainer arguments instance + args = args_cls(**training_config) + # Initialize trainer + trainer = trainer_cls( + model=model, + processing_class=tokenizer, + args=args, + compute_loss_func=None, + train_dataset=train_dataset.dataset, + eval_dataset=eval_dataset.dataset, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + callbacks=callbacks, + **additional_kwargs, + ) + + replace_progress_callback(trainer, callbacks, logger) + + return trainer + + def run(self) -> None: + """ + Execute the complete fine-tuning pipeline. + """ + # Validate configuration + self.config_manager.validate_config() + + # Prepare training configuration + training_config = prepare_training_config(config_manager=self.config_manager) + + # Create datasets + logger.log_rank_zero("Creating datasets...") + train_dataset, eval_dataset = self._create_datasets() + + # Create model and tokenizer + logger.log_rank_zero("Loading model and tokenizer...") + model_instance = self._create_model() + model = model_instance.model + tokenizer = model_instance.tokenizer + + # Create optimizer + logger.log_rank_zero("Preparing optimizer...") + optimizer_cls_and_kwargs = self._create_optimizer() + + # Create callbacks + logger.log_rank_zero("Creating callbacks...") + callbacks = self._create_callbacks() + + # Create trainer + logger.log_rank_zero("Initializing trainer...") + trainer = self._create_trainer( + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + callbacks=callbacks, + training_config=training_config, + ) + + # Start training + logger.log_rank_zero("Starting training...") + trainer.train() + + +def main(): + """ + Main entry point for fine-tuning. + + Parses command-line arguments or config file and runs the fine-tuning pipeline. + """ + # ConfigManager now handles argument parsing internally via its __init__ + # It will automatically detect and parse: + # - Command-line args (if len(sys.argv) > 1) + # - Config file path (if sys.argv[1] ends with .yaml) + # - Or use defaults if no args provided + config_manager = ConfigManager() + + # Create and run pipeline - pass ConfigManager directly to avoid redundant wrapping + pipeline = FineTuningPipeline(config_manager) + pipeline.run() + + +if __name__ == "__main__": + main() diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 30659e3bb..bd1ce91c2 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -19,7 +19,7 @@ from transformers.integrations.integration_utils import TensorBoardCallback from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState -from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry from QEfficient.finetune.experimental.core.utils.profiler_utils import ( get_op_verifier_ctx, init_qaic_profiling, @@ -197,9 +197,39 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra self.op_verifier_ctx_step.__exit__(None, None, None) -def create_callbacks(name: str, **kwargs) -> Any: - """Create a callback instance.""" - callback_class = registry.get_callback(name) - if callback_class is None: - raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") - return callback_class(**kwargs) +def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None: + """ + Replace default ProgressCallback with EnhancedProgressCallback if not already present. + + Args: + trainer: Trainer instance + callbacks: List of callbacks already added + logger: Optional logger instance for warning messages + """ + # Check if EnhancedProgressCallback is already in callbacks + has_enhanced = any(callback.__class__.__name__ == "EnhancedProgressCallback" for callback in callbacks) + + if not has_enhanced: + try: + # Remove default ProgressCallback if present + trainer.remove_callback(ProgressCallback) + except (AttributeError, ValueError) as e: + # Callback not present or method doesn't exist, continue + if logger: + logger.log_rank_zero( + f"Debug: Could not remove default ProgressCallback: {e}. This is expected if callback is not present.", + level="debug", + ) + pass + + try: + # Add EnhancedProgressCallback + enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar") + trainer.add_callback(enhanced_callback) + except Exception as e: + if logger: + logger.log_rank_zero(f"Warning: Could not add enhanced progress callback: {e}", level="warning") + else: + import warnings + + warnings.warn(f"Could not add enhanced progress callback: {e}") diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index d1f948031..043552275 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import logging -from typing import Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type # from QEfficient.finetune.experimental.core.logger import get_logger @@ -201,10 +201,77 @@ def list_callbacks(self) -> list[str]: class ComponentFactory: @staticmethod - def create_model(model_type: str, model_name: str, **kwargs) -> any: + def create_model(model_type: str, model_name: str, **kwargs) -> Any: """Create a model instance.""" model_class = registry.get_model(model_type) if model_class is None: raise ValueError(f"Unknown model: {model_type}. Available: {registry.list_models()}") model_instance = model_class.create(model_name, **kwargs) return model_instance + + @staticmethod + def create_trainer_config(name: str, **dependencies) -> tuple: + """ + Create trainer configuration based on registered trainer modules. + + Args: + name: Name of the trainer type + **dependencies: Any dependencies needed to configure the trainer + + Returns: + tuple: (trainer_class, args_class, additional_kwargs) + """ + config = registry.get_trainer_module(name) + + # Process required kwargs based on available dependencies + additional_kwargs = {} + for kwarg, default in config["required_kwargs"].items(): + if kwarg in dependencies: + additional_kwargs[kwarg] = dependencies[kwarg] + elif default != "REQUIRED": + additional_kwargs[kwarg] = default + + # Check for missing required arguments + for kwarg, default in config["required_kwargs"].items(): + if kwarg not in additional_kwargs and default == "REQUIRED": + raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") + + return config["trainer_cls"], config["args_cls"], additional_kwargs + + @staticmethod + def create_dataset(dataset_type: str, dataset_name: str, split: str, seed: int = 42, **kwargs) -> Any: + """ + Create a dataset instance. + + Args: + dataset_type: Type of dataset to create (e.g., 'sft_dataset') + dataset_name: Name of the dataset to load + split: Dataset split ("train", "test", etc.) + seed: Random seed for reproducibility + **kwargs: Additional dataset configuration parameters + + Returns: + Dataset instance + """ + dataset_class = registry.get_dataset(dataset_type) + if dataset_class is None: + raise ValueError(f"Unknown dataset type: {dataset_type}. Available: {registry.list_datasets()}") + dataset_instance = dataset_class(dataset_name=dataset_name, split=split, seed=seed, **kwargs) + return dataset_instance + + @staticmethod + def create_callback(name: str, **kwargs) -> Any: + """ + Create a callback instance. + + Args: + name: Name of the callback to create + **kwargs: Additional callback configuration parameters + + Returns: + Callback instance + """ + callback_class = registry.get_callback(name) + if callback_class is None: + raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") + return callback_class(**kwargs) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index b28c2e1e3..5b5a8a819 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -11,6 +11,7 @@ import json import os +import sys from dataclasses import asdict, dataclass, field, fields, is_dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -18,7 +19,9 @@ import yaml from transformers.hf_argparser import HfArgumentParser -from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.logger import Logger + +logger = Logger(__name__) @dataclass @@ -55,6 +58,10 @@ class SchedulerConfig: "ratio of total training steps for the warmup phase." }, ) + warmup_ratio: int = field( + default=0.1, + metadata={"help": "ratio of total training steps for the warmup phase. value is within [0-1) range."}, + ) @dataclass @@ -70,7 +77,7 @@ class DatasetConfig: metadata={"help": "The type of dataset (e.g., 'seq_completion')."}, ) dataset_name: str = field( - default="knkarthick/samsum", + default="yahma/alpaca-cleaned", metadata={"help": "The name or path of the dataset."}, ) dataset_subset: str = field( @@ -93,7 +100,7 @@ class DatasetConfig: default=0.8, metadata={"help": "Ratio for train/test split, used when only train_split is provided."}, ) - input_columns: list[str] = field( + input_columns: List[str] = field( default_factory=lambda: ["text"], metadata={"help": "List of column names containing input text."}, ) @@ -113,6 +120,22 @@ class DatasetConfig: default=4, metadata={"help": "Number of workers for dataset processing."}, ) + prompt_template: str = field( + default=None, + metadata={"help": "Template for formatting prompts (e.g., 'User: {input} Assistant: ')."}, + ) + prompt_func: str = field( + default=None, + metadata={"help": "Function for formatting prompts (e.g., 'User: {input} Assistant: ')."}, + ) + completion_template: str = field( + default=None, + metadata={"help": "Template for formatting output completions (e.g., '{output}')."}, + ) + completion_func: str = field( + default=None, + metadata={"help": "Function for formatting output completions (e.g., '{output}')."}, + ) collate_fn: str = field( default="dynamic_padding", metadata={"help": "The collation function to use (e.g., 'dynamic_padding')."}, @@ -145,6 +168,11 @@ class DatasetConfig: default=1, metadata={"help": "Number of workers for the DataLoader."}, ) + config_name: str = field( + default="default", + metadata={"help": "Name of the hf configuration file."}, + ) + json_file_path: str = field(default=None, metadata={"help": "Path to a JSON file containing data."}) @dataclass @@ -163,7 +191,7 @@ class PeftConfig: default=0.1, metadata={"help": "The dropout probability for Lora layers."}, ) - target_modules: list[str] = field( + target_modules: List[str] = field( default_factory=lambda: ["q_proj", "v_proj"], metadata={"help": "The modules to apply Lora to."}, ) @@ -252,7 +280,7 @@ class DdpConfig: """Arguments for Distributed Data Parallel (DDP) training.""" ddp_backend: str = field( - default="qccl", + default=None, metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."}, ) ddp_find_unused_parameters: bool = field( @@ -293,10 +321,6 @@ class TrainingConfig: default=42, metadata={"help": "Random seed for reproducibility."}, ) - device: str = field( - default="qaic", - metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, - ) do_eval: bool = field( default=True, metadata={"help": "Whether to run evaluation during training."}, @@ -329,7 +353,6 @@ class TrainingConfig: default=-1, metadata={"help": "If > 0: set total number of training steps to perform."}, ) - log_level: str = field( default="info", metadata={"help": "Set the verbosity level of the logs ('debug', 'info', 'warning', 'error')."}, @@ -363,12 +386,6 @@ class TrainingConfig: default="eval_loss", metadata={"help": "The metric to use to compare two models ('eval_loss', etc.)."}, ) - - dtype: str = field( - default="fp16", - metadata={"help": "The data type to use for training (e.g., 'fp16', 'bf16')."}, - ) - gradient_checkpointing: bool = field( default=False, metadata={"help": "Whether to use gradient checkpointing."}, @@ -377,9 +394,16 @@ class TrainingConfig: default_factory=GradientCheckpointingKwargs, metadata={"help": "Arguments for gradient checkpointing."}, ) - + device: str = field( + default="qaic", + metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, + ) + torch_dtype: str = field( + default="fp16", + metadata={"help": "The torch data type to use for model weights (e.g., 'fp32', 'fp16', 'bf16')."}, + ) torch_compile: bool = field( - default=True, + default=False, metadata={"help": "Whether to compile the model with `torch.compile`."}, ) include_num_input_tokens_seen: bool = field( @@ -412,7 +436,7 @@ class TrainingConfig: metadata={"help": "DDP configuration dictionary."}, ) use_cpu: Optional[bool] = field( - default=None, + default=False, metadata={"help": "Whether to explicitly run training on CPU."}, ) resume_from_checkpoint: Optional[str] = field( @@ -460,47 +484,85 @@ class MasterConfig: ) -def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig: - """Create argument parser for the new finetuning interface.""" - parser = HfArgumentParser(MasterConfig) - - if config_path: - config_path = os.path.abspath(config_path) - if not os.path.exists(config_path): - raise FileNotFoundError(f"Config file not found: {config_path}") - if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): - raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") - - try: - (master_config,) = parser.parse_yaml_file(yaml_file=config_path) - return master_config - except Exception as e: - raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") - - args = [] if args is None else args - # If a single positional YAML file was passed via args, parse it as YAML - if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): - yaml_path = os.path.abspath(args[0]) - (master_config,) = parser.parse_yaml_file(yaml_file=yaml_path) - else: - (master_config,) = parser.parse_args_into_dataclasses(args=args) - master_config = asdict(master_config) - master_config = MasterConfig(**master_config) - - return master_config - - class ConfigManager: """Manages configuration loading, validation, and updates.""" - def __init__(self, config: MasterConfig): + def __init__(self, config: Optional[MasterConfig] = None, config_path: Optional[str] = None): """ Initialize ConfigManager with either: - Path to config file (str or Path) - Configuration dictionary - - None (creates empty config) """ - self.config = config + if config: + self.config = config + else: + self.config = MasterConfig() + + if config_path and not config: + logger.log_rank_zero("Loading configuration from config_path...") + config_path = os.path.abspath(config_path) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): + raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") + try: + self.load_config(config_path) + except Exception as e: + raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") + + elif config and not config_path: + logger.log_rank_zero("Loading configuration from config object...") + + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + logger.log_rank_zero("Loading configuration from config_path from CLI...") + config_path = os.path.abspath(sys.argv[1]) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + try: + self.load_config(config_path) + except Exception as e: + raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") + + elif len(sys.argv) > 2: + logger.log_rank_zero("Loading configuration flags from CLI...") + parser = HfArgumentParser( + ( + TrainingConfig, + ModelConfig, + DatasetConfig, + OptimizerConfig, + SchedulerConfig, + CallbackConfig, + PeftConfig, + DdpConfig, + GradientCheckpointingKwargs, + ) + ) + train_args, model_args, data_args, opt_args, schd_args, call_args, peft_args, ddp_args, gck_args, extra = ( + parser.parse_args_into_dataclasses(return_remaining_strings=True) + ) + train_args.ddp_config = ddp_args + train_args.gradient_checkpointing_kwargs = gck_args + model_args.peft_config = peft_args + self.config = MasterConfig( + model=model_args, + dataset=data_args, + training=train_args, + callbacks=call_args, + optimizers=opt_args, + scheduler=schd_args, + extra_params=extra, + ) + + else: + logger.log_rank_zero("Using default configuration...") + self.config = asdict(self.config) + self.config = MasterConfig(**self.config) + # Validate loaded config + try: + self.validate_config() + except Exception as e: + logger.log_rank_zero(f"Config validation failed with error: {e}") def load_config(self, config_path: Union[str, Path]) -> None: """Load configuration from file.""" @@ -517,7 +579,6 @@ def load_config(self, config_path: Union[str, Path]) -> None: config_dict = json.load(f) else: raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") - self.update_config(config_dict) def _ensure_extra_params(self, obj) -> Dict[str, Any]: @@ -598,16 +659,20 @@ def validate_config(self) -> None: """ Validate configuration parameters for MasterConfig. """ + cfg = self.config errors: List[str] = [] - cfg = self.config model = getattr(cfg, "model", {}) dataset = getattr(cfg, "dataset", {}) training = getattr(cfg, "training", {}) # ---------- Model ---------- self._push(errors, not model.get("model_name"), "model.model_name is required.") - + # Device + valid_devices = ["cpu", "cuda", "qaic"] + training_device = model.get("device", "qaic") + if training_device not in valid_devices: + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") # PEFT validation if model.get("use_peft"): pc = model.get("peft_config", {}) @@ -632,34 +697,46 @@ def validate_config(self) -> None: # ---------- Dataset ---------- self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") - self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.") # ---------- Training ---------- + # torch_dtype validation + torch_dtype = training.get("torch_dtype") + valid_dtypes = {"fp16", "bf16", "fp32"} + self._push( + errors, + not torch_dtype, + "training.torch_dtype is required.", + ) + self._push( + errors, + torch_dtype and torch_dtype not in valid_dtypes, + f"training.torch_dtype must be one of {valid_dtypes}.", + ) + # Batch sizes self._push( errors, - training.get("per_device_train_batch_size", 0) <= 0, + training.get("per_device_train_batch_size", 1) <= 0, "training.per_device_train_batch_size must be positive.", ) self._push( errors, - training.get("per_device_eval_batch_size", 0) <= 0, + training.get("per_device_eval_batch_size", 1) <= 0, "training.per_device_eval_batch_size must be positive.", ) # Epochs / steps - n_epochs = training.get("num_train_epochs", 0) - max_steps = training.get("max_steps", -1) + n_epochs = training.get("num_train_epochs", 1) self._push( errors, - n_epochs <= 0 and max_steps <= 0, - "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.", + n_epochs <= 0, + "Either training.num_train_epochs > 0 must be set.", ) # Gradient accumulation self._push( errors, - training.get("gradient_accumulation_steps", 0) <= 0, + training.get("gradient_accumulation_steps", 1) <= 0, "training.gradient_accumulation_steps must be positive.", ) @@ -667,12 +744,6 @@ def validate_config(self) -> None: self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.") self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.") - # Device - valid_devices = ["cpu", "cuda", "qaic"] - training_device = training.get("device", None) - if training_device not in valid_devices: - self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") - # DDP config ddp = training.get("ddp_config", {}) if isinstance(ddp, dict): @@ -710,8 +781,24 @@ def get_dataset_config(self) -> Dict[str, Any]: return self.config.dataset def get_model_config(self) -> Dict[str, Any]: - """Get model configuration as dictionary.""" - return self.config.model + """ + Get model configuration as dictionary. + + Automatically handles torch_dtype conversion from training config if not set in model config. + """ + model_config = self.config.model + + # Get torch_dtype from training config and convert + # To do: check if it can be moved from training config to model config instead + if model_config.get("torch_dtype") is None: + training_config = self.get_training_config() + training_dtype = training_config.get("torch_dtype") + if training_dtype: + # Convert from training format (fp16/bf16) to model format (float16/bfloat16) + dtype_mapping = {"fp16": "float16", "bf16": "bfloat16"} + model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto") + + return model_config def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary.""" @@ -722,32 +809,3 @@ def __getattr__(self, name: str) -> Any: if hasattr(self.config, name): return getattr(self.config, name) raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - -def create_trainer_config(name: str, **dependencies) -> tuple: - """ - Create trainer configuration based on registered trainer modules. - - Args: - name: Name of the trainer type - **dependencies: Any dependencies needed to configure the trainer - - Returns: - tuple: (trainer_class, args_class, additional_kwargs) - """ - config = registry.get_trainer_module(name) - - # Process required kwargs based on available dependencies - additional_kwargs = {} - for kwarg, default in config["required_kwargs"].items(): - if kwarg in dependencies: - additional_kwargs[kwarg] = dependencies[kwarg] - elif default != "REQUIRED": - additional_kwargs[kwarg] = default - - # Check for missing required arguments - for kwarg, default in config["required_kwargs"].items(): - if kwarg not in additional_kwargs and default == "REQUIRED": - raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") - - return config["trainer_cls"], config["args_cls"], additional_kwargs diff --git a/QEfficient/finetune/experimental/core/utils/peft_utils.py b/QEfficient/finetune/experimental/core/utils/peft_utils.py new file mode 100644 index 000000000..9c6cfaf3c --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/peft_utils.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Utility functions for PEFT (Parameter-Efficient Fine-Tuning) configuration. +""" + +from dataclasses import asdict +from typing import Any, Optional + +from peft import LoraConfig + + +def convert_peft_config_to_lora_config(peft_config: Any) -> Optional[LoraConfig]: + """ + Convert PeftConfig (dataclass or dict) to LoraConfig from peft library. + + Args: + peft_config: PeftConfig dataclass instance or dict + + Returns: + LoraConfig instance or None if PEFT is not enabled + """ + if peft_config is None: + return None + + # Convert dataclass to dictionary if needed + if hasattr(peft_config, "__dict__") and not isinstance(peft_config, dict): + peft_dict = asdict(peft_config) + else: + peft_dict = peft_config + + # Map PeftConfig fields to LoraConfig fields + lora_config_dict = { + "r": peft_dict.get("lora_r"), + "lora_alpha": peft_dict.get("lora_alpha"), + "lora_dropout": peft_dict.get("lora_dropout"), + "target_modules": peft_dict.get("target_modules"), + "bias": peft_dict.get("bias"), + "task_type": peft_dict.get("task_type"), + } + + return LoraConfig(**lora_config_dict) diff --git a/QEfficient/finetune/experimental/core/utils/training_config_utils.py b/QEfficient/finetune/experimental/core/utils/training_config_utils.py new file mode 100644 index 000000000..1cd6704e4 --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/training_config_utils.py @@ -0,0 +1,84 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Utility functions for preparing training configurations. +""" + +from typing import Any, Dict + +from QEfficient.finetune.experimental.core.config_manager import ConfigManager + + +def prepare_training_config( + config_manager: ConfigManager, + include_num_input_tokens_seen: bool = False, + use_cpu: bool = False, +) -> Dict[str, Any]: + """ + Prepare and transform training configuration for trainer initialization. + + Args: + config_manager: ConfigManager instance with loaded configuration + + Returns: + Dictionary of training arguments ready for trainer initialization + """ + # Get training config as dict and create mutable copy to avoid mutating original + training_config = dict(config_manager.get_training_config()) + + # Handle dtype conversion + # To do: (For Tanisha) Check if torch_dtype should rather be added directly in model_config only in config_manager.py + + torch_dtype = training_config.pop("torch_dtype", None) + if torch_dtype is None: + raise ValueError("'torch_dtype' field is required in training configuration. Expected one of: ['fp16', 'bf16']") + training_config[torch_dtype] = True + training_config["data_seed"] = training_config.get("seed") + + # Restoring the "torch_dtype" after torch_dtype conversion using the saved value + training_config["torch_dtype"] = torch_dtype + + # Handle scheduler configuration + scheduler_config = config_manager.get_scheduler_config() + training_config.setdefault("lr_scheduler_type", scheduler_config.get("scheduler_name")) + + # Set warmup_ratio and warmup_steps from scheduler_config if they exist and are not None + warmup_ratio = scheduler_config.get("warmup_ratio") + if warmup_ratio is not None: + training_config["warmup_ratio"] = warmup_ratio + warmup_steps = scheduler_config.get("warmup_steps") + if warmup_steps is not None: + training_config["warmup_steps"] = warmup_steps + + # Handle dataset configuration for dataloader settings + dataset_config = config_manager.get_dataset_config() + training_config.setdefault("dataloader_pin_memory", dataset_config.get("dataloader_pin_memory")) + training_config.setdefault("dataloader_persistent_workers", dataset_config.get("dataloader_persistent_workers")) + training_config.setdefault("dataloader_prefetch_factor", dataset_config.get("dataloader_prefetch_factor")) + training_config.setdefault("dataloader_drop_last", dataset_config.get("dataloader_drop_last")) + training_config.setdefault("dataloader_num_workers", dataset_config.get("dataloader_num_workers")) + training_config.setdefault("group_by_length", dataset_config.get("group_by_length")) + + # Handle DDP configuration + if training_config.get("ddp_config") is not None: + ddp_config = training_config.pop("ddp_config") + if not isinstance(ddp_config, dict): + from dataclasses import asdict, is_dataclass + + if is_dataclass(ddp_config): + ddp_config = asdict(ddp_config) + else: + raise TypeError( + f"ddp_config must be a dict or DdpConfig dataclass instance, " + f"got {type(ddp_config).__name__}: {ddp_config}" + ) + + # Merge ddp_config into training_config + training_config = {**training_config, **ddp_config} + + return training_config diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 59ff4d117..e085da9c9 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -8,8 +8,7 @@ import pytest from transformers import TrainerCallback -from QEfficient.finetune.experimental.core.callbacks import create_callbacks -from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry class ModelSummaryCallback(TrainerCallback): @@ -46,7 +45,7 @@ def test_callbacks(callback_name): # Create callbacks using the factory config = CALLBACK_CONFIGS[callback_name] try: - callback_inst = create_callbacks(**config) + callback_inst = ComponentFactory.create_callback(**config) except ValueError as e: assert "Unknown callback" in str(e) return diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index fd2abfd48..b4980ad2c 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -4,13 +4,11 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- - - from pathlib import Path import pytest -from QEfficient.finetune.experimental.core.config_manager import ConfigManager, parse_arguments +from QEfficient.finetune.experimental.core.config_manager import ConfigManager @pytest.fixture @@ -19,15 +17,15 @@ def config_path() -> Path: return (here / "test_config.yaml").resolve() +def test_default_config(): + config_manager = ConfigManager() + assert config_manager is not None + assert config_manager.config is not None + + def test_config(config_path): - master_config = parse_arguments(args=[]) - config_manager = ConfigManager(master_config) + config_manager = ConfigManager(config_path=config_path) assert isinstance(config_manager, ConfigManager) - config_manager.load_config(config_path) - try: - config_manager.validate_config() - except Exception as e: - pytest.fail(f"Config validation failed with error: {e}") # Test that all required fields are present missing = [ @@ -60,3 +58,30 @@ def test_config(config_path): assert optimizer_config is not None assert isinstance(optimizer_config, dict) assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr")) + + +def test_torch_dtype_validation(): + """Test that torch_dtype validation works correctly.""" + # Test with default config - should have torch_dtype set to fp16 by default + config_manager = ConfigManager() + training_config = config_manager.get_training_config() + assert training_config.get("torch_dtype") == "fp16" + + # Validation should pass with default config + config_manager.validate_config() # Should not raise + + +def test_torch_dtype_invalid(): + """Test that invalid torch_dtype raises validation error.""" + from QEfficient.finetune.experimental.core.config_manager import MasterConfig, TrainingConfig + + # Create config with invalid torch_dtype + training_config = TrainingConfig(torch_dtype="invalid_dtype") + master_config = MasterConfig(training=training_config) + config_manager = ConfigManager(config=master_config) + + # Validation should fail + with pytest.raises(ValueError) as exc_info: + config_manager.validate_config() + + assert "torch_dtype must be one of" in str(exc_info.value) diff --git a/QEfficient/finetune/experimental/tests/test_finetune.py b/QEfficient/finetune/experimental/tests/test_finetune.py new file mode 100644 index 000000000..5182a4395 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_finetune.py @@ -0,0 +1,654 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Unit tests for finetune_experimental.py. +Tests for FineTuningPipeline class and main() function. +""" + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from QEfficient.cloud.finetune_experimental import FineTuningPipeline, main +from QEfficient.finetune.experimental.core.config_manager import MasterConfig, TrainingConfig +from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config + + +class DictLikeMock: + """A mock that supports both dict access ['key'] and attribute access .key""" + + def __init__(self, data): + self._data = data + for key, value in data.items(): + setattr(self, key, value) + + def __getitem__(self, key): + return self._data[key] + + def __contains__(self, key): + return key in self._data + + def get(self, key, default=None): + return self._data.get(key, default) + + +class TestFineTuningPipeline: + """Test suite for FineTuningPipeline class.""" + + @pytest.fixture + def mock_master_config(self): + """Create a mock MasterConfig for testing.""" + config = MagicMock(spec=MasterConfig) + # Use DictLikeMock to support both dict access ['key'] and attribute access .key + config.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + return config + + @pytest.fixture + def mock_config_manager(self): + """Create a mock ConfigManager.""" + config_manager = MagicMock() + config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "seed": 42, + } + config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + } + config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + config_manager.get_optimizer_config.return_value = { + "optimizer_name": "adamw", + "lr": 1e-4, + } + config_manager.get_callback_config.return_value = {"callbacks": {}} + config_manager.validate_config = MagicMock() + return config_manager + + def test_initialization(self, mock_config_manager): + """Test pipeline initialization.""" + # Set up config_manager.config to return a mock that has training dict access + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + pipeline = FineTuningPipeline(mock_config_manager) + + assert pipeline.config_manager == mock_config_manager + assert pipeline.config == mock_config_obj + assert isinstance(pipeline.output_dir, Path) + assert pipeline.output_dir == Path("./test_output") + + def test_setup_environment(self, mock_config_manager): + """Test environment variable setup.""" + # Set up config_manager.config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Clear environment variables + env_vars = ["OUTPUT_DIR", "TRACKIO_DIR", "TENSORBOARD_LOGGING_DIR"] + for var in env_vars: + if var in os.environ: + del os.environ[var] + + pipeline = FineTuningPipeline(mock_config_manager) + + # Verify environment variables are set + assert os.environ["OUTPUT_DIR"] == str(pipeline.output_dir) + assert os.environ["TRACKIO_DIR"] == str(pipeline.output_dir / "trackio_logs") + assert os.environ["TENSORBOARD_LOGGING_DIR"] == str(pipeline.output_dir) + + def test_prepare_training_config(self, mock_config_manager): + """Test training config preparation via prepare_training_config utility.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.prepare_training_config") as mock_prepare: + mock_prepare.return_value = {"fp16": True, "seed": 42, "type": "sft"} + + # Call prepare_training_config directly + result = mock_prepare(config_manager=mock_config_manager) + + # Verify prepare_training_config was called + assert mock_prepare.call_count > 0 + assert result == {"fp16": True, "seed": 42, "type": "sft"} + + @pytest.mark.parametrize( + "train_split,test_split,expected_train_split,expected_test_split", + [ + ("train", "test", "train", "test"), # Default splits + ("training", "testing", "training", "testing"), # Custom splits + ], + ) + def test_create_datasets( + self, + mock_config_manager, + train_split, + test_split, + expected_train_split, + expected_test_split, + ): + """Test dataset creation with default and custom split names.""" + # Set up config_manager.config.training to support dict access for seed and output_dir + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + # Update dataset config with the split names + mock_config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": train_split, + "test_split": test_split, + } + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + + def create_dataset_side_effect(*args, **kwargs): + split = kwargs.get("split", "") + # Match based on expected split names + if expected_train_split in split or (expected_train_split == "train" and "train" in split): + return mock_train_dataset + return mock_eval_dataset + + mock_factory.create_dataset.side_effect = create_dataset_side_effect + + pipeline = FineTuningPipeline(mock_config_manager) + train_dataset, eval_dataset = pipeline._create_datasets() + + # Verify datasets were created + assert train_dataset == mock_train_dataset + assert eval_dataset == mock_eval_dataset + + # Verify create_dataset was called twice (train and test) + assert mock_factory.create_dataset.call_count == 2 + + # Verify correct parameters were passed + calls = mock_factory.create_dataset.call_args_list + assert calls[0].kwargs["split"] == expected_train_split + assert calls[1].kwargs["split"] == expected_test_split + assert calls[0].kwargs["seed"] == 42 + assert calls[0].kwargs["dataset_type"] == "sft_dataset" + assert calls[0].kwargs["dataset_name"] == "test_dataset" + + @pytest.mark.parametrize( + "torch_dtype,expected_dtype", + [ + ("fp16", "float16"), # fp16 -> float16 + ("bf16", "bfloat16"), # bf16 -> bfloat16 + ("unknown", "auto"), # Unknown dtype -> auto + ], + ) + def test_create_model_dtype_conversion(self, mock_config_manager, torch_dtype, expected_dtype): + """Test model creation with different dtype conversions.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Mock get_model_config to return config with torch_dtype already converted + # (This conversion is done by ConfigManager.get_model_config, not by _create_model) + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "torch_dtype": expected_dtype, # Already converted by get_model_config + } + + mock_model_instance = MagicMock() + mock_model_instance.model = MagicMock() + mock_model_instance.tokenizer = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_model.return_value = mock_model_instance + + pipeline = FineTuningPipeline(mock_config_manager) + result = pipeline._create_model() + + assert result == mock_model_instance + + # Verify model was created with correct dtype (already converted by ConfigManager) + assert mock_factory.create_model.call_count > 0 + call_kwargs = mock_factory.create_model.call_args.kwargs + assert call_kwargs.get("torch_dtype") == expected_dtype + + def test_create_optimizer(self, mock_config_manager): + """Test optimizer creation.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {"lr": 1e-4} + + with patch("QEfficient.cloud.finetune_experimental.prepare_optimizer") as mock_prepare: + mock_prepare.return_value = (mock_optimizer_cls, mock_optimizer_kwargs) + + pipeline = FineTuningPipeline(mock_config_manager) + optimizer_cls, optimizer_kwargs = pipeline._create_optimizer() + + assert optimizer_cls == mock_optimizer_cls + assert optimizer_kwargs == mock_optimizer_kwargs + + assert mock_prepare.call_count > 0 + assert mock_prepare.call_args[0][0] == mock_config_manager.get_optimizer_config.return_value + + @pytest.mark.parametrize( + "callback_config,expected_count,expected_names", + [ + ( + { + "early_stopping": {"early_stopping_patience": 3}, + "tensorboard": {}, + }, + 2, + ["early_stopping", "tensorboard"], + ), + ( + { + "early_stopping": {"early_stopping_patience": 3}, + "tensorboard": {}, + "checkpoint": {"save_strategy": "epoch"}, + }, + 3, + ["early_stopping", "tensorboard", "checkpoint"], + ), + ], + ) + def test_create_callbacks(self, mock_config_manager, callback_config, expected_count, expected_names): + """Test callback creation with different numbers of callbacks.""" + mock_callback_config = {"callbacks": callback_config} + mock_config_manager.get_callback_config.return_value = mock_callback_config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Create mock callbacks based on expected count + mock_callbacks = [MagicMock() for _ in range(expected_count)] + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory.create_callback") as mock_create: + mock_create.side_effect = mock_callbacks + + pipeline = FineTuningPipeline(mock_config_manager) + callbacks = pipeline._create_callbacks() + + assert len(callbacks) == expected_count + for mock_cb in mock_callbacks: + assert mock_cb in callbacks + + # Verify callbacks were created with correct names + assert mock_create.call_count == expected_count + for i, expected_name in enumerate(expected_names): + assert mock_create.call_args_list[i][0][0] == expected_name + + def test_create_callbacks_with_failure(self, mock_config_manager): + """Test callback creation with one failure.""" + mock_callback_config = { + "callbacks": { + "early_stopping": {"early_stopping_patience": 3}, + "invalid_callback": {}, + } + } + mock_config_manager.get_callback_config.return_value = mock_callback_config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_callback = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory.create_callback") as mock_create: + with patch("QEfficient.cloud.finetune_experimental.logger") as mock_logger: + mock_create.side_effect = [ + mock_callback, + ValueError("Unknown callback"), + ] + + pipeline = FineTuningPipeline(mock_config_manager) + callbacks = pipeline._create_callbacks() + + # Should only have the successful callback + assert len(callbacks) == 1 + assert mock_callback in callbacks + + # Should log warning for failed callback + log_calls = [call[0][0] for call in mock_logger.log_rank_zero.call_args_list if call] + assert any("Warning" in str(msg) and "invalid_callback" in str(msg) for msg in log_calls) + + def test_create_trainer(self, mock_config_manager): + """Test trainer creation.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "device": "cpu", + } + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + + mock_trainer_cls = MagicMock() + mock_args_cls = MagicMock() + mock_args_instance = MagicMock() + mock_args_cls.return_value = mock_args_instance + + mock_trainer_instance = MagicMock() + mock_trainer_cls.return_value = mock_trainer_instance + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {} + mock_callbacks = [MagicMock()] + + training_config = {"type": "sft", "output_dir": "./output", "fp16": True} + + with patch( + "QEfficient.cloud.finetune_experimental.ComponentFactory.create_trainer_config" + ) as mock_create_trainer: + with patch("QEfficient.cloud.finetune_experimental.replace_progress_callback") as mock_replace: + mock_create_trainer.return_value = (mock_trainer_cls, mock_args_cls, {}) + + pipeline = FineTuningPipeline(mock_config_manager) + trainer = pipeline._create_trainer( + model=mock_model, + tokenizer=mock_tokenizer, + train_dataset=mock_train_dataset, + eval_dataset=mock_eval_dataset, + optimizer_cls_and_kwargs=(mock_optimizer_cls, mock_optimizer_kwargs), + callbacks=mock_callbacks, + training_config=training_config.copy(), + ) + + assert trainer == mock_trainer_instance + + # Verify trainer was created with correct parameters + assert mock_trainer_cls.call_count > 0 + call_kwargs = mock_trainer_cls.call_args.kwargs + assert call_kwargs["model"] == mock_model + assert call_kwargs["processing_class"] == mock_tokenizer + assert call_kwargs["args"] == mock_args_instance + assert call_kwargs["compute_loss_func"] is None + assert call_kwargs["train_dataset"] == mock_train_dataset.dataset + assert call_kwargs["eval_dataset"] == mock_eval_dataset.dataset + assert call_kwargs["optimizer_cls_and_kwargs"] == (mock_optimizer_cls, mock_optimizer_kwargs) + assert call_kwargs["callbacks"] == mock_callbacks + + # Verify progress callback replacement was called + assert mock_replace.call_count > 0 + replace_call_args = mock_replace.call_args.args + assert replace_call_args[0] == mock_trainer_instance + assert replace_call_args[1] == mock_callbacks + # Third argument should be logger (can be None or Logger instance) + assert len(replace_call_args) >= 3 + + def test_run_full_pipeline(self, mock_config_manager): + """Test full pipeline execution.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + mock_model_instance = MagicMock() + mock_model_instance.model = MagicMock() + mock_model_instance.tokenizer = MagicMock() + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {} + mock_callbacks = [MagicMock()] + mock_trainer = MagicMock() + + with patch( + "QEfficient.cloud.finetune_experimental.prepare_training_config", return_value={"type": "sft", "fp16": True} + ): + with patch.object( + FineTuningPipeline, "_create_datasets", return_value=(mock_train_dataset, mock_eval_dataset) + ): + with patch.object(FineTuningPipeline, "_create_model", return_value=mock_model_instance): + with patch.object( + FineTuningPipeline, + "_create_optimizer", + return_value=(mock_optimizer_cls, mock_optimizer_kwargs), + ): + with patch.object(FineTuningPipeline, "_create_callbacks", return_value=mock_callbacks): + with patch.object(FineTuningPipeline, "_create_trainer", return_value=mock_trainer): + with patch("QEfficient.cloud.finetune_experimental.logger") as mock_logger: + pipeline = FineTuningPipeline(mock_config_manager) + pipeline.run() + + # Verify all steps were executed + assert mock_config_manager.validate_config.call_count > 0 + assert pipeline._create_datasets.call_count > 0 + assert pipeline._create_model.call_count > 0 + assert pipeline._create_optimizer.call_count > 0 + assert pipeline._create_callbacks.call_count > 0 + assert pipeline._create_trainer.call_count > 0 + assert mock_trainer.train.call_count > 0 + + # Verify logging occurred + log_messages = [ + call[0][0] for call in mock_logger.log_rank_zero.call_args_list if call + ] + assert any("Creating datasets" in msg for msg in log_messages) + assert any("Loading model" in msg for msg in log_messages) + assert any("Preparing optimizer" in msg for msg in log_messages) + assert any("Creating callbacks" in msg for msg in log_messages) + assert any("Initializing trainer" in msg for msg in log_messages) + assert any("Starting training" in msg for msg in log_messages) + + def test_run_with_validation_error(self, mock_config_manager): + """Test pipeline run with validation error.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + mock_config_manager.validate_config.side_effect = ValueError("Invalid config") + + pipeline = FineTuningPipeline(mock_config_manager) + + with pytest.raises(ValueError, match="Invalid config"): + pipeline.run() + + @pytest.mark.parametrize( + "output_dir,expected_path", + [ + ("/absolute/path/to/output", "/absolute/path/to/output"), + ("./relative/output", "relative/output"), # Path normalizes ./relative/output to relative/output + ], + ) + def test_output_dir_path_handling(self, mock_config_manager, output_dir, expected_path): + """Test output directory path handling for both absolute and relative paths.""" + # Set up config_manager.config to have training dict + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": output_dir}) + mock_config_manager.config = mock_config_obj + + pipeline = FineTuningPipeline(mock_config_manager) + + assert isinstance(pipeline.output_dir, Path) + assert str(pipeline.output_dir) == expected_path + + +class TestMainFunction: + """Test suite for main() function.""" + + def test_main_function(self): + """Test main function execution.""" + mock_config_manager = MagicMock() + mock_pipeline = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", return_value=mock_config_manager): + with patch("QEfficient.cloud.finetune_experimental.FineTuningPipeline", return_value=mock_pipeline): + main() + + # Verify pipeline was created and run + from QEfficient.cloud.finetune_experimental import FineTuningPipeline + + assert FineTuningPipeline.call_count > 0 + assert FineTuningPipeline.call_args[0][0] == mock_config_manager + assert mock_pipeline.run.call_count > 0 + + def test_main_with_config_error(self): + """Test main function with config initialization error.""" + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", side_effect=ValueError("Config error")): + with pytest.raises(ValueError, match="Config error"): + main() + + def test_main_with_pipeline_error(self): + """Test main function with pipeline error.""" + mock_config_manager = MagicMock() + mock_pipeline = MagicMock() + mock_pipeline.run.side_effect = RuntimeError("Training failed") + + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", return_value=mock_config_manager): + with patch("QEfficient.cloud.finetune_experimental.FineTuningPipeline", return_value=mock_pipeline): + with pytest.raises(RuntimeError, match="Training failed"): + main() + + +class TestFineTuningPipelineEnhanced: + """Enhanced test suite for FineTuningPipeline class with additional edge cases.""" + + @pytest.fixture + def mock_master_config(self): + """Create a mock MasterConfig for testing.""" + config = MagicMock(spec=MasterConfig) + # Use DictLikeMock to support both dict access ['key'] and attribute access .key + config.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + return config + + @pytest.fixture + def mock_config_manager(self): + """Create a mock ConfigManager.""" + config_manager = MagicMock() + config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "seed": 42, + } + config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + } + config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + config_manager.get_optimizer_config.return_value = { + "optimizer_name": "adamw", + "lr": 1e-4, + } + config_manager.get_callback_config.return_value = {"callbacks": {}} + config_manager.validate_config = MagicMock() + return config_manager + + def test_create_datasets_with_additional_config_params(self, mock_config_manager): + """Test that additional dataset config parameters are properly propagated.""" + mock_config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + "max_seq_length": 512, + "batch_size": 16, + "custom_param": "custom_value", + } + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_dataset.return_value = MagicMock() + + pipeline = FineTuningPipeline(mock_config_manager) + pipeline._create_datasets() + + # Verify additional parameters are passed through + calls = mock_factory.create_dataset.call_args_list + assert calls[0].kwargs.get("max_seq_length") == 512 + assert calls[0].kwargs.get("batch_size") == 16 + assert calls[0].kwargs.get("custom_param") == "custom_value" + # Verify excluded keys are not passed + assert "train_split" not in calls[0].kwargs + assert "test_split" not in calls[0].kwargs + + def test_create_model_with_additional_model_params(self, mock_config_manager): + """Test that additional model config parameters are properly propagated.""" + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + "trust_remote_code": True, + "device_map": "auto", + "custom_model_param": "value", + } + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_model.return_value = MagicMock() + + pipeline = FineTuningPipeline(mock_config_manager) + pipeline._create_model() + + call_kwargs = mock_factory.create_model.call_args.kwargs + assert call_kwargs.get("trust_remote_code") is True + assert call_kwargs.get("device_map") == "auto" + assert call_kwargs.get("custom_model_param") == "value" + # Verify PEFT keys are excluded + assert "use_peft" not in call_kwargs + assert "peft_config" not in call_kwargs + + def test_run_method_calls_validate_config_first(self, mock_config_manager): + """Test that run() calls validate_config before other operations.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + call_order = [] + + def track_validate(): + call_order.append("validate") + return None + + mock_config_manager.validate_config.side_effect = track_validate + + with patch( + "QEfficient.cloud.finetune_experimental.prepare_training_config", return_value={"type": "sft", "fp16": True} + ): + with patch.object(FineTuningPipeline, "_create_datasets", return_value=(MagicMock(), MagicMock())): + with patch.object(FineTuningPipeline, "_create_model", return_value=MagicMock()): + with patch.object(FineTuningPipeline, "_create_optimizer", return_value=(MagicMock(), {})): + with patch.object(FineTuningPipeline, "_create_callbacks", return_value=[]): + with patch.object(FineTuningPipeline, "_create_trainer", return_value=MagicMock()): + with patch("QEfficient.cloud.finetune_experimental.logger"): + pipeline = FineTuningPipeline(mock_config_manager) + pipeline.run() + + # Verify validate_config was called first + assert call_order[0] == "validate" + assert mock_config_manager.validate_config.call_count == 1