Skip to content

Commit

Permalink
Add dpmpp_3m_sde CFG++
Browse files Browse the repository at this point in the history
  • Loading branch information
pamparamm committed Jul 13, 2024
1 parent b49dafb commit 32ca983
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 4 deletions.
72 changes: 72 additions & 0 deletions ppm_cfgpp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,78 @@ def post_cfg_function(args):
return x


@torch.no_grad()
def sample_dpmpp_3m_sde_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""

if len(sigmas) <= 1:
return x

seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])

denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None

temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
h_eta = h * (eta + 1)

x = torch.exp(-h_eta) * (x + (denoised - temp[0])) + (-h_eta).expm1().neg() * denoised

if h_2 is not None:
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
d1_1 = (denoised_1 - denoised_2) / r1
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
elif h_1 is not None:
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d

if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise

denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x


@torch.no_grad()
def sample_dpmpp_3m_sde_gpu_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
return x

sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)


@torch.no_grad()
def sample_dpmpp_2m_sde_gpu_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-ppm"
description = "Fixed AttentionCouple/NegPip(negative weights in prompts), more CFG++ samplers, etc."
version = "1.0.0"
version = "1.0.1"
license = "AGPL-3.0"

[project.urls]
Expand Down
14 changes: 11 additions & 3 deletions samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

INITIALIZED = False
CFGPP_SAMPLER_NAMES_ORIGINAL = ["euler_cfg_pp", "euler_ancestral_cfg_pp"]
CFGPP_SAMPLER_NAMES = CFGPP_SAMPLER_NAMES_ORIGINAL + ["dpmpp_2m_cfg_pp", "dpmpp_2m_sde_cfg_pp", "dpmpp_2m_sde_gpu_cfg_pp"]
CFGPP_SAMPLER_NAMES = CFGPP_SAMPLER_NAMES_ORIGINAL + [
"dpmpp_2m_cfg_pp",
"dpmpp_2m_sde_cfg_pp",
"dpmpp_2m_sde_gpu_cfg_pp",
"dpmpp_3m_sde_cfg_pp",
"dpmpp_3m_sde_gpu_cfg_pp",
]


def inject_samplers():
Expand All @@ -20,6 +26,7 @@ def INPUT_TYPES(s):
return {
"required": {
"sampler_name": (CFGPP_SAMPLER_NAMES,),
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
}
}

Expand All @@ -28,10 +35,11 @@ def INPUT_TYPES(s):

FUNCTION = "get_sampler"

def get_sampler(self, sampler_name):
def get_sampler(self, sampler_name, eta: float):
if sampler_name in CFGPP_SAMPLER_NAMES_ORIGINAL:
sampler_func = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
else:
sampler_func = getattr(ppm_cfgpp_sampling, "sample_{}".format(sampler_name))
sampler = KSAMPLER(sampler_func)
extra_options = {} if sampler_name in {"euler_cfg_pp", "dpmpp_2m_cfg_pp"} else {"eta": eta}
sampler = KSAMPLER(sampler_func, extra_options=extra_options)
return (sampler,)

0 comments on commit 32ca983

Please sign in to comment.