Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to access model gradients with DeepSpeed and Accelerate #3184

Open
2 of 4 tasks
shouyezhe opened this issue Oct 22, 2024 · 0 comments
Open
2 of 4 tasks

Unable to access model gradients with DeepSpeed and Accelerate #3184

shouyezhe opened this issue Oct 22, 2024 · 0 comments

Comments

@shouyezhe
Copy link

System Info

- `Accelerate` version: 0.34.0
- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.17
- `accelerate` bash location: /home/miao/anaconda3/envs/train/bin/accelerate
- Python version: 3.8.20
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.4.0 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 503.55 GB
- GPU type: NVIDIA GeForce RTX 3090
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: DEEPSPEED
        - mixed_precision: no
        - use_cpu: False
        - debug: False
        - num_processes: 4
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: 0,1,2,3
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

When using the official training script for Diffusers (https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) with DeepSpeed and ZeRO-2, I'm trying to save the gradients of the model at each training step. However, I'm encountering difficulties due to the way Accelerate wraps DeepSpeed's operations.

Current Behavior

I modified the code between accelerator.backward(loss) (line 1030) and optimizer.step() (line 1033) as follows:

from deepspeed.utils import safe_get_full_grad
for n, lp in unet.named_parameters():
    # 1. Access the full states
    #  1.1) gradient lookup
    # For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
    # For zero3, gradient lookup must be called after `backward`
    hp_grad = safe_get_full_grad(lp)

Problem

The current implementation of Accelerate wraps both DeepSpeed's backward and step operations into a single accelerator.backward call. This prevents users from accessing the gradients between these two operations, which is necessary for gradient analysis or custom gradient processing.

Suggested Solution

Modify Accelerate's DeepSpeed integration to allow users to access gradients between the backward and step operations. This could be achieved by:

  1. Separating the backward and step operations in Accelerate's DeepSpeed wrapper. By the way, I don't understand why DeepSpeed's backward and step are coupled together.
  2. A temporary solution to access full gradients when using DeepSpeed with Accelerate. I modified the code in accelerate.utils.deepspeed.py line 178 and accelerate.accelerator.py line 2188.
self.engine.backward(loss, **kwargs)
# Deepspeed's `engine.step` performs the following operations:
# - gradient accumulation check
# - gradient clipping
# - optimizer step
# - zero grad
# - checking overflow
# - lr_scheduler step (only if engine.lr_scheduler is not None)
if gradients != None:
	from deepspeed.utils import safe_get_full_grad
	import torch
	with torch.no_grad():
		for n, lp in self.engine.module.named_parameters():
			# 1. Access the full states
			#  1.1) gradient lookup
			# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
			# For zero3, gradient lookup must be called after `backward`
			if lp.grad is None:
				gradients[n] = safe_get_full_grad(lp)
			else:
				gradients[n] = lp.grad
self.engine.step()
if gradients != None:
	return gradients
if kwargs.get("gradients") != None:
	return self.deepspeed_engine_wrapped.backward(loss, **kwargs)
else:
	self.deepspeed_engine_wrapped.backward(loss)

Finally, I obtained the desired gradients using the following code.

gradients = accelerator.backward(loss, gradients=gradients)

Expected behavior

I should be able to access and save the full gradients of the model parameters at each training step when using DeepSpeed with ZeRO-2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant