-
Notifications
You must be signed in to change notification settings - Fork 8
/
alignprop_trainer.py
479 lines (398 loc) · 19.3 KB
/
alignprop_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import os
import textwrap
import random
from collections import defaultdict
from typing import Any, Callable, List, Optional, Tuple, Union
from warnings import warn
import time
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from aesthetic_scorer import AestheticScorerDiff
from accelerate.utils import ProjectConfiguration, set_seed
from transformers import is_wandb_available
import hpsv2
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
import torchvision
from sd_pipeline import DiffusionPipeline
from config.alignprop_config import AlignPropConfig
from trl.trainer import BaseTrainer
if is_wandb_available():
import wandb
logger = get_logger(__name__)
def hps_loss_fn(inference_dtype=None, device=None):
model_name = "ViT-H-14"
model, preprocess_train, preprocess_val = create_model_and_transforms(
model_name,
'laion2B-s32B-b79K',
precision=inference_dtype,
device=device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False
)
tokenizer = get_tokenizer(model_name)
link = "https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt"
import os
import requests
from tqdm import tqdm
# Create the directory if it doesn't exist
os.makedirs(os.path.expanduser('~/.cache/hpsv2'), exist_ok=True)
checkpoint_path = f"{os.path.expanduser('~')}/.cache/hpsv2/HPS_v2_compressed.pt"
# Download the file if it doesn't exist
if not os.path.exists(checkpoint_path):
response = requests.get(link, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(checkpoint_path, 'wb') as file, tqdm(
desc="Downloading HPS_v2_compressed.pt",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
# force download of model via score
hpsv2.score([], "")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
tokenizer = get_tokenizer(model_name)
model = model.to(device, dtype=inference_dtype)
model.eval()
target_size = 224
normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
def loss_fn(im_pix, prompts):
im_pix = ((im_pix / 2) + 0.5).clamp(0, 1)
x_var = torchvision.transforms.Resize(target_size)(im_pix)
x_var = normalize(x_var).to(im_pix.dtype)
caption = tokenizer(prompts)
caption = caption.to(device)
outputs = model(x_var, caption)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits = image_features @ text_features.T
scores = torch.diagonal(logits)
loss = 1.0 - scores
return loss, scores
return loss_fn
def aesthetic_loss_fn(aesthetic_target=None,
grad_scale=0,
device=None,
accelerator=None,
torch_dtype=None):
target_size = 224
normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
scorer.requires_grad_(False)
target_size = 224
def loss_fn(im_pix_un):
im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1)
im_pix = torchvision.transforms.Resize(target_size)(im_pix)
im_pix = normalize(im_pix).to(im_pix_un.dtype)
rewards = scorer(im_pix)
if aesthetic_target is None: # default maximization
loss = -1 * rewards
else:
# using L1 to keep on same scale
loss = abs(rewards - aesthetic_target)
return loss * grad_scale, rewards
return loss_fn
class AlignPropTrainer(BaseTrainer):
"""
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
As of now only Stable Diffusion based pipelines are supported
Attributes:
config (`AlignPropConfig`):
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
reward_function (`Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]`):
Reward function to be used
prompt_function (`Callable[[], Tuple[str, Any]]`):
Function to generate prompts to guide model
sd_pipeline (`DiffusionPipeline`):
Stable Diffusion pipeline to be used for training.
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
Hook to be called to log images
"""
_tag_names = ["trl", "alignprop"]
def __init__(
self,
config: AlignPropConfig,
prompt_function: Callable[[], Tuple[str, Any]],
sd_pipeline: DiffusionPipeline,
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
):
if image_samples_hook is None:
warn("No image_samples_hook provided; no images will be logged")
self.prompt_fn = prompt_function
self.config = config
self.image_samples_callback = image_samples_hook
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
if self.config.resume_from:
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
if "checkpoint_" not in os.path.basename(self.config.resume_from):
# get the most recent checkpoint in this directory
checkpoints = list(
filter(
lambda x: "checkpoint_" in x,
os.listdir(self.config.resume_from),
)
)
if len(checkpoints) == 0:
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
self.config.resume_from = os.path.join(
self.config.resume_from,
f"checkpoint_{checkpoint_numbers[-1]}",
)
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
self.accelerator = Accelerator(
log_with=self.config.log_with,
mixed_precision=self.config.mixed_precision,
project_config=accelerator_project_config,
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
# the total number of optimizer steps to accumulate across.
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
**self.config.accelerator_kwargs,
)
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
config_dict = config.to_dict()
config_dict['checkpoints_dir'] = self.config.project_kwargs['project_dir']
if self.accelerator.is_main_process:
self.accelerator.init_trackers(
self.config.tracker_project_name,
config=dict(alignprop_trainer_config=config_dict)
if not is_using_tensorboard
else config.to_dict(),
init_kwargs=self.config.tracker_kwargs,
)
logger.info(f"\n{config}")
set_seed(self.config.seed, device_specific=True)
self.sd_pipeline = sd_pipeline
self.sd_pipeline.set_progress_bar_config(
position=1,
disable=not self.accelerator.is_local_main_process,
leave=False,
desc="Timestep",
dynamic_ncols=True,
)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
if self.accelerator.mixed_precision == "fp16":
inference_dtype = torch.float16
elif self.accelerator.mixed_precision == "bf16":
inference_dtype = torch.bfloat16
else:
inference_dtype = torch.float32
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
trainable_layers = self.sd_pipeline.get_trainable_layers()
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
self.optimizer = self._setup_optimizer(
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
)
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
self.sd_pipeline.tokenizer(
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
)[0]
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
else:
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
if self.config.reward_fn=='hps':
self.loss_fn = hps_loss_fn(inference_dtype, self.accelerator.device)
elif self.config.reward_fn=='aesthetic': # easthetic
self.loss_fn = aesthetic_loss_fn(grad_scale=self.config.grad_scale,
aesthetic_target=self.config.aesthetic_target,
accelerator = self.accelerator,
torch_dtype = inference_dtype,
device = self.accelerator.device)
else:
raise NotImplementedError
if config.resume_from:
logger.info(f"Resuming from {config.resume_from}")
self.accelerator.load_state(config.resume_from)
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
else:
self.first_epoch = 0
self.eval_prompts, self.eval_prompt_metadata = zip(*[self.prompt_fn() for _ in range(config.train_batch_size)])
def step(self, epoch: int, global_step: int):
"""
Perform a single step of training.
Args:
epoch (int): The current epoch.
global_step (int): The current global step.
Side Effects:
- Model weights are updated
- Logs the statistics to the accelerator trackers.
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
Returns:
global_step (int): The updated global step.
"""
info = defaultdict(list)
print(f"Epoch: {epoch}, Global Step: {global_step}")
self.sd_pipeline.unet.train()
for _ in range(self.config.train_gradient_accumulation_steps):
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
prompt_image_pairs = self._generate_samples(
batch_size=self.config.train_batch_size,
)
if "hps" in self.config.reward_fn:
loss, rewards = self.loss_fn(prompt_image_pairs["images"], prompt_image_pairs["prompts"])
else:
loss, rewards = self.loss_fn(prompt_image_pairs["images"])
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
loss = loss.mean()
loss = loss * self.config.loss_coeff
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.trainable_layers.parameters()
if not isinstance(self.trainable_layers, list)
else self.trainable_layers,
self.config.train_max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
info["reward_mean"].append(rewards_vis.mean())
info["reward_std"].append(rewards_vis.std())
info["loss"].append(loss.item())
# Checks if the accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
# log training-related stuff
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
info.update({"epoch": epoch})
self.accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
else:
raise ValueError(
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
)
# Logs generated images
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0 and self.accelerator.is_main_process:
print("Logging images")
# Fix the random seed for reproducibility
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config.seed)
prompt_image_pairs_eval = self._generate_samples(
batch_size=self.config.train_batch_size, with_grad=False, prompts=self.eval_prompts
)
self.image_samples_callback(prompt_image_pairs_eval, global_step, self.accelerator.trackers[0])
seed = random.randint(0, 100)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if epoch != 0 and epoch % self.config.save_freq == 0:
print("Saving checkpoint")
self.accelerator.save_state()
print("Step Done")
return global_step
def _setup_optimizer(self, trainable_layers_parameters):
if self.config.train_use_8bit_adam:
import bitsandbytes
optimizer_cls = bitsandbytes.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
return optimizer_cls(
trainable_layers_parameters,
lr=self.config.train_learning_rate,
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
weight_decay=self.config.train_adam_weight_decay,
eps=self.config.train_adam_epsilon,
)
def _save_model_hook(self, models, weights, output_dir):
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
def _load_model_hook(self, models, input_dir):
self.sd_pipeline.load_checkpoint(models, input_dir)
models.pop() # ensures that accelerate doesn't try to handle loading of the model
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
"""
Generate samples from the model
Args:
batch_size (int): Batch size to use for sampling
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
Returns:
prompt_image_pairs (Dict[Any])
"""
prompt_image_pairs = {}
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
if prompts is None:
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
else:
prompt_metadata = [{} for _ in range(batch_size)]
prompt_ids = self.sd_pipeline.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
if with_grad:
sd_output = self.sd_pipeline.rgb_with_grad(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
backprop_strategy=self.config.backprop_strategy,
backprop_kwargs=self.config.backprop_kwargs[self.config.backprop_strategy],
output_type="pt",
)
else:
sd_output = self.sd_pipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
output_type="pt",
)
images = sd_output.images
prompt_image_pairs["images"] = images
prompt_image_pairs["prompts"] = prompts
prompt_image_pairs["prompt_metadata"] = prompt_metadata
return prompt_image_pairs
def train(self, epochs: Optional[int] = None):
"""
Train the model for a given number of epochs
"""
global_step = 0
if epochs is None:
epochs = self.config.num_epochs
for epoch in range(self.first_epoch, epochs):
global_step = self.step(epoch, global_step)
def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)
self.create_model_card()