Skip to content

Commit

Permalink
lora refactor in progress
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Nov 30, 2024
1 parent 6ec93f2 commit eee85e5
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 95 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Change Log for SD.Next

## Update for 2024-11-28
## Update for 2024-11-30

### New models and integrations

Expand Down Expand Up @@ -67,7 +67,8 @@
- fix xyz-grid with lora
- fix api script callbacks
- fix gpu memory monitoring
- simplify img2img/inpaint/sketch canvas handling
- simplify img2img/inpaint/sketch canvas handling
- fix prompt caching

## Update for 2024-11-21

Expand Down
9 changes: 4 additions & 5 deletions modules/lora/extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,21 @@ def __init__(self):
self.errors = {}

def activate(self, p, params_list, step=0):
t0 = time.time()
self.errors.clear()
if self.active:
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
self.active = False
if len(params_list) > 0 and not self.active: # activate patches once
shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"')
# shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"')
self.active = True
self.model = shared.opts.sd_model_checkpoint
names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step)
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
t1 = time.time()
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) # load
networks.network_load() # backup/apply
if len(networks.loaded_networks) > 0 and step == 0:
infotext(p)
prompt(p)
shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} dims={dyn_dims} load={t1-t0:.2f}')
shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} te={te_multipliers} unet={unet_multipliers} time={networks.get_timers()}')

def deactivate(self, p):
t0 = time.time()
Expand Down
73 changes: 36 additions & 37 deletions modules/lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def total_time():
def get_timers():
t = { 'total': round(sum(timer.values()), 2) }
for k, v in timer.items():
t[k] = round(v, 2)
if v > 0.1:
t[k] = round(v, 2)
return t


Expand Down Expand Up @@ -216,6 +217,7 @@ def maybe_recompile_model(names, te_multipliers):


def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
global backup_size # pylint: disable=global-statement
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
Expand Down Expand Up @@ -304,10 +306,9 @@ def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm
with devices.inference_context():
if weights_backup is not None:
if updown is not None:
if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
if len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
weights_backup = weights_backup.clone().to(device)
weights_backup = weights_backup.clone().to(self.weight.device)
weights_backup += updown.to(weights_backup)
if getattr(self, "quant_type", None) in ['nf4', 'fp4']:
bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
Expand Down Expand Up @@ -375,18 +376,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
network_layer_name = getattr(self, 'network_layer_name', None)
current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
if network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419
maybe_backup_weights(self, wanted_names)
if current_names != wanted_names:
batch_updown = None
batch_ex_bias = None
t0 = time.time()
for net in loaded_networks:
# default workflow where module is known and has weights
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
try:
with devices.inference_context():
with devices.inference_context():
if network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419
maybe_backup_weights(self, wanted_names)
if current_names != wanted_names:
batch_updown = None
batch_ex_bias = None
t0 = time.time()
for net in loaded_networks:
# default workflow where module is known and has weights
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
try:
weight = self.weight.to(devices.device) # calculate quant weights once
updown, ex_bias = module.calc_updown(weight)
if batch_updown is not None and updown is not None:
Expand All @@ -402,30 +403,30 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
batch_updown = batch_updown.to(devices.cpu)
if batch_ex_bias is not None:
batch_ex_bias = batch_ex_bias.to(devices.cpu)
except RuntimeError as e:
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
if debug:
module_name = net.modules.get(network_layer_name, None)
shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}')
errors.display(e, 'LoRA')
raise RuntimeError('LoRA apply weight') from e
continue
if module is None:
continue
shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation')
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
t1 = time.time()
timer['calc'] += t1 - t0
set_weights(self, batch_updown, batch_ex_bias) # Set or restore weights from backup
self.network_current_names = wanted_names
except RuntimeError as e:
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
if debug:
module_name = net.modules.get(network_layer_name, None)
shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}')
errors.display(e, 'LoRA')
raise RuntimeError('LoRA apply weight') from e
continue
if module is None:
continue
shared.log.warning(f'LoRA network="{net.name}" layer="{network_layer_name}" unsupported operation')
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
t1 = time.time()
timer['calc'] += t1 - t0
set_weights(self, batch_updown, batch_ex_bias) # Set or restore weights from backup
self.network_current_names = wanted_names


def network_load(): # called from processing
timer['backup'] = 0
timer['calc'] = 0
timer['apply'] = 0
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
if shared.opts.diffusers_offload_mode != "none":
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, device=devices.cpu)
modules = []
Expand All @@ -441,11 +442,9 @@ def network_load(): # called from processing
pbar.remove_task(task)
modules.clear()
if debug:
shared.log.debug(f'Load network: type=LoRA modules={len(modules)}')
if shared.opts.diffusers_offload_mode != "none":
shared.log.debug(f'Load network: type=LoRA modules={len(modules)} backup={backup_size} time={get_timers()}')
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.set_diffuser_offload(sd_model, op="model")
if debug:
shared.log.debug(f'Load network: type=LoRA time={get_timers()} backup={backup_size}')


def list_available_networks():
Expand Down
2 changes: 0 additions & 2 deletions modules/processing_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import numpy as np
from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers
from modules.lora.networks import network_load


p = None
Expand Down Expand Up @@ -69,7 +68,6 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
time.sleep(0.1)
if hasattr(p, "stepwise_lora") and shared.native:
extra_networks.activate(p, p.extra_network_data, step=step)
network_load()
if latents is None:
return kwargs
elif shared.opts.nan_skip:
Expand Down
21 changes: 12 additions & 9 deletions modules/processing_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
if hasattr(shared.sd_model, "vae") and output.images is not None and len(output.images) > 0:
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input
p.task_args['image'] = output.images # replace so hires uses new output
sd_models.move_model(shared.sd_model, devices.device)
if hasattr(shared.sd_model, 'unet'):
sd_models.move_model(shared.sd_model.unet, devices.device)
if hasattr(shared.sd_model, 'transformer'):
sd_models.move_model(shared.sd_model.transformer, devices.device)
update_sampler(p, shared.sd_model, second_pass=True)
orig_denoise = p.denoising_strength
p.denoising_strength = strength
Expand All @@ -227,6 +222,11 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
shared.state.job = 'HiRes'
shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None)
try:
sd_models.move_model(shared.sd_model, devices.device)
if hasattr(shared.sd_model, 'unet'):
sd_models.move_model(shared.sd_model.unet, devices.device)
if hasattr(shared.sd_model, 'transformer'):
sd_models.move_model(shared.sd_model.transformer, devices.device)
sd_models_compile.check_deepcache(enable=True)
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
if isinstance(output, dict):
Expand Down Expand Up @@ -405,6 +405,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
shared.sd_model = orig_pipeline
return results

if shared.opts.diffusers_offload_mode == "balanced":
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)

# sanitize init_images
if hasattr(p, 'init_images') and getattr(p, 'init_images', None) is None:
del p.init_images
Expand All @@ -427,10 +430,6 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
if p.negative_prompts is None or len(p.negative_prompts) == 0:
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]

# load loras
networks.network_load()

sd_models.move_model(shared.sd_model, devices.device)
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes

if 'base' not in p.skip:
Expand Down Expand Up @@ -461,6 +460,10 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
timer.process.add('lora', networks.total_time())

shared.sd_model = orig_pipeline

if shared.opts.diffusers_offload_mode == "balanced":
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)

if p.state == '':
global last_p # pylint: disable=global-statement
last_p = p
Expand Down
9 changes: 7 additions & 2 deletions modules/prompt_parser_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
token_dict = None # used by helper get_tokens
token_type = None # used by helper get_tokens
cache = OrderedDict()
last_attention = None
embedder = None


Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
self.prompts = prompts
self.negative_prompts = negative_prompts
self.batchsize = len(self.prompts)
self.attention = None
self.attention = last_attention
self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible
self.steps = steps
self.clip_skip = clip_skip
Expand All @@ -78,6 +79,8 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
self.scheduled_encode(pipe, batchidx)
else:
self.encode(pipe, prompt, negative_prompt, batchidx)
if shared.opts.diffusers_offload_mode == "balanced":
pipe = sd_models.apply_balanced_offload(pipe)
self.checkcache(p)
debug(f"Prompt encode: time={(time.time() - t0):.3f}")

Expand Down Expand Up @@ -113,6 +116,7 @@ def flatten(xss):
debug(f"Prompt cache: add={key}")
while len(cache) > int(shared.opts.sd_textencoder_cache_size):
cache.popitem(last=False)
return True
if item:
self.__dict__.update(cache[key])
cache.move_to_end(key)
Expand Down Expand Up @@ -161,7 +165,9 @@ def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index
self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx])

def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
global last_attention # pylint: disable=global-statement
self.attention = shared.opts.prompt_attention
last_attention = self.attention
if self.attention == "xhinker":
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
else:
Expand All @@ -178,7 +184,6 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
if debug_enabled:
get_tokens(pipe, 'positive', positive_prompt)
get_tokens(pipe, 'negative', negative_prompt)
pipe = prepare_model()

def __call__(self, key, step=0):
batch = getattr(self, key)
Expand Down
Loading

0 comments on commit eee85e5

Please sign in to comment.