diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py index d647b73a6..e613431ab 100644 --- a/QEfficient/cloud/finetune_experimental.py +++ b/QEfficient/cloud/finetune_experimental.py @@ -4,3 +4,278 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +""" +Main entry point for fine-tuning LLMs using the experimental finetune framework. +""" + +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory +from QEfficient.finetune.experimental.core.config_manager import ( + ConfigManager, +) +from QEfficient.finetune.experimental.core.dataset import SFTDataset # noqa: F401 +from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401 +from QEfficient.finetune.experimental.core.optimizer import prepare_optimizer +from QEfficient.finetune.experimental.core.trainer import sft_trainer # noqa: F401 +from QEfficient.finetune.experimental.core.utils.peft_utils import convert_peft_config_to_lora_config +from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config + +logger = Logger(__name__) + +# Try importing QAIC-specific module, proceed without it if it's unavailable +try: + import torch_qaic # noqa: F401 +except ImportError as e: + logger.log_rank_zero( + f"Unable to import 'torch_qaic' package due to exception: {e}. Moving ahead without the torch_qaic extension.", + level="warning", + ) + + +class FineTuningPipeline: + """ + Main pipeline class for fine-tuning LLMs. + """ + + def __init__(self, config_manager: ConfigManager): + """ + Initialize the fine-tuning pipeline with configuration. + + Args: + config_manager: ConfigManager instance with loaded and validated configuration + """ + self.config_manager = config_manager + self.config = self.config_manager.config + self.output_dir = Path(self.config.training["output_dir"]) + self._setup_environment() + + def _setup_environment(self) -> None: + """Set up environment variables for output directories.""" + os.environ["OUTPUT_DIR"] = str(self.output_dir) + os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs") + os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir) + + def _create_datasets(self) -> Tuple[Any, Any]: + """ + Create training and evaluation datasets. + + Returns: + Tuple of (train_dataset, eval_dataset) + """ + dataset_config = self.config_manager.get_dataset_config() + + dataset_type = dataset_config.get("dataset_type") + dataset_name = dataset_config.get("dataset_name") + train_split = dataset_config.get("train_split", "train") + test_split = dataset_config.get("test_split", "test") + seed = self.config.training["seed"] + + # Create a copy of dataset_config excluding keys that are passed explicitly + # to avoid duplicate keyword arguments when unpacking + excluded_keys = ("dataset_type", "dataset_name", "split", "seed", "train_split", "test_split") + dataset_config_copy = {k: v for k, v in dataset_config.items() if k not in excluded_keys} + + # Helper function to create a dataset for a specific split + def create_dataset_for_split(split_name: str) -> Any: + return ComponentFactory.create_dataset( + dataset_type=dataset_type, + dataset_name=dataset_name, + split=split_name, + seed=seed, + **dataset_config_copy, + ) + + # Create training and evaluation datasets using config values + train_dataset = create_dataset_for_split(train_split) + eval_dataset = create_dataset_for_split(test_split) + + return train_dataset, eval_dataset + + def _create_model(self) -> Any: + """ + Create and load the model instance. + + Returns: + Model instance with loaded model and tokenizer + """ + # Get model config as dict + model_config = self.config_manager.get_model_config() + + # Extract required fields + model_type = model_config.pop("model_type") + model_name = model_config.pop("model_name") + + # Filter out PEFT-related fields, these shouldn't be passed to model creation + excluded_keys = {"use_peft", "peft_config"} + model_config_kwargs = {k: v for k, v in model_config.items() if k not in excluded_keys} + + model_instance = ComponentFactory.create_model(model_type, model_name, **model_config_kwargs) + return model_instance + + def _create_optimizer(self) -> Tuple[Any, Dict[str, Any]]: + """ + Create optimizer configuration. + + Returns: + Tuple of (optimizer_class, optimizer_kwargs) + """ + optimizer_config = self.config_manager.get_optimizer_config() + return prepare_optimizer(optimizer_config) + + def _create_callbacks(self) -> List[Any]: + """ + Create callback instances from configuration. + + Returns: + List of callback instances + """ + callback_config = self.config_manager.get_callback_config() + callbacks = [] + + # callback_config.callbacks is a dictionary of callback configurations + for callback_name, callback_kwargs in callback_config["callbacks"].items(): + try: + callback_instance = ComponentFactory.create_callback(callback_name, **callback_kwargs) + callbacks.append(callback_instance) + except ValueError as e: + logger.log_rank_zero(f"Warning: Failed to create callback '{callback_name}': {e}", level="warning") + + return callbacks + + def _create_trainer( + self, + model: Any, + tokenizer: Any, + train_dataset: Any, + eval_dataset: Any, + optimizer_cls_and_kwargs: Tuple[Any, Dict[str, Any]], + callbacks: List[Any], + training_config: Dict[str, Any], + ) -> Any: + """ + Create and configure the trainer instance. + + Args: + model: The model to train + tokenizer: Tokenizer for processing + train_dataset: Training dataset + eval_dataset: Evaluation dataset + optimizer_cls_and_kwargs: Optimizer class and kwargs tuple + callbacks: List of callbacks + training_config: Training configuration dictionary + + Returns: + Trainer instance + """ + trainer_type = training_config.pop("type") + + # Get PEFT config if enabled + model_config_dict = self.config_manager.get_model_config() + peft_config = None + if model_config_dict.get("use_peft", False): + peft_config_dataclass = model_config_dict.get("peft_config") + if peft_config_dataclass is not None: + peft_config = convert_peft_config_to_lora_config(peft_config_dataclass) + + # Build dependencies for trainer configuration + dependencies = {} + if peft_config is not None: + dependencies["peft_config"] = peft_config + trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies) + + # Clean up training config: remove fields that shouldn't be passed to TrainingArguments + training_config.pop("device", None) + # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config + training_config.pop("deepspeed_config", None) + training_config.pop("torch_dtype", None) + + # Create trainer arguments instance + args = args_cls(**training_config) + # Initialize trainer + trainer = trainer_cls( + model=model, + processing_class=tokenizer, + args=args, + compute_loss_func=None, + train_dataset=train_dataset.dataset, + eval_dataset=eval_dataset.dataset, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + callbacks=callbacks, + **additional_kwargs, + ) + + replace_progress_callback(trainer, callbacks, logger) + + return trainer + + def run(self) -> None: + """ + Execute the complete fine-tuning pipeline. + """ + # Validate configuration + self.config_manager.validate_config() + + # Prepare training configuration + training_config = prepare_training_config(config_manager=self.config_manager) + + # Create datasets + logger.log_rank_zero("Creating datasets...") + train_dataset, eval_dataset = self._create_datasets() + + # Create model and tokenizer + logger.log_rank_zero("Loading model and tokenizer...") + model_instance = self._create_model() + model = model_instance.model + tokenizer = model_instance.tokenizer + + # Create optimizer + logger.log_rank_zero("Preparing optimizer...") + optimizer_cls_and_kwargs = self._create_optimizer() + + # Create callbacks + logger.log_rank_zero("Creating callbacks...") + callbacks = self._create_callbacks() + + # Create trainer + logger.log_rank_zero("Initializing trainer...") + trainer = self._create_trainer( + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + callbacks=callbacks, + training_config=training_config, + ) + + # Start training + logger.log_rank_zero("Starting training...") + trainer.train() + + +def main(): + """ + Main entry point for fine-tuning. + + Parses command-line arguments or config file and runs the fine-tuning pipeline. + """ + # ConfigManager now handles argument parsing internally via its __init__ + # It will automatically detect and parse: + # - Command-line args (if len(sys.argv) > 1) + # - Config file path (if sys.argv[1] ends with .yaml) + # - Or use defaults if no args provided + config_manager = ConfigManager() + + # Create and run pipeline - pass ConfigManager directly to avoid redundant wrapping + pipeline = FineTuningPipeline(config_manager) + pipeline.run() + + +if __name__ == "__main__": + main() diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 30659e3bb..bd1ce91c2 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -19,7 +19,7 @@ from transformers.integrations.integration_utils import TensorBoardCallback from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState -from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry from QEfficient.finetune.experimental.core.utils.profiler_utils import ( get_op_verifier_ctx, init_qaic_profiling, @@ -197,9 +197,39 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra self.op_verifier_ctx_step.__exit__(None, None, None) -def create_callbacks(name: str, **kwargs) -> Any: - """Create a callback instance.""" - callback_class = registry.get_callback(name) - if callback_class is None: - raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") - return callback_class(**kwargs) +def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None: + """ + Replace default ProgressCallback with EnhancedProgressCallback if not already present. + + Args: + trainer: Trainer instance + callbacks: List of callbacks already added + logger: Optional logger instance for warning messages + """ + # Check if EnhancedProgressCallback is already in callbacks + has_enhanced = any(callback.__class__.__name__ == "EnhancedProgressCallback" for callback in callbacks) + + if not has_enhanced: + try: + # Remove default ProgressCallback if present + trainer.remove_callback(ProgressCallback) + except (AttributeError, ValueError) as e: + # Callback not present or method doesn't exist, continue + if logger: + logger.log_rank_zero( + f"Debug: Could not remove default ProgressCallback: {e}. This is expected if callback is not present.", + level="debug", + ) + pass + + try: + # Add EnhancedProgressCallback + enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar") + trainer.add_callback(enhanced_callback) + except Exception as e: + if logger: + logger.log_rank_zero(f"Warning: Could not add enhanced progress callback: {e}", level="warning") + else: + import warnings + + warnings.warn(f"Could not add enhanced progress callback: {e}") diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index 00252831f..043552275 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import logging -from typing import Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type # from QEfficient.finetune.experimental.core.logger import get_logger @@ -201,7 +201,7 @@ def list_callbacks(self) -> list[str]: class ComponentFactory: @staticmethod - def create_model(model_type: str, model_name: str, **kwargs) -> any: + def create_model(model_type: str, model_name: str, **kwargs) -> Any: """Create a model instance.""" model_class = registry.get_model(model_type) if model_class is None: @@ -209,6 +209,7 @@ def create_model(model_type: str, model_name: str, **kwargs) -> any: model_instance = model_class.create(model_name, **kwargs) return model_instance + @staticmethod def create_trainer_config(name: str, **dependencies) -> tuple: """ Create trainer configuration based on registered trainer modules. @@ -236,3 +237,41 @@ def create_trainer_config(name: str, **dependencies) -> tuple: raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") return config["trainer_cls"], config["args_cls"], additional_kwargs + + @staticmethod + def create_dataset(dataset_type: str, dataset_name: str, split: str, seed: int = 42, **kwargs) -> Any: + """ + Create a dataset instance. + + Args: + dataset_type: Type of dataset to create (e.g., 'sft_dataset') + dataset_name: Name of the dataset to load + split: Dataset split ("train", "test", etc.) + seed: Random seed for reproducibility + **kwargs: Additional dataset configuration parameters + + Returns: + Dataset instance + """ + dataset_class = registry.get_dataset(dataset_type) + if dataset_class is None: + raise ValueError(f"Unknown dataset type: {dataset_type}. Available: {registry.list_datasets()}") + dataset_instance = dataset_class(dataset_name=dataset_name, split=split, seed=seed, **kwargs) + return dataset_instance + + @staticmethod + def create_callback(name: str, **kwargs) -> Any: + """ + Create a callback instance. + + Args: + name: Name of the callback to create + **kwargs: Additional callback configuration parameters + + Returns: + Callback instance + """ + callback_class = registry.get_callback(name) + if callback_class is None: + raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") + return callback_class(**kwargs) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index cf6737c25..5b5a8a819 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -172,6 +172,7 @@ class DatasetConfig: default="default", metadata={"help": "Name of the hf configuration file."}, ) + json_file_path: str = field(default=None, metadata={"help": "Path to a JSON file containing data."}) @dataclass @@ -698,6 +699,20 @@ def validate_config(self) -> None: self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") # ---------- Training ---------- + # torch_dtype validation + torch_dtype = training.get("torch_dtype") + valid_dtypes = {"fp16", "bf16", "fp32"} + self._push( + errors, + not torch_dtype, + "training.torch_dtype is required.", + ) + self._push( + errors, + torch_dtype and torch_dtype not in valid_dtypes, + f"training.torch_dtype must be one of {valid_dtypes}.", + ) + # Batch sizes self._push( errors, @@ -766,8 +781,24 @@ def get_dataset_config(self) -> Dict[str, Any]: return self.config.dataset def get_model_config(self) -> Dict[str, Any]: - """Get model configuration as dictionary.""" - return self.config.model + """ + Get model configuration as dictionary. + + Automatically handles torch_dtype conversion from training config if not set in model config. + """ + model_config = self.config.model + + # Get torch_dtype from training config and convert + # To do: check if it can be moved from training config to model config instead + if model_config.get("torch_dtype") is None: + training_config = self.get_training_config() + training_dtype = training_config.get("torch_dtype") + if training_dtype: + # Convert from training format (fp16/bf16) to model format (float16/bfloat16) + dtype_mapping = {"fp16": "float16", "bf16": "bfloat16"} + model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto") + + return model_config def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary.""" diff --git a/QEfficient/finetune/experimental/core/utils/peft_utils.py b/QEfficient/finetune/experimental/core/utils/peft_utils.py new file mode 100644 index 000000000..9c6cfaf3c --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/peft_utils.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Utility functions for PEFT (Parameter-Efficient Fine-Tuning) configuration. +""" + +from dataclasses import asdict +from typing import Any, Optional + +from peft import LoraConfig + + +def convert_peft_config_to_lora_config(peft_config: Any) -> Optional[LoraConfig]: + """ + Convert PeftConfig (dataclass or dict) to LoraConfig from peft library. + + Args: + peft_config: PeftConfig dataclass instance or dict + + Returns: + LoraConfig instance or None if PEFT is not enabled + """ + if peft_config is None: + return None + + # Convert dataclass to dictionary if needed + if hasattr(peft_config, "__dict__") and not isinstance(peft_config, dict): + peft_dict = asdict(peft_config) + else: + peft_dict = peft_config + + # Map PeftConfig fields to LoraConfig fields + lora_config_dict = { + "r": peft_dict.get("lora_r"), + "lora_alpha": peft_dict.get("lora_alpha"), + "lora_dropout": peft_dict.get("lora_dropout"), + "target_modules": peft_dict.get("target_modules"), + "bias": peft_dict.get("bias"), + "task_type": peft_dict.get("task_type"), + } + + return LoraConfig(**lora_config_dict) diff --git a/QEfficient/finetune/experimental/core/utils/training_config_utils.py b/QEfficient/finetune/experimental/core/utils/training_config_utils.py new file mode 100644 index 000000000..1cd6704e4 --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/training_config_utils.py @@ -0,0 +1,84 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Utility functions for preparing training configurations. +""" + +from typing import Any, Dict + +from QEfficient.finetune.experimental.core.config_manager import ConfigManager + + +def prepare_training_config( + config_manager: ConfigManager, + include_num_input_tokens_seen: bool = False, + use_cpu: bool = False, +) -> Dict[str, Any]: + """ + Prepare and transform training configuration for trainer initialization. + + Args: + config_manager: ConfigManager instance with loaded configuration + + Returns: + Dictionary of training arguments ready for trainer initialization + """ + # Get training config as dict and create mutable copy to avoid mutating original + training_config = dict(config_manager.get_training_config()) + + # Handle dtype conversion + # To do: (For Tanisha) Check if torch_dtype should rather be added directly in model_config only in config_manager.py + + torch_dtype = training_config.pop("torch_dtype", None) + if torch_dtype is None: + raise ValueError("'torch_dtype' field is required in training configuration. Expected one of: ['fp16', 'bf16']") + training_config[torch_dtype] = True + training_config["data_seed"] = training_config.get("seed") + + # Restoring the "torch_dtype" after torch_dtype conversion using the saved value + training_config["torch_dtype"] = torch_dtype + + # Handle scheduler configuration + scheduler_config = config_manager.get_scheduler_config() + training_config.setdefault("lr_scheduler_type", scheduler_config.get("scheduler_name")) + + # Set warmup_ratio and warmup_steps from scheduler_config if they exist and are not None + warmup_ratio = scheduler_config.get("warmup_ratio") + if warmup_ratio is not None: + training_config["warmup_ratio"] = warmup_ratio + warmup_steps = scheduler_config.get("warmup_steps") + if warmup_steps is not None: + training_config["warmup_steps"] = warmup_steps + + # Handle dataset configuration for dataloader settings + dataset_config = config_manager.get_dataset_config() + training_config.setdefault("dataloader_pin_memory", dataset_config.get("dataloader_pin_memory")) + training_config.setdefault("dataloader_persistent_workers", dataset_config.get("dataloader_persistent_workers")) + training_config.setdefault("dataloader_prefetch_factor", dataset_config.get("dataloader_prefetch_factor")) + training_config.setdefault("dataloader_drop_last", dataset_config.get("dataloader_drop_last")) + training_config.setdefault("dataloader_num_workers", dataset_config.get("dataloader_num_workers")) + training_config.setdefault("group_by_length", dataset_config.get("group_by_length")) + + # Handle DDP configuration + if training_config.get("ddp_config") is not None: + ddp_config = training_config.pop("ddp_config") + if not isinstance(ddp_config, dict): + from dataclasses import asdict, is_dataclass + + if is_dataclass(ddp_config): + ddp_config = asdict(ddp_config) + else: + raise TypeError( + f"ddp_config must be a dict or DdpConfig dataclass instance, " + f"got {type(ddp_config).__name__}: {ddp_config}" + ) + + # Merge ddp_config into training_config + training_config = {**training_config, **ddp_config} + + return training_config diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 59ff4d117..e085da9c9 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -8,8 +8,7 @@ import pytest from transformers import TrainerCallback -from QEfficient.finetune.experimental.core.callbacks import create_callbacks -from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry class ModelSummaryCallback(TrainerCallback): @@ -46,7 +45,7 @@ def test_callbacks(callback_name): # Create callbacks using the factory config = CALLBACK_CONFIGS[callback_name] try: - callback_inst = create_callbacks(**config) + callback_inst = ComponentFactory.create_callback(**config) except ValueError as e: assert "Unknown callback" in str(e) return diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index 4e531595d..b4980ad2c 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -58,3 +58,30 @@ def test_config(config_path): assert optimizer_config is not None assert isinstance(optimizer_config, dict) assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr")) + + +def test_torch_dtype_validation(): + """Test that torch_dtype validation works correctly.""" + # Test with default config - should have torch_dtype set to fp16 by default + config_manager = ConfigManager() + training_config = config_manager.get_training_config() + assert training_config.get("torch_dtype") == "fp16" + + # Validation should pass with default config + config_manager.validate_config() # Should not raise + + +def test_torch_dtype_invalid(): + """Test that invalid torch_dtype raises validation error.""" + from QEfficient.finetune.experimental.core.config_manager import MasterConfig, TrainingConfig + + # Create config with invalid torch_dtype + training_config = TrainingConfig(torch_dtype="invalid_dtype") + master_config = MasterConfig(training=training_config) + config_manager = ConfigManager(config=master_config) + + # Validation should fail + with pytest.raises(ValueError) as exc_info: + config_manager.validate_config() + + assert "torch_dtype must be one of" in str(exc_info.value) diff --git a/QEfficient/finetune/experimental/tests/test_finetune.py b/QEfficient/finetune/experimental/tests/test_finetune.py new file mode 100644 index 000000000..2c8ab8b3e --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_finetune.py @@ -0,0 +1,653 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Unit tests for finetune_experimental.py. +Tests for FineTuningPipeline class and main() function. +""" + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from QEfficient.cloud.finetune_experimental import FineTuningPipeline, main +from QEfficient.finetune.experimental.core.config_manager import MasterConfig + + +class DictLikeMock: + """A mock that supports both dict access ['key'] and attribute access .key""" + + def __init__(self, data): + self._data = data + for key, value in data.items(): + setattr(self, key, value) + + def __getitem__(self, key): + return self._data[key] + + def __contains__(self, key): + return key in self._data + + def get(self, key, default=None): + return self._data.get(key, default) + + +class TestFineTuningPipeline: + """Test suite for FineTuningPipeline class.""" + + @pytest.fixture + def mock_master_config(self): + """Create a mock MasterConfig for testing.""" + config = MagicMock(spec=MasterConfig) + # Use DictLikeMock to support both dict access ['key'] and attribute access .key + config.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + return config + + @pytest.fixture + def mock_config_manager(self): + """Create a mock ConfigManager.""" + config_manager = MagicMock() + config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "seed": 42, + } + config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + } + config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + config_manager.get_optimizer_config.return_value = { + "optimizer_name": "adamw", + "lr": 1e-4, + } + config_manager.get_callback_config.return_value = {"callbacks": {}} + config_manager.validate_config = MagicMock() + return config_manager + + def test_initialization(self, mock_config_manager): + """Test pipeline initialization.""" + # Set up config_manager.config to return a mock that has training dict access + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + pipeline = FineTuningPipeline(mock_config_manager) + + assert pipeline.config_manager == mock_config_manager + assert pipeline.config == mock_config_obj + assert isinstance(pipeline.output_dir, Path) + assert pipeline.output_dir == Path("./test_output") + + def test_setup_environment(self, mock_config_manager): + """Test environment variable setup.""" + # Set up config_manager.config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Clear environment variables + env_vars = ["OUTPUT_DIR", "TRACKIO_DIR", "TENSORBOARD_LOGGING_DIR"] + for var in env_vars: + if var in os.environ: + del os.environ[var] + + pipeline = FineTuningPipeline(mock_config_manager) + + # Verify environment variables are set + assert os.environ["OUTPUT_DIR"] == str(pipeline.output_dir) + assert os.environ["TRACKIO_DIR"] == str(pipeline.output_dir / "trackio_logs") + assert os.environ["TENSORBOARD_LOGGING_DIR"] == str(pipeline.output_dir) + + def test_prepare_training_config(self, mock_config_manager): + """Test training config preparation via prepare_training_config utility.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.prepare_training_config") as mock_prepare: + mock_prepare.return_value = {"fp16": True, "seed": 42, "type": "sft"} + + # Call prepare_training_config directly + result = mock_prepare(config_manager=mock_config_manager) + + # Verify prepare_training_config was called + assert mock_prepare.call_count > 0 + assert result == {"fp16": True, "seed": 42, "type": "sft"} + + @pytest.mark.parametrize( + "train_split,test_split,expected_train_split,expected_test_split", + [ + ("train", "test", "train", "test"), # Default splits + ("training", "testing", "training", "testing"), # Custom splits + ], + ) + def test_create_datasets( + self, + mock_config_manager, + train_split, + test_split, + expected_train_split, + expected_test_split, + ): + """Test dataset creation with default and custom split names.""" + # Set up config_manager.config.training to support dict access for seed and output_dir + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + # Update dataset config with the split names + mock_config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": train_split, + "test_split": test_split, + } + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + + def create_dataset_side_effect(*args, **kwargs): + split = kwargs.get("split", "") + # Match based on expected split names + if expected_train_split in split or (expected_train_split == "train" and "train" in split): + return mock_train_dataset + return mock_eval_dataset + + mock_factory.create_dataset.side_effect = create_dataset_side_effect + + pipeline = FineTuningPipeline(mock_config_manager) + train_dataset, eval_dataset = pipeline._create_datasets() + + # Verify datasets were created + assert train_dataset == mock_train_dataset + assert eval_dataset == mock_eval_dataset + + # Verify create_dataset was called twice (train and test) + assert mock_factory.create_dataset.call_count == 2 + + # Verify correct parameters were passed + calls = mock_factory.create_dataset.call_args_list + assert calls[0].kwargs["split"] == expected_train_split + assert calls[1].kwargs["split"] == expected_test_split + assert calls[0].kwargs["seed"] == 42 + assert calls[0].kwargs["dataset_type"] == "sft_dataset" + assert calls[0].kwargs["dataset_name"] == "test_dataset" + + @pytest.mark.parametrize( + "torch_dtype,expected_dtype", + [ + ("fp16", "float16"), # fp16 -> float16 + ("bf16", "bfloat16"), # bf16 -> bfloat16 + ("unknown", "auto"), # Unknown dtype -> auto + ], + ) + def test_create_model_dtype_conversion(self, mock_config_manager, torch_dtype, expected_dtype): + """Test model creation with different dtype conversions.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Mock get_model_config to return config with torch_dtype already converted + # (This conversion is done by ConfigManager.get_model_config, not by _create_model) + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "torch_dtype": expected_dtype, # Already converted by get_model_config + } + + mock_model_instance = MagicMock() + mock_model_instance.model = MagicMock() + mock_model_instance.tokenizer = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_model.return_value = mock_model_instance + + pipeline = FineTuningPipeline(mock_config_manager) + result = pipeline._create_model() + + assert result == mock_model_instance + + # Verify model was created with correct dtype (already converted by ConfigManager) + assert mock_factory.create_model.call_count > 0 + call_kwargs = mock_factory.create_model.call_args.kwargs + assert call_kwargs.get("torch_dtype") == expected_dtype + + def test_create_optimizer(self, mock_config_manager): + """Test optimizer creation.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {"lr": 1e-4} + + with patch("QEfficient.cloud.finetune_experimental.prepare_optimizer") as mock_prepare: + mock_prepare.return_value = (mock_optimizer_cls, mock_optimizer_kwargs) + + pipeline = FineTuningPipeline(mock_config_manager) + optimizer_cls, optimizer_kwargs = pipeline._create_optimizer() + + assert optimizer_cls == mock_optimizer_cls + assert optimizer_kwargs == mock_optimizer_kwargs + + assert mock_prepare.call_count > 0 + assert mock_prepare.call_args[0][0] == mock_config_manager.get_optimizer_config.return_value + + @pytest.mark.parametrize( + "callback_config,expected_count,expected_names", + [ + ( + { + "early_stopping": {"early_stopping_patience": 3}, + "tensorboard": {}, + }, + 2, + ["early_stopping", "tensorboard"], + ), + ( + { + "early_stopping": {"early_stopping_patience": 3}, + "tensorboard": {}, + "checkpoint": {"save_strategy": "epoch"}, + }, + 3, + ["early_stopping", "tensorboard", "checkpoint"], + ), + ], + ) + def test_create_callbacks(self, mock_config_manager, callback_config, expected_count, expected_names): + """Test callback creation with different numbers of callbacks.""" + mock_callback_config = {"callbacks": callback_config} + mock_config_manager.get_callback_config.return_value = mock_callback_config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + # Create mock callbacks based on expected count + mock_callbacks = [MagicMock() for _ in range(expected_count)] + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory.create_callback") as mock_create: + mock_create.side_effect = mock_callbacks + + pipeline = FineTuningPipeline(mock_config_manager) + callbacks = pipeline._create_callbacks() + + assert len(callbacks) == expected_count + for mock_cb in mock_callbacks: + assert mock_cb in callbacks + + # Verify callbacks were created with correct names + assert mock_create.call_count == expected_count + for i, expected_name in enumerate(expected_names): + assert mock_create.call_args_list[i][0][0] == expected_name + + def test_create_callbacks_with_failure(self, mock_config_manager): + """Test callback creation with one failure.""" + mock_callback_config = { + "callbacks": { + "early_stopping": {"early_stopping_patience": 3}, + "invalid_callback": {}, + } + } + mock_config_manager.get_callback_config.return_value = mock_callback_config + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_callback = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory.create_callback") as mock_create: + with patch("QEfficient.cloud.finetune_experimental.logger") as mock_logger: + mock_create.side_effect = [ + mock_callback, + ValueError("Unknown callback"), + ] + + pipeline = FineTuningPipeline(mock_config_manager) + callbacks = pipeline._create_callbacks() + + # Should only have the successful callback + assert len(callbacks) == 1 + assert mock_callback in callbacks + + # Should log warning for failed callback + log_calls = [call[0][0] for call in mock_logger.log_rank_zero.call_args_list if call] + assert any("Warning" in str(msg) and "invalid_callback" in str(msg) for msg in log_calls) + + def test_create_trainer(self, mock_config_manager): + """Test trainer creation.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "device": "cpu", + } + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + + mock_trainer_cls = MagicMock() + mock_args_cls = MagicMock() + mock_args_instance = MagicMock() + mock_args_cls.return_value = mock_args_instance + + mock_trainer_instance = MagicMock() + mock_trainer_cls.return_value = mock_trainer_instance + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {} + mock_callbacks = [MagicMock()] + + training_config = {"type": "sft", "output_dir": "./output", "fp16": True} + + with patch( + "QEfficient.cloud.finetune_experimental.ComponentFactory.create_trainer_config" + ) as mock_create_trainer: + with patch("QEfficient.cloud.finetune_experimental.replace_progress_callback") as mock_replace: + mock_create_trainer.return_value = (mock_trainer_cls, mock_args_cls, {}) + + pipeline = FineTuningPipeline(mock_config_manager) + trainer = pipeline._create_trainer( + model=mock_model, + tokenizer=mock_tokenizer, + train_dataset=mock_train_dataset, + eval_dataset=mock_eval_dataset, + optimizer_cls_and_kwargs=(mock_optimizer_cls, mock_optimizer_kwargs), + callbacks=mock_callbacks, + training_config=training_config.copy(), + ) + + assert trainer == mock_trainer_instance + + # Verify trainer was created with correct parameters + assert mock_trainer_cls.call_count > 0 + call_kwargs = mock_trainer_cls.call_args.kwargs + assert call_kwargs["model"] == mock_model + assert call_kwargs["processing_class"] == mock_tokenizer + assert call_kwargs["args"] == mock_args_instance + assert call_kwargs["compute_loss_func"] is None + assert call_kwargs["train_dataset"] == mock_train_dataset.dataset + assert call_kwargs["eval_dataset"] == mock_eval_dataset.dataset + assert call_kwargs["optimizer_cls_and_kwargs"] == (mock_optimizer_cls, mock_optimizer_kwargs) + assert call_kwargs["callbacks"] == mock_callbacks + + # Verify progress callback replacement was called + assert mock_replace.call_count > 0 + replace_call_args = mock_replace.call_args.args + assert replace_call_args[0] == mock_trainer_instance + assert replace_call_args[1] == mock_callbacks + # Third argument should be logger (can be None or Logger instance) + assert len(replace_call_args) >= 3 + + def test_run_full_pipeline(self, mock_config_manager): + """Test full pipeline execution.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + mock_train_dataset = MagicMock() + mock_eval_dataset = MagicMock() + mock_model_instance = MagicMock() + mock_model_instance.model = MagicMock() + mock_model_instance.tokenizer = MagicMock() + mock_optimizer_cls = MagicMock() + mock_optimizer_kwargs = {} + mock_callbacks = [MagicMock()] + mock_trainer = MagicMock() + + with patch( + "QEfficient.cloud.finetune_experimental.prepare_training_config", return_value={"type": "sft", "fp16": True} + ): + with patch.object( + FineTuningPipeline, "_create_datasets", return_value=(mock_train_dataset, mock_eval_dataset) + ): + with patch.object(FineTuningPipeline, "_create_model", return_value=mock_model_instance): + with patch.object( + FineTuningPipeline, + "_create_optimizer", + return_value=(mock_optimizer_cls, mock_optimizer_kwargs), + ): + with patch.object(FineTuningPipeline, "_create_callbacks", return_value=mock_callbacks): + with patch.object(FineTuningPipeline, "_create_trainer", return_value=mock_trainer): + with patch("QEfficient.cloud.finetune_experimental.logger") as mock_logger: + pipeline = FineTuningPipeline(mock_config_manager) + pipeline.run() + + # Verify all steps were executed + assert mock_config_manager.validate_config.call_count > 0 + assert pipeline._create_datasets.call_count > 0 + assert pipeline._create_model.call_count > 0 + assert pipeline._create_optimizer.call_count > 0 + assert pipeline._create_callbacks.call_count > 0 + assert pipeline._create_trainer.call_count > 0 + assert mock_trainer.train.call_count > 0 + + # Verify logging occurred + log_messages = [ + call[0][0] for call in mock_logger.log_rank_zero.call_args_list if call + ] + assert any("Creating datasets" in msg for msg in log_messages) + assert any("Loading model" in msg for msg in log_messages) + assert any("Preparing optimizer" in msg for msg in log_messages) + assert any("Creating callbacks" in msg for msg in log_messages) + assert any("Initializing trainer" in msg for msg in log_messages) + assert any("Starting training" in msg for msg in log_messages) + + def test_run_with_validation_error(self, mock_config_manager): + """Test pipeline run with validation error.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + mock_config_manager.validate_config.side_effect = ValueError("Invalid config") + + pipeline = FineTuningPipeline(mock_config_manager) + + with pytest.raises(ValueError, match="Invalid config"): + pipeline.run() + + @pytest.mark.parametrize( + "output_dir,expected_path", + [ + ("/absolute/path/to/output", "/absolute/path/to/output"), + ("./relative/output", "relative/output"), # Path normalizes ./relative/output to relative/output + ], + ) + def test_output_dir_path_handling(self, mock_config_manager, output_dir, expected_path): + """Test output directory path handling for both absolute and relative paths.""" + # Set up config_manager.config to have training dict + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": output_dir}) + mock_config_manager.config = mock_config_obj + + pipeline = FineTuningPipeline(mock_config_manager) + + assert isinstance(pipeline.output_dir, Path) + assert str(pipeline.output_dir) == expected_path + + +class TestMainFunction: + """Test suite for main() function.""" + + def test_main_function(self): + """Test main function execution.""" + mock_config_manager = MagicMock() + mock_pipeline = MagicMock() + + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", return_value=mock_config_manager): + with patch("QEfficient.cloud.finetune_experimental.FineTuningPipeline", return_value=mock_pipeline): + main() + + # Verify pipeline was created and run + from QEfficient.cloud.finetune_experimental import FineTuningPipeline + + assert FineTuningPipeline.call_count > 0 + assert FineTuningPipeline.call_args[0][0] == mock_config_manager + assert mock_pipeline.run.call_count > 0 + + def test_main_with_config_error(self): + """Test main function with config initialization error.""" + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", side_effect=ValueError("Config error")): + with pytest.raises(ValueError, match="Config error"): + main() + + def test_main_with_pipeline_error(self): + """Test main function with pipeline error.""" + mock_config_manager = MagicMock() + mock_pipeline = MagicMock() + mock_pipeline.run.side_effect = RuntimeError("Training failed") + + with patch("QEfficient.cloud.finetune_experimental.ConfigManager", return_value=mock_config_manager): + with patch("QEfficient.cloud.finetune_experimental.FineTuningPipeline", return_value=mock_pipeline): + with pytest.raises(RuntimeError, match="Training failed"): + main() + + +class TestFineTuningPipelineEnhanced: + """Enhanced test suite for FineTuningPipeline class with additional edge cases.""" + + @pytest.fixture + def mock_master_config(self): + """Create a mock MasterConfig for testing.""" + config = MagicMock(spec=MasterConfig) + # Use DictLikeMock to support both dict access ['key'] and attribute access .key + config.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + return config + + @pytest.fixture + def mock_config_manager(self): + """Create a mock ConfigManager.""" + config_manager = MagicMock() + config_manager.get_training_config.return_value = { + "type": "sft", + "dtype": "fp16", + "seed": 42, + } + config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + } + config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + } + config_manager.get_optimizer_config.return_value = { + "optimizer_name": "adamw", + "lr": 1e-4, + } + config_manager.get_callback_config.return_value = {"callbacks": {}} + config_manager.validate_config = MagicMock() + return config_manager + + def test_create_datasets_with_additional_config_params(self, mock_config_manager): + """Test that additional dataset config parameters are properly propagated.""" + mock_config_manager.get_dataset_config.return_value = { + "dataset_type": "sft_dataset", + "dataset_name": "test_dataset", + "train_split": "train", + "test_split": "test", + "max_seq_length": 512, + "batch_size": 16, + "custom_param": "custom_value", + } + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_dataset.return_value = MagicMock() + + pipeline = FineTuningPipeline(mock_config_manager) + pipeline._create_datasets() + + # Verify additional parameters are passed through + calls = mock_factory.create_dataset.call_args_list + assert calls[0].kwargs.get("max_seq_length") == 512 + assert calls[0].kwargs.get("batch_size") == 16 + assert calls[0].kwargs.get("custom_param") == "custom_value" + # Verify excluded keys are not passed + assert "train_split" not in calls[0].kwargs + assert "test_split" not in calls[0].kwargs + + def test_create_model_with_additional_model_params(self, mock_config_manager): + """Test that additional model config parameters are properly propagated.""" + mock_config_manager.get_model_config.return_value = { + "model_type": "hf", + "model_name": "test-model", + "use_peft": False, + "trust_remote_code": True, + "device_map": "auto", + "custom_model_param": "value", + } + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output"}) + mock_config_manager.config = mock_config_obj + + with patch("QEfficient.cloud.finetune_experimental.ComponentFactory") as mock_factory: + mock_factory.create_model.return_value = MagicMock() + + pipeline = FineTuningPipeline(mock_config_manager) + pipeline._create_model() + + call_kwargs = mock_factory.create_model.call_args.kwargs + assert call_kwargs.get("trust_remote_code") is True + assert call_kwargs.get("device_map") == "auto" + assert call_kwargs.get("custom_model_param") == "value" + # Verify PEFT keys are excluded + assert "use_peft" not in call_kwargs + assert "peft_config" not in call_kwargs + + def test_run_method_calls_validate_config_first(self, mock_config_manager): + """Test that run() calls validate_config before other operations.""" + mock_config_obj = MagicMock() + mock_config_obj.training = DictLikeMock({"output_dir": "./test_output", "seed": 42}) + mock_config_manager.config = mock_config_obj + + call_order = [] + + def track_validate(): + call_order.append("validate") + return None + + mock_config_manager.validate_config.side_effect = track_validate + + with patch( + "QEfficient.cloud.finetune_experimental.prepare_training_config", return_value={"type": "sft", "fp16": True} + ): + with patch.object(FineTuningPipeline, "_create_datasets", return_value=(MagicMock(), MagicMock())): + with patch.object(FineTuningPipeline, "_create_model", return_value=MagicMock()): + with patch.object(FineTuningPipeline, "_create_optimizer", return_value=(MagicMock(), {})): + with patch.object(FineTuningPipeline, "_create_callbacks", return_value=[]): + with patch.object(FineTuningPipeline, "_create_trainer", return_value=MagicMock()): + with patch("QEfficient.cloud.finetune_experimental.logger"): + pipeline = FineTuningPipeline(mock_config_manager) + pipeline.run() + + # Verify validate_config was called first + assert call_order[0] == "validate" + assert mock_config_manager.validate_config.call_count == 1