diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index dcfa117db8..e823e6ecc6 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -237,6 +237,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}, ) + use_reentrant_gc: bool = field( + default=True, + metadata={"help": "Whether or not to use reentrant gradient checkpointing."}, + ) upcast_layernorm: bool = field( default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 3397a8cd90..0fad48cf3d 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum _gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc ) model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc} + ) setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info_rank0("Gradient checkpointing enabled.")