Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,359 @@
|
||||
# Currently only sd15
|
||||
|
||||
import functools
|
||||
import torch
|
||||
import einops
|
||||
|
||||
from comfy import model_management, utils
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
module_mapping_sd15 = {
|
||||
0: "input_blocks.1.1.transformer_blocks.0.attn1",
|
||||
1: "input_blocks.1.1.transformer_blocks.0.attn2",
|
||||
2: "input_blocks.2.1.transformer_blocks.0.attn1",
|
||||
3: "input_blocks.2.1.transformer_blocks.0.attn2",
|
||||
4: "input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
5: "input_blocks.4.1.transformer_blocks.0.attn2",
|
||||
6: "input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
7: "input_blocks.5.1.transformer_blocks.0.attn2",
|
||||
8: "input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
9: "input_blocks.7.1.transformer_blocks.0.attn2",
|
||||
10: "input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
11: "input_blocks.8.1.transformer_blocks.0.attn2",
|
||||
12: "output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
13: "output_blocks.3.1.transformer_blocks.0.attn2",
|
||||
14: "output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
15: "output_blocks.4.1.transformer_blocks.0.attn2",
|
||||
16: "output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
17: "output_blocks.5.1.transformer_blocks.0.attn2",
|
||||
18: "output_blocks.6.1.transformer_blocks.0.attn1",
|
||||
19: "output_blocks.6.1.transformer_blocks.0.attn2",
|
||||
20: "output_blocks.7.1.transformer_blocks.0.attn1",
|
||||
21: "output_blocks.7.1.transformer_blocks.0.attn2",
|
||||
22: "output_blocks.8.1.transformer_blocks.0.attn1",
|
||||
23: "output_blocks.8.1.transformer_blocks.0.attn2",
|
||||
24: "output_blocks.9.1.transformer_blocks.0.attn1",
|
||||
25: "output_blocks.9.1.transformer_blocks.0.attn2",
|
||||
26: "output_blocks.10.1.transformer_blocks.0.attn1",
|
||||
27: "output_blocks.10.1.transformer_blocks.0.attn2",
|
||||
28: "output_blocks.11.1.transformer_blocks.0.attn1",
|
||||
29: "output_blocks.11.1.transformer_blocks.0.attn2",
|
||||
30: "middle_block.1.transformer_blocks.0.attn1",
|
||||
31: "middle_block.1.transformer_blocks.0.attn2",
|
||||
}
|
||||
|
||||
|
||||
def compute_cond_mark(cond_or_uncond, sigmas):
|
||||
cond_or_uncond_size = int(sigmas.shape[0])
|
||||
|
||||
cond_mark = []
|
||||
for cx in cond_or_uncond:
|
||||
cond_mark += [cx] * cond_or_uncond_size
|
||||
|
||||
cond_mark = torch.Tensor(cond_mark).to(sigmas)
|
||||
return cond_mark
|
||||
|
||||
|
||||
class LoRALinearLayer(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
|
||||
super().__init__()
|
||||
self.down = torch.nn.Linear(in_features, rank, bias=False)
|
||||
self.up = torch.nn.Linear(rank, out_features, bias=False)
|
||||
self.org = [org]
|
||||
|
||||
def forward(self, h):
|
||||
org_weight = self.org[0].weight.to(h)
|
||||
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
|
||||
down_weight = self.down.weight
|
||||
up_weight = self.up.weight
|
||||
final_weight = org_weight + torch.mm(up_weight, down_weight)
|
||||
return torch.nn.functional.linear(h, final_weight, org_bias)
|
||||
|
||||
|
||||
class AttentionSharingUnit(torch.nn.Module):
|
||||
# `transformer_options` passed to the most recent BasicTransformerBlock.forward
|
||||
# call.
|
||||
transformer_options: dict = {}
|
||||
|
||||
def __init__(self, module, frames=2, use_control=True, rank=256):
|
||||
super().__init__()
|
||||
|
||||
self.heads = module.heads
|
||||
self.frames = frames
|
||||
self.original_module = [module]
|
||||
q_in_channels, q_out_channels = (
|
||||
module.to_q.in_features,
|
||||
module.to_q.out_features,
|
||||
)
|
||||
k_in_channels, k_out_channels = (
|
||||
module.to_k.in_features,
|
||||
module.to_k.out_features,
|
||||
)
|
||||
v_in_channels, v_out_channels = (
|
||||
module.to_v.in_features,
|
||||
module.to_v.out_features,
|
||||
)
|
||||
o_in_channels, o_out_channels = (
|
||||
module.to_out[0].in_features,
|
||||
module.to_out[0].out_features,
|
||||
)
|
||||
|
||||
hidden_size = k_out_channels
|
||||
|
||||
self.to_q_lora = [
|
||||
LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_k_lora = [
|
||||
LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_v_lora = [
|
||||
LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_out_lora = [
|
||||
LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
|
||||
self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
|
||||
self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
|
||||
self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
|
||||
self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)
|
||||
|
||||
self.temporal_i = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_n = torch.nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6
|
||||
)
|
||||
self.temporal_q = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_k = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_v = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_o = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
|
||||
self.control_convs = None
|
||||
|
||||
if use_control:
|
||||
self.control_convs = [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, hidden_size, kernel_size=1),
|
||||
)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.control_convs = torch.nn.ModuleList(self.control_convs)
|
||||
|
||||
self.control_signals = None
|
||||
|
||||
def forward(self, h, context=None, value=None):
|
||||
transformer_options = self.transformer_options
|
||||
|
||||
modified_hidden_states = einops.rearrange(
|
||||
h, "(b f) d c -> f b d c", f=self.frames
|
||||
)
|
||||
|
||||
if self.control_convs is not None:
|
||||
context_dim = int(modified_hidden_states.shape[2])
|
||||
control_outs = []
|
||||
for f in range(self.frames):
|
||||
control_signal = self.control_signals[context_dim].to(
|
||||
modified_hidden_states
|
||||
)
|
||||
control = self.control_convs[f](control_signal)
|
||||
control = einops.rearrange(control, "b c h w -> b (h w) c")
|
||||
control_outs.append(control)
|
||||
control_outs = torch.stack(control_outs, dim=0)
|
||||
modified_hidden_states = modified_hidden_states + control_outs.to(
|
||||
modified_hidden_states
|
||||
)
|
||||
|
||||
if context is None:
|
||||
framed_context = modified_hidden_states
|
||||
else:
|
||||
framed_context = einops.rearrange(
|
||||
context, "(b f) d c -> f b d c", f=self.frames
|
||||
)
|
||||
|
||||
framed_cond_mark = einops.rearrange(
|
||||
compute_cond_mark(
|
||||
transformer_options["cond_or_uncond"],
|
||||
transformer_options["sigmas"],
|
||||
),
|
||||
"(b f) -> f b",
|
||||
f=self.frames,
|
||||
).to(modified_hidden_states)
|
||||
|
||||
attn_outs = []
|
||||
for f in range(self.frames):
|
||||
fcf = framed_context[f]
|
||||
|
||||
if context is not None:
|
||||
cond_overwrite = transformer_options.get("cond_overwrite", [])
|
||||
if len(cond_overwrite) > f:
|
||||
cond_overwrite = cond_overwrite[f]
|
||||
else:
|
||||
cond_overwrite = None
|
||||
if cond_overwrite is not None:
|
||||
cond_mark = framed_cond_mark[f][:, None, None]
|
||||
fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark
|
||||
|
||||
q = self.to_q_lora[f](modified_hidden_states[f])
|
||||
k = self.to_k_lora[f](fcf)
|
||||
v = self.to_v_lora[f](fcf)
|
||||
o = optimized_attention(q, k, v, self.heads)
|
||||
o = self.to_out_lora[f](o)
|
||||
o = self.original_module[0].to_out[1](o)
|
||||
attn_outs.append(o)
|
||||
|
||||
attn_outs = torch.stack(attn_outs, dim=0)
|
||||
modified_hidden_states = modified_hidden_states + attn_outs.to(
|
||||
modified_hidden_states
|
||||
)
|
||||
modified_hidden_states = einops.rearrange(
|
||||
modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
|
||||
)
|
||||
|
||||
x = modified_hidden_states
|
||||
x = self.temporal_n(x)
|
||||
x = self.temporal_i(x)
|
||||
d = x.shape[1]
|
||||
|
||||
x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)
|
||||
|
||||
q = self.temporal_q(x)
|
||||
k = self.temporal_k(x)
|
||||
v = self.temporal_v(x)
|
||||
|
||||
x = optimized_attention(q, k, v, self.heads)
|
||||
x = self.temporal_o(x)
|
||||
x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)
|
||||
|
||||
modified_hidden_states = modified_hidden_states + x
|
||||
|
||||
return modified_hidden_states - h
|
||||
|
||||
@classmethod
|
||||
def hijack_transformer_block(cls):
|
||||
def register_get_transformer_options(func):
|
||||
@functools.wraps(func)
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
cls.transformer_options = transformer_options
|
||||
return func(self, x, context, transformer_options)
|
||||
|
||||
return forward
|
||||
|
||||
from comfy.ldm.modules.attention import BasicTransformerBlock
|
||||
|
||||
BasicTransformerBlock.forward = register_get_transformer_options(
|
||||
BasicTransformerBlock.forward
|
||||
)
|
||||
|
||||
|
||||
AttentionSharingUnit.hijack_transformer_block()
|
||||
|
||||
|
||||
class AdditionalAttentionCondsEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.blocks_0 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 64*64*256
|
||||
|
||||
self.blocks_1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 32*32*256
|
||||
|
||||
self.blocks_2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 16*16*256
|
||||
|
||||
self.blocks_3 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 8*8*256
|
||||
|
||||
self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]
|
||||
|
||||
def __call__(self, h):
|
||||
results = {}
|
||||
for b in self.blks:
|
||||
h = b(h)
|
||||
results[int(h.shape[2]) * int(h.shape[3])] = h
|
||||
return results
|
||||
|
||||
|
||||
class HookerLayers(torch.nn.Module):
|
||||
def __init__(self, layer_list):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(layer_list)
|
||||
|
||||
|
||||
class AttentionSharingPatcher(torch.nn.Module):
|
||||
def __init__(self, unet, frames=2, use_control=True, rank=256):
|
||||
super().__init__()
|
||||
model_management.unload_model_clones(unet)
|
||||
|
||||
units = []
|
||||
for i in range(32):
|
||||
real_key = module_mapping_sd15[i]
|
||||
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
|
||||
u = AttentionSharingUnit(
|
||||
attn_module, frames=frames, use_control=use_control, rank=rank
|
||||
)
|
||||
units.append(u)
|
||||
unet.add_object_patch("diffusion_model." + real_key, u)
|
||||
|
||||
self.hookers = HookerLayers(units)
|
||||
|
||||
if use_control:
|
||||
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
|
||||
else:
|
||||
self.kwargs_encoder = None
|
||||
|
||||
self.dtype = torch.float32
|
||||
if model_management.should_use_fp16(model_management.get_torch_device()):
|
||||
self.dtype = torch.float16
|
||||
self.hookers.half()
|
||||
return
|
||||
|
||||
def set_control(self, img):
|
||||
img = img.cpu().float() * 2.0 - 1.0
|
||||
signals = self.kwargs_encoder(img)
|
||||
for m in self.hookers.layers:
|
||||
m.control_signals = signals
|
||||
return
|
||||
Reference in New Issue
Block a user