Skip to content

Commit

Permalink
Fix callbacks in dyn samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
pamparamm committed Sep 4, 2024
1 parent 8d508bb commit d451c6b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-ppm"
description = "Fixed AttentionCouple/NegPip(negative weights in prompts), more CFG++ samplers, etc."
version = "1.0.9"
version = "1.0.10"
license = { text = "GNU Affero General Public License v3" }

[project.urls]
Expand Down
25 changes: 16 additions & 9 deletions sampling/ppm_cfgpp_dyn_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@torch.no_grad()
def dy_sampling_step_cfg_pp(x, model, sigma_next, sigma_hat, **extra_args):
def dy_sampling_step_cfg_pp(x, model, sigma_next, i, sigma, sigma_hat, callback, **extra_args):
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
Expand All @@ -37,6 +37,9 @@ def post_cfg_function(args):

with Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
if callback is not None:
callback({"x": c, "i": i, "sigma": sigma, "sigma_hat": sigma_hat, "denoised": denoised})

d = to_d(c, sigma_hat, temp[0])
c = denoised + d * sigma_next

Expand Down Expand Up @@ -83,19 +86,19 @@ def post_cfg_function(args):
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
d = to_d(x, sigma_hat, temp[0])
# Euler method
x = denoised + d * sigmas[i + 1]
if sigmas[i + 1] > 0:
if i // 2 == 1:
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], i, sigmas[i], sigma_hat, callback, **extra_args)
return x


@torch.no_grad()
def smea_sampling_step_cfg_pp(x, model, sigma_next, sigma_hat, **extra_args):
def smea_sampling_step_cfg_pp(x, model, sigma_next, i, sigma, sigma_hat, callback, **extra_args):
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
Expand All @@ -106,8 +109,12 @@ def post_cfg_function(args):

m, n = x.shape[2], x.shape[3]
x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode='nearest-exact')

with Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma_hat, "denoised": denoised})

d = to_d(x, sigma_hat, temp[0])
x = denoised + d * sigma_next
x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
Expand Down Expand Up @@ -138,14 +145,14 @@ def post_cfg_function(args):
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
d = to_d(x, sigma_hat, temp[0])
# Euler method
x = denoised + d * sigmas[i + 1]
if sigmas[i + 1] > 0:
if i + 1 // 2 == 1:
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], i, sigmas[i], sigma_hat, callback, **extra_args)
if i + 1 // 2 == 0:
x = smea_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
x = smea_sampling_step_cfg_pp(x, model, sigmas[i + 1], i, sigmas[i], sigma_hat, callback, **extra_args)
return x
25 changes: 16 additions & 9 deletions sampling/ppm_dyn_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __exit__(self, type, value, traceback):


@torch.no_grad()
def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
def dy_sampling_step(x, model, dt, i, sigma, sigma_hat, callback, **extra_args):
original_shape = x.shape
batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
extra_row = x.shape[2] % 2 == 1
Expand All @@ -51,6 +51,9 @@ def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):

with Rescaler(model, c, "nearest-exact", **extra_args) as rescaler:
denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
if callback is not None:
callback({"x": c, "i": i, "sigma": sigma, "sigma_hat": sigma_hat, "denoised": denoised})

d = to_d(c, sigma_hat, denoised)
c = c + d * dt

Expand Down Expand Up @@ -90,23 +93,27 @@ def sample_euler_dy(
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
d = to_d(x, sigma_hat, denoised)
# Euler method
x = x + d * dt
if sigmas[i + 1] > 0:
if i // 2 == 1:
x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
x = dy_sampling_step(x, model, dt, i, sigmas[i], sigma_hat, callback, **extra_args)
return x


@torch.no_grad()
def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
def smea_sampling_step(x, model, dt, i, sigma, sigma_hat, callback, **extra_args):
m, n = x.shape[2], x.shape[3]
x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode="nearest-exact")

with Rescaler(model, x, "nearest-exact", **extra_args) as rescaler:
denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma_hat, "denoised": denoised})

d = to_d(x, sigma_hat, denoised)
x = x + d * dt
x = torch.nn.functional.interpolate(input=x, size=(m, n), mode="nearest-exact")
Expand All @@ -130,14 +137,14 @@ def sample_euler_smea_dy(
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
d = to_d(x, sigma_hat, denoised)
# Euler method
x = x + d * dt
if sigmas[i + 1] > 0:
if i + 1 // 2 == 1:
x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
x = dy_sampling_step(x, model, dt, i, sigmas[i], sigma_hat, callback, **extra_args)
if i + 1 // 2 == 0:
x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
x = smea_sampling_step(x, model, dt, i, sigmas[i], sigma_hat, callback, **extra_args)
return x

0 comments on commit d451c6b

Please sign in to comment.