Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
242 lines
9.8 KiB
Python
242 lines
9.8 KiB
Python
import math
|
|
from itertools import groupby
|
|
from typing import Any, Callable, Literal
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def parse_unet_blocks(model, unet_block_list: str, attn: Literal["attn1", "attn2"] | None):
|
|
output: list[tuple[str, int, int | None]] = []
|
|
names: list[str] = []
|
|
|
|
# Get all Self-attention blocks
|
|
input_blocks: list[tuple[int, str]] = []
|
|
middle_blocks: list[tuple[int, str]] = []
|
|
output_blocks: list[tuple[int, str]] = []
|
|
for name, module in model.model.diffusion_model.named_modules():
|
|
if module.__class__.__name__ == "BasicTransformerBlock" and (attn is None or hasattr(module, attn)):
|
|
parts = name.split(".")
|
|
unet_part = parts[0]
|
|
block_id = int(parts[1])
|
|
if unet_part.startswith("input"):
|
|
input_blocks.append((block_id, name))
|
|
elif unet_part.startswith("middle"):
|
|
middle_blocks.append((block_id - 1, name))
|
|
elif unet_part.startswith("output"):
|
|
output_blocks.append((block_id, name))
|
|
|
|
def group_blocks(blocks: list[tuple[int, str]]):
|
|
grouped_blocks = [(i, list(gr)) for i, gr in groupby(blocks, lambda b: b[0])]
|
|
return [(i, len(gr), list(idx[1] for idx in gr)) for i, gr in grouped_blocks]
|
|
|
|
input_blocks_gr, middle_blocks_gr, output_blocks_gr = (
|
|
group_blocks(input_blocks),
|
|
group_blocks(middle_blocks),
|
|
group_blocks(output_blocks),
|
|
)
|
|
|
|
user_inputs = [b.strip() for b in unet_block_list.split(",")]
|
|
for user_input in user_inputs:
|
|
unet_part_s, indices = user_input[0], user_input[1:].split(".")
|
|
match unet_part_s:
|
|
case "d":
|
|
unet_part, unet_group = "input", input_blocks_gr
|
|
case "m":
|
|
unet_part, unet_group = "middle", middle_blocks_gr
|
|
case "u":
|
|
unet_part, unet_group = "output", output_blocks_gr
|
|
case _:
|
|
raise ValueError(f"Block {user_input}: Unknown block prefix {unet_part_s}")
|
|
|
|
block_index_range = [int(b.strip()) for b in indices[0].split("-")]
|
|
block_index_range_start = block_index_range[0]
|
|
block_index_range_end = block_index_range[0] if len(block_index_range) != 2 else block_index_range[1]
|
|
for block_index in range(block_index_range_start, block_index_range_end + 1):
|
|
if block_index < 0 or block_index >= len(unet_group):
|
|
raise ValueError(
|
|
f"Block {user_input}: Block index in out of range 0 <= {block_index} < {len(unet_group)}"
|
|
)
|
|
|
|
block_group = unet_group[block_index]
|
|
block_index_real = block_group[0]
|
|
|
|
if len(indices) == 1:
|
|
output.append((unet_part, block_index_real, None))
|
|
names.extend(block_group[2])
|
|
else:
|
|
transformer_index_range = [int(b.strip()) for b in indices[1].split("-")]
|
|
transformer_index_range_start = transformer_index_range[0]
|
|
transformer_index_range_end = (
|
|
transformer_index_range[0] if len(transformer_index_range) != 2 else transformer_index_range[1]
|
|
)
|
|
for transformer_index in range(transformer_index_range_start, transformer_index_range_end + 1):
|
|
if transformer_index is not None and (transformer_index < 0 or transformer_index >= block_group[1]):
|
|
raise ValueError(
|
|
f"Block {user_input}: Transformer index in out of range 0 <= {transformer_index} < {block_group[1]}"
|
|
)
|
|
|
|
output.append((unet_part, block_index_real, transformer_index))
|
|
names.append(block_group[2][transformer_index])
|
|
|
|
return output, names
|
|
|
|
|
|
# Copied from https://github.com/comfyanonymous/ComfyUI/blob/719fb2c81d716ce8edd7f1bdc7804ae160a71d3a/comfy/model_patcher.py#L21 for backward compatibility
|
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
|
to = model_options["transformer_options"].copy()
|
|
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
else:
|
|
to["patches_replace"] = to["patches_replace"].copy()
|
|
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
else:
|
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
|
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
|
|
def set_model_options_value(model_options, key: str, value: Any):
|
|
to = model_options["transformer_options"].copy()
|
|
to[key] = value
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
|
|
def perturbed_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
|
|
"""Perturbed self-attention"""
|
|
return v
|
|
|
|
|
|
# Modified 'Algorithm 2 Classifier-Free Guidance with Rescale' from Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.).
|
|
def rescale_guidance(
|
|
guidance: torch.Tensor, cond_pred: torch.Tensor, cfg_result: torch.Tensor, rescale=0.0, rescale_mode="full"
|
|
):
|
|
if rescale == 0.0:
|
|
return guidance
|
|
|
|
match rescale_mode:
|
|
case "full":
|
|
guidance_result = cfg_result + guidance
|
|
case _:
|
|
guidance_result = cond_pred + guidance
|
|
|
|
std_cond = torch.std(cond_pred, dim=(1, 2, 3), keepdim=True)
|
|
std_guidance = torch.std(guidance_result, dim=(1, 2, 3), keepdim=True)
|
|
|
|
factor = std_cond / std_guidance
|
|
factor = rescale * factor + (1.0 - rescale)
|
|
|
|
return guidance * factor
|
|
|
|
|
|
# Gaussian blur
|
|
def gaussian_blur_2d(img, kernel_size, sigma):
|
|
height = img.shape[-1]
|
|
kernel_size = min(kernel_size, height - (height % 2 - 1))
|
|
ksize_half = (kernel_size - 1) * 0.5
|
|
|
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
|
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
|
|
x_kernel = pdf / pdf.sum()
|
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
|
|
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
|
|
|
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
|
|
|
img = F.pad(img, padding, mode="reflect")
|
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
|
|
|
return img
|
|
|
|
|
|
def seg_attention_wrapper(attention, blur_sigma=1.0):
|
|
def seg_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
|
|
"""Smoothed Energy Guidance self-attention"""
|
|
heads = extra_options["n_heads"]
|
|
bs, area, inner_dim = q.shape
|
|
|
|
height_orig, width_orig = extra_options["original_shape"][2:4]
|
|
aspect_ratio = width_orig / height_orig
|
|
|
|
if aspect_ratio >= 1.0:
|
|
height = round((area / aspect_ratio) ** 0.5)
|
|
q = q.permute(0, 2, 1).reshape(bs, inner_dim, height, -1)
|
|
else:
|
|
width = round((area * aspect_ratio) ** 0.5)
|
|
q = q.permute(0, 2, 1).reshape(bs, inner_dim, -1, width)
|
|
|
|
if blur_sigma >= 0:
|
|
kernel_size = math.ceil(6 * blur_sigma) + 1 - math.ceil(6 * blur_sigma) % 2
|
|
q = gaussian_blur_2d(q, kernel_size, blur_sigma)
|
|
else:
|
|
q[:] = q.mean(dim=(-2, -1), keepdim=True)
|
|
|
|
q = q.reshape(bs, inner_dim, -1).permute(0, 2, 1)
|
|
|
|
return attention(q, k, v, heads=heads)
|
|
|
|
return seg_attention
|
|
|
|
|
|
# Modified algorithm from 2411.10257 'The Unreasonable Effectiveness of Guidance for Diffusion Models' (Figure 6.)
|
|
def swg_pred_calc(
|
|
x: torch.Tensor, tile_width: int, tile_height: int, tile_overlap: int, calc_func: Callable[..., tuple[torch.Tensor]]
|
|
):
|
|
b, c, h, w = x.shape
|
|
swg_pred = torch.zeros_like(x)
|
|
overlap = torch.zeros_like(x)
|
|
|
|
tiles_w = math.ceil(w / (tile_width - tile_overlap))
|
|
tiles_h = math.ceil(h / (tile_height - tile_overlap))
|
|
|
|
for w_i in range(tiles_w):
|
|
for h_i in range(tiles_h):
|
|
left, right = tile_width * w_i, tile_width * (w_i + 1) + tile_overlap
|
|
top, bottom = tile_height * h_i, tile_height * (h_i + 1) + tile_overlap
|
|
|
|
x_window = x[:, :, top:bottom, left:right]
|
|
if x_window.shape[-1] == 0 or x_window.shape[-2] == 0:
|
|
continue
|
|
|
|
swg_pred_window = calc_func(x_in=x_window)[0]
|
|
swg_pred[:, :, top:bottom, left:right] += swg_pred_window
|
|
|
|
overlap_window = torch.ones_like(swg_pred_window)
|
|
overlap[:, :, top:bottom, left:right] += overlap_window
|
|
|
|
swg_pred = swg_pred / overlap
|
|
return swg_pred
|
|
|
|
|
|
# Saliency-adaptive Noise Fusion based on High-fidelity Person-centric Subject-to-Image Synthesis (Wang et al.)
|
|
# https://github.com/CodeGoat24/Face-diffuser/blob/edff1a5178ac9984879d9f5e542c1d0f0059ca5f/facediffuser/pipeline.py#L535-L562
|
|
def snf_guidance(t_guidance: torch.Tensor, s_guidance: torch.Tensor):
|
|
b, c, h, w = t_guidance.shape
|
|
|
|
t_omega = gaussian_blur_2d(torch.abs(t_guidance), 3, 1)
|
|
s_omega = gaussian_blur_2d(torch.abs(s_guidance), 3, 1)
|
|
t_softmax = torch.softmax(t_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
|
|
s_softmax = torch.softmax(s_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
|
|
|
|
guidance_stacked = torch.stack([t_guidance, s_guidance], dim=0)
|
|
ts_softmax = torch.stack([t_softmax, s_softmax], dim=0)
|
|
|
|
argeps = torch.argmax(ts_softmax, dim=0, keepdim=True)
|
|
|
|
snf = torch.gather(guidance_stacked, dim=0, index=argeps).squeeze(0)
|
|
return snf
|