diff --git a/CHANGELOG.md b/CHANGELOG.md index 919041bde..dcb88bcf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log for SD.Next -## Update for 2024-11-28 +## Update for 2024-11-30 ### New models and integrations @@ -67,7 +67,8 @@ - fix xyz-grid with lora - fix api script callbacks - fix gpu memory monitoring -- simplify img2img/inpaint/sketch canvas handling +- simplify img2img/inpaint/sketch canvas handling +- fix prompt caching ## Update for 2024-11-21 diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py index 3aea659d9..c875ba0d5 100644 --- a/modules/lora/extra_networks_lora.py +++ b/modules/lora/extra_networks_lora.py @@ -113,22 +113,21 @@ def __init__(self): self.errors = {} def activate(self, p, params_list, step=0): - t0 = time.time() self.errors.clear() if self.active: if self.model != shared.opts.sd_model_checkpoint: # reset if model changed self.active = False if len(params_list) > 0 and not self.active: # activate patches once - shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"') + # shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"') self.active = True self.model = shared.opts.sd_model_checkpoint names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step) - networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) - t1 = time.time() + networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) # load + networks.network_load() # backup/apply if len(networks.loaded_networks) > 0 and step == 0: infotext(p) prompt(p) - shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} dims={dyn_dims} load={t1-t0:.2f}') + shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} time={networks.get_timers()}') def deactivate(self, p): t0 = time.time() diff --git a/modules/lora/networks.py b/modules/lora/networks.py index 51ef27a8a..86c6e5ed0 100644 --- a/modules/lora/networks.py +++ b/modules/lora/networks.py @@ -54,7 +54,8 @@ def total_time(): def get_timers(): t = { 'total': round(sum(timer.values()), 2) } for k, v in timer.items(): - t[k] = round(v, 2) + if v > 0.1: + t[k] = round(v, 2) return t @@ -216,6 +217,7 @@ def maybe_recompile_model(names, te_multipliers): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + global backup_size # pylint: disable=global-statement networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() @@ -304,10 +306,9 @@ def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm with devices.inference_context(): if weights_backup is not None: if updown is not None: - if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: - # inpainting model. zero pad updown to make channel[1] 4 to 9 + if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable - weights_backup = weights_backup.clone().to(device) + weights_backup = weights_backup.clone().to(self.weight.device) weights_backup += updown.to(weights_backup) if getattr(self, "quant_type", None) in ['nf4', 'fp4']: bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True) @@ -375,18 +376,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn network_layer_name = getattr(self, 'network_layer_name', None) current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) - if network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 - maybe_backup_weights(self, wanted_names) - if current_names != wanted_names: - batch_updown = None - batch_ex_bias = None - t0 = time.time() - for net in loaded_networks: - # default workflow where module is known and has weights - module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight'): - try: - with devices.inference_context(): + with devices.inference_context(): + if network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 + maybe_backup_weights(self, wanted_names) + if current_names != wanted_names: + batch_updown = None + batch_ex_bias = None + t0 = time.time() + for net in loaded_networks: + # default workflow where module is known and has weights + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + try: weight = self.weight.to(devices.device) # calculate quant weights once updown, ex_bias = module.calc_updown(weight) if batch_updown is not None and updown is not None: @@ -402,22 +403,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn batch_updown = batch_updown.to(devices.cpu) if batch_ex_bias is not None: batch_ex_bias = batch_ex_bias.to(devices.cpu) - except RuntimeError as e: - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - if debug: - module_name = net.modules.get(network_layer_name, None) - shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}') - errors.display(e, 'LoRA') - raise RuntimeError('LoRA apply weight') from e - continue - if module is None: - continue - shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation') - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - t1 = time.time() - timer['calc'] += t1 - t0 - set_weights(self, batch_updown, batch_ex_bias) # Set or restore weights from backup - self.network_current_names = wanted_names + except RuntimeError as e: + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + if debug: + module_name = net.modules.get(network_layer_name, None) + shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}') + errors.display(e, 'LoRA') + raise RuntimeError('LoRA apply weight') from e + continue + if module is None: + continue + shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation') + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + t1 = time.time() + timer['calc'] += t1 - t0 + set_weights(self, batch_updown, batch_ex_bias) # Set or restore weights from backup + self.network_current_names = wanted_names def network_load(): # called from processing @@ -425,7 +426,7 @@ def network_load(): # called from processing timer['calc'] = 0 timer['apply'] = 0 sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility - if shared.opts.diffusers_offload_mode != "none": + if shared.opts.diffusers_offload_mode == "sequential": sd_models.disable_offload(sd_model) sd_models.move_model(sd_model, device=devices.cpu) modules = [] @@ -441,11 +442,9 @@ def network_load(): # called from processing pbar.remove_task(task) modules.clear() if debug: - shared.log.debug(f'Load network: type=LoRA modules={len(modules)}') - if shared.opts.diffusers_offload_mode != "none": + shared.log.debug(f'Load network: type=LoRA modules={len(modules)} backup={backup_size} time={get_timers()}') + if shared.opts.diffusers_offload_mode == "sequential": sd_models.set_diffuser_offload(sd_model, op="model") - if debug: - shared.log.debug(f'Load network: type=LoRA time={get_timers()} backup={backup_size}') def list_available_networks(): diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py index e1bf723cc..f3eb0bc37 100644 --- a/modules/processing_callbacks.py +++ b/modules/processing_callbacks.py @@ -4,7 +4,6 @@ import torch import numpy as np from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers -from modules.lora.networks import network_load p = None @@ -69,7 +68,6 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} time.sleep(0.1) if hasattr(p, "stepwise_lora") and shared.native: extra_networks.activate(p, p.extra_network_data, step=step) - network_load() if latents is None: return kwargs elif shared.opts.nan_skip: diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index ae24f5f80..463a15280 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -199,11 +199,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output): if hasattr(shared.sd_model, "vae") and output.images is not None and len(output.images) > 0: output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input p.task_args['image'] = output.images # replace so hires uses new output - sd_models.move_model(shared.sd_model, devices.device) - if hasattr(shared.sd_model, 'unet'): - sd_models.move_model(shared.sd_model.unet, devices.device) - if hasattr(shared.sd_model, 'transformer'): - sd_models.move_model(shared.sd_model.transformer, devices.device) update_sampler(p, shared.sd_model, second_pass=True) orig_denoise = p.denoising_strength p.denoising_strength = strength @@ -227,6 +222,11 @@ def process_hires(p: processing.StableDiffusionProcessing, output): shared.state.job = 'HiRes' shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None) try: + sd_models.move_model(shared.sd_model, devices.device) + if hasattr(shared.sd_model, 'unet'): + sd_models.move_model(shared.sd_model.unet, devices.device) + if hasattr(shared.sd_model, 'transformer'): + sd_models.move_model(shared.sd_model.transformer, devices.device) sd_models_compile.check_deepcache(enable=True) output = shared.sd_model(**hires_args) # pylint: disable=not-callable if isinstance(output, dict): @@ -405,6 +405,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing): shared.sd_model = orig_pipeline return results + if shared.opts.diffusers_offload_mode == "balanced": + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + # sanitize init_images if hasattr(p, 'init_images') and getattr(p, 'init_images', None) is None: del p.init_images @@ -427,10 +430,6 @@ def process_diffusers(p: processing.StableDiffusionProcessing): if p.negative_prompts is None or len(p.negative_prompts) == 0: p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size] - # load loras - networks.network_load() - - sd_models.move_model(shared.sd_model, devices.device) sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes if 'base' not in p.skip: @@ -461,6 +460,10 @@ def process_diffusers(p: processing.StableDiffusionProcessing): timer.process.add('lora', networks.total_time()) shared.sd_model = orig_pipeline + + if shared.opts.diffusers_offload_mode == "balanced": + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + if p.state == '': global last_p # pylint: disable=global-statement last_p = p diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index 2edef4bf5..c74731c6d 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -16,6 +16,7 @@ token_dict = None # used by helper get_tokens token_type = None # used by helper get_tokens cache = OrderedDict() +last_attention = None embedder = None @@ -52,7 +53,7 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p): self.prompts = prompts self.negative_prompts = negative_prompts self.batchsize = len(self.prompts) - self.attention = None + self.attention = last_attention self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible self.steps = steps self.clip_skip = clip_skip @@ -78,6 +79,8 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p): self.scheduled_encode(pipe, batchidx) else: self.encode(pipe, prompt, negative_prompt, batchidx) + if shared.opts.diffusers_offload_mode == "balanced": + pipe = sd_models.apply_balanced_offload(pipe) self.checkcache(p) debug(f"Prompt encode: time={(time.time() - t0):.3f}") @@ -113,6 +116,7 @@ def flatten(xss): debug(f"Prompt cache: add={key}") while len(cache) > int(shared.opts.sd_textencoder_cache_size): cache.popitem(last=False) + return True if item: self.__dict__.update(cache[key]) cache.move_to_end(key) @@ -161,7 +165,9 @@ def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx]) def encode(self, pipe, positive_prompt, negative_prompt, batchidx): + global last_attention # pylint: disable=global-statement self.attention = shared.opts.prompt_attention + last_attention = self.attention if self.attention == "xhinker": prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip) else: @@ -178,7 +184,6 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx): if debug_enabled: get_tokens(pipe, 'positive', positive_prompt) get_tokens(pipe, 'negative', negative_prompt) - pipe = prepare_model() def __call__(self, key, step=0): batch = getattr(self, key) diff --git a/modules/sd_models.py b/modules/sd_models.py index ccba0bfb5..2cf7b3931 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,6 +13,7 @@ from rich import progress # pylint: disable=redefined-builtin import torch import safetensors.torch +import accelerate from omegaconf import OmegaConf from ldm.util import instantiate_from_config from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect @@ -310,6 +311,7 @@ def set_accelerate(sd_model): def set_diffuser_offload(sd_model, op: str = 'model'): + t0 = time.time() if not shared.native: shared.log.warning('Attempting to use offload with backend=original') return @@ -363,41 +365,50 @@ def set_diffuser_offload(sd_model, op: str = 'model'): sd_model = apply_balanced_offload(sd_model) except Exception as e: shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') + process_timer.add('offload', time.time() - t0) + + +class OffloadHook(accelerate.hooks.ModelHook): + def init_hook(self, module): + return module + + def pre_forward(self, module, *args, **kwargs): + if devices.normalize_device(module.device) != devices.normalize_device(devices.device): + device_index = torch.device(devices.device).index + if device_index is None: + device_index = 0 + max_memory = { + device_index: f"{shared.opts.diffusers_offload_max_gpu_memory}GiB", + "cpu": f"{shared.opts.diffusers_offload_max_cpu_memory}GiB", + } + device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) + module = accelerate.hooks.remove_hook_from_module(module, recurse=True) + offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) + module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) + module = accelerate.hooks.add_hook_to_module(module, OffloadHook(), append=True) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + return args, kwargs + + def post_forward(self, module, output): + return output + + def detach_hook(self, module): + return module + + +offload_hook_instance = OffloadHook() def apply_balanced_offload(sd_model): - from accelerate import infer_auto_device_map, dispatch_model - from accelerate.hooks import add_hook_to_module, remove_hook_from_module, ModelHook + t0 = time.time() excluded = ['OmniGenPipeline'] if sd_model.__class__.__name__ in excluded: return sd_model - - class dispatch_from_cpu_hook(ModelHook): - def init_hook(self, module): - return module - - def pre_forward(self, module, *args, **kwargs): - if devices.normalize_device(module.device) != devices.normalize_device(devices.device): - device_index = torch.device(devices.device).index - if device_index is None: - device_index = 0 - max_memory = { - device_index: f"{shared.opts.diffusers_offload_max_gpu_memory}GiB", - "cpu": f"{shared.opts.diffusers_offload_max_cpu_memory}GiB", - } - device_map = infer_auto_device_map(module, max_memory=max_memory) - module = remove_hook_from_module(module, recurse=True) - offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) - module = dispatch_model(module, device_map=device_map, offload_dir=offload_dir) - module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - return args, kwargs - - def post_forward(self, module, output): - return output - - def detach_hook(self, module): - return module + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + debug_move(f'Apply offload: type=balanced fn={fn}') + checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None + if checkpoint_name is None: + checkpoint_name = sd_model.__class__.__name__ def apply_balanced_offload_to_module(pipe): if hasattr(pipe, "pipe"): @@ -409,23 +420,19 @@ def apply_balanced_offload_to_module(pipe): for module_name in keys: # pylint: disable=protected-access module = getattr(pipe, module_name, None) if isinstance(module, torch.nn.Module): - checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None - if checkpoint_name is None: - checkpoint_name = pipe.__class__.__name__ - offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) network_layer_name = getattr(module, "network_layer_name", None) - module = remove_hook_from_module(module, recurse=True) + module = accelerate.hooks.remove_hook_from_module(module, recurse=True) try: - module = module.to("cpu") - module.offload_dir = offload_dir - module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True) + module = module.to(devices.cpu, non_blocking=True) + module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) + # module = accelerate.hooks.add_hook_to_module(module, OffloadHook(), append=True) + module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access if network_layer_name: module.network_layer_name = network_layer_name except Exception as e: if 'bitsandbytes' not in str(e): shared.log.error(f'Balanced offload: module={module_name} {e}') - devices.torch_gc(fast=True) apply_balanced_offload_to_module(sd_model) if hasattr(sd_model, "pipe"): @@ -435,6 +442,8 @@ def apply_balanced_offload_to_module(pipe): if hasattr(sd_model, "decoder_pipe"): apply_balanced_offload_to_module(sd_model.decoder_pipe) set_accelerate(sd_model) + devices.torch_gc(fast=True) + process_timer.add('offload', time.time() - t0) return sd_model