Files
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

363 lines
12 KiB
Python

import math
from typing import Callable, List
import torch
from einops import rearrange, repeat
from torch import Tensor
import numpy as np
#from .modules.conditioner import HFEmbedder
from .layers import DoubleStreamMixerProcessor, timestep_embedding
from tqdm.auto import tqdm
from .utils import ControlNetContainer
def model_forward(
model,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
block_controlnet_hidden_states=None,
guidance: Tensor | None = None,
neg_mode: bool | None = False,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = model.img_in(img)
vec = model.time_in(timestep_embedding(timesteps, 256))
if model.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + model.guidance_in(timestep_embedding(guidance, 256))
vec = vec + model.vector_in(y)
txt = model.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = model.pe_embedder(ids)
if block_controlnet_hidden_states is not None:
controlnet_depth = len(block_controlnet_hidden_states)
for index_block, block in enumerate(model.double_blocks):
if hasattr(block, "processor"):
if isinstance(block.processor, DoubleStreamMixerProcessor):
if neg_mode:
for ip in block.processor.ip_adapters:
ip.ip_hidden_states = ip.in_hidden_states_neg
else:
for ip in block.processor.ip_adapters:
ip.ip_hidden_states = ip.in_hidden_states_pos
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# controlnet residual
if block_controlnet_hidden_states is not None:
img = img + block_controlnet_hidden_states[index_block % 2]
img = torch.cat((txt, img), 1)
for block in model.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = model.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(txt_t5, vec_clip, img: Tensor) -> dict[str, Tensor]:
txt = txt_t5
vec = vec_clip
bs, c, h, w = img.shape
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device, dtype=img.dtype),
"txt": txt.to(img.device, dtype=img.dtype),
"txt_ids": txt_ids.to(img.device, dtype=img.dtype),
"vec": vec.to(img.device, dtype=img.dtype),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
true_gs = 1,
timestep_to_start_cfg=0,
image2image_strength=None,
orig_image = None,
callback = None,
width = 512,
height = 512,
):
i = 0
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if image2image_strength is not None and orig_image is not None:
t_idx = np.clip(int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)), 0, len(timesteps) - 1)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype)
img = t * img + (1.0 - t) * orig_image
img_ids=img_ids.to(img.device, dtype=img.dtype)
txt=txt.to(img.device, dtype=img.dtype)
txt_ids=txt_ids.to(img.device, dtype=img.dtype)
vec=vec.to(img.device, dtype=img.dtype)
if hasattr(model, "guidance_in"):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
else:
# this is ignored for schnell
guidance_vec = None
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total = len(timesteps)-1):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model_forward(
model,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
if i >= timestep_to_start_cfg:
neg_pred = model_forward(
model,
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
neg_mode = True,
)
pred = neg_pred + true_gs * (pred - neg_pred)
img = img + (t_prev - t_curr) * pred
if callback is not None:
unpacked = unpack(img.float(), height, width)
callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1)
i += 1
return img
def denoise_controlnet(
model,
controlnets_container: None|List[ControlNetContainer],
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
#controlnet_cond,
#sampling parameters
timesteps: list[float],
guidance: float = 4.0,
true_gs = 1,
#controlnet_gs=0.7,
timestep_to_start_cfg=0,
image2image_strength=None,
orig_image = None,
callback = None,
width = 512,
height = 512,
#controlnet_start_step=0,
#controlnet_end_step=None
):
i = 0
if image2image_strength is not None and orig_image is not None:
t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps))
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype)
img = t * img + (1.0 - t) * orig_image
img_ids = img_ids.to(img.device, dtype=img.dtype)
txt = txt.to(img.device, dtype=img.dtype)
txt_ids = txt_ids.to(img.device, dtype=img.dtype)
vec = vec.to(img.device, dtype=img.dtype)
for container in controlnets_container:
container.controlnet_cond = container.controlnet_cond.to(img.device, dtype=img.dtype)
container.controlnet.to(img.device, dtype=img.dtype)
#controlnet.to(img.device, dtype=img.dtype)
#controlnet_cond = controlnet_cond.to(img.device, dtype=img.dtype)
if hasattr(model, "guidance_in"):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
else:
guidance_vec = None
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total=len(timesteps)-1):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
guidance_vec = guidance_vec.to(img.device, dtype=img.dtype)
controlnet_hidden_states = None
for container in controlnets_container:
if container.controlnet_start_step <= i <= container.controlnet_end_step:
block_res_samples = container.controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=container.controlnet_cond,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
if controlnet_hidden_states is None:
controlnet_hidden_states = [sample * container.controlnet_gs for sample in block_res_samples]
else:
if len(controlnet_hidden_states) == len(block_res_samples):
for j in range(len(controlnet_hidden_states)):
controlnet_hidden_states[j] += block_res_samples[j] * container.controlnet_gs
pred = model_forward(
model,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=controlnet_hidden_states
)
neg_controlnet_hidden_states = None
if i >= timestep_to_start_cfg:
for container in controlnets_container:
if container.controlnet_start_step <= i <= container.controlnet_end_step:
neg_block_res_samples = container.controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=container.controlnet_cond,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
)
if neg_controlnet_hidden_states is None:
neg_controlnet_hidden_states = [sample * container.controlnet_gs for sample in neg_block_res_samples]
else:
if len(neg_controlnet_hidden_states) == len(neg_block_res_samples):
for j in range(len(neg_controlnet_hidden_states)):
neg_controlnet_hidden_states[j] += neg_block_res_samples[j] * container.controlnet_gs
neg_pred = model_forward(
model,
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=neg_controlnet_hidden_states,
neg_mode=True,
)
pred = neg_pred + true_gs * (pred - neg_pred)
img = img + (t_prev - t_curr) * pred
if callback is not None:
unpacked = unpack(img.float(), height, width)
callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1)
i += 1
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)