Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
#credit to comfyanonymous for this module
#from https://github.com/comfyanonymous/ComfyUI_bitsandbytes_NF4
import comfy.ops
import torch
import folder_paths
from ...libs.utils import install_package
try:
from bitsandbytes.nn.modules import Params4bit, QuantState
except ImportError:
Params4bit = torch.nn.Parameter
raise ImportError("Please install bitsandbytes>=0.43.3")
def functional_linear_4bits(x, weight, bias):
try:
install_package("bitsandbytes", "0.43.3", True, "0.43.3")
import bitsandbytes as bnb
except ImportError:
raise ImportError("Please install bitsandbytes>=0.43.3")
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state, device: torch.device = None):
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code.to(device),
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit().from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=self.dummy.device,
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)

View File

@@ -0,0 +1,475 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import numpy as np
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7,self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(self, in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1):
super(myrebnconv,self).__init__()
self.conv = nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self,x):
return self.rl(self.bn(self.conv(x)))
def preprocess_image(im, model_input_size: list) -> torch.Tensor:
# im = im.resize(model_input_size, Image.BILINEAR)
im_np = np.array(im)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
return image
def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
class BriaRMBG(nn.Module):
def __init__(self, config:dict={"in_ch":3,"out_ch":1}):
super(BriaRMBG,self).__init__()
in_ch = config["in_ch"]
out_ch = config["out_ch"]
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage1 = RSU7(64,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
#hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]

View File

@@ -0,0 +1,822 @@
#credit to nullquant for this module
#from https://github.com/nullquant/ComfyUI-BrushNet
import os
import types
import torch
try:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
except:
init_empty_weights, load_checkpoint_and_dispatch = None, None
import comfy
try:
from .model import BrushNetModel, PowerPaintModel
from .model_patch import add_model_patch_option, patch_model_function_wrapper
from .powerpaint_utils import TokenizerWrapper, add_tokens
except:
BrushNetModel, PowerPaintModel = None, None
add_model_patch_option, patch_model_function_wrapper = None, None
TokenizerWrapper, add_tokens = None, None
cwd_path = os.path.dirname(os.path.realpath(__file__))
brushnet_config_file = os.path.join(cwd_path, 'config', 'brushnet.json')
brushnet_xl_config_file = os.path.join(cwd_path, 'config', 'brushnet_xl.json')
powerpaint_config_file = os.path.join(cwd_path, 'config', 'powerpaint.json')
sd15_scaling_factor = 0.18215
sdxl_scaling_factor = 0.13025
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL]
class BrushNet:
# Check models compatibility
def check_compatibilty(self, model, brushnet):
is_SDXL = False
is_PP = False
if isinstance(model.model.model_config, comfy.supported_models.SD15):
print('Base model type: SD1.5')
is_SDXL = False
if brushnet["SDXL"]:
raise Exception("Base model is SD15, but BrushNet is SDXL type")
if brushnet["PP"]:
is_PP = True
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
print('Base model type: SDXL')
is_SDXL = True
if not brushnet["SDXL"]:
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
return (is_SDXL, is_PP)
def check_image_mask(self, image, mask, name):
if len(image.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
image = image[None, :, :, :]
if len(mask.shape) > 3:
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
# take first mask, red channel
mask = (mask[:, :, :, 0])[:, :, :]
elif len(mask.shape) < 3:
# mask tensor shape should be [B, H, W] but batch somehow is missing
mask = mask[None, :, :]
if image.shape[0] > mask.shape[0]:
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
if mask.shape[0] == 1:
print(name, "will copy the mask to fill batch")
mask = torch.cat([mask] * image.shape[0], dim=0)
else:
print(name, "will add empty masks to fill batch")
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
mask = torch.cat([mask, empty_mask], dim=0)
elif image.shape[0] < mask.shape[0]:
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
mask = mask[:image.shape[0], :, :]
return (image, mask)
# Prepare image and mask
def prepare_image(self, image, mask):
image, mask = self.check_image_mask(image, mask, 'BrushNet')
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
raise Exception("Image and mask should be the same size")
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
mask = mask.round()
masked_image = image * (1.0 - mask[:, :, :, None])
return (masked_image, mask)
# Get origin of the mask
def cut_with_mask(self, mask, width, height):
iy, ix = (mask == 1).nonzero(as_tuple=True)
h0, w0 = mask.shape
if iy.numel() == 0:
x_c = w0 / 2.0
y_c = h0 / 2.0
else:
x_min = ix.min().item()
x_max = ix.max().item()
y_min = iy.min().item()
y_max = iy.max().item()
if x_max - x_min > width or y_max - y_min > height:
raise Exception("Mask is bigger than provided dimensions")
x_c = (x_min + x_max) / 2.0
y_c = (y_min + y_max) / 2.0
width2 = width / 2.0
height2 = height / 2.0
if w0 <= width:
x0 = 0
w = w0
else:
x0 = max(0, x_c - width2)
w = width
if x0 + width > w0:
x0 = w0 - width
if h0 <= height:
y0 = 0
h = h0
else:
y0 = max(0, y_c - height2)
h = height
if y0 + height > h0:
y0 = h0 - height
return (int(x0), int(y0), int(w), int(h))
# Prepare conditioning_latents
@torch.inference_mode()
def get_image_latents(self, masked_image, mask, vae, scaling_factor):
processed_image = masked_image.to(vae.device)
image_latents = vae.encode(processed_image[:, :, :, :3]) * scaling_factor
processed_mask = 1. - mask[:, None, :, :]
interpolated_mask = torch.nn.functional.interpolate(
processed_mask,
size=(
image_latents.shape[-2],
image_latents.shape[-1]
)
)
interpolated_mask = interpolated_mask.to(image_latents.device)
conditioning_latents = [image_latents, interpolated_mask]
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =',
interpolated_mask.shape)
return conditioning_latents
def brushnet_blocks(self, sd):
brushnet_down_block = 0
brushnet_mid_block = 0
brushnet_up_block = 0
for key in sd:
if 'brushnet_down_block' in key:
brushnet_down_block += 1
if 'brushnet_mid_block' in key:
brushnet_mid_block += 1
if 'brushnet_up_block' in key:
brushnet_up_block += 1
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
def get_model_type(self, brushnet_file):
sd = comfy.utils.load_torch_file(brushnet_file)
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = self.brushnet_blocks(sd)
del sd
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
is_SDXL = False
if keys == 322:
is_PP = False
print('BrushNet model type: SD1.5')
else:
is_PP = True
print('PowerPaint model type: SD1.5')
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
print('BrushNet model type: Loading SDXL')
is_SDXL = True
is_PP = False
else:
raise Exception("Unknown BrushNet model")
return is_SDXL, is_PP
def load_brushnet_model(self, brushnet_file, dtype='float16'):
is_SDXL, is_PP = self.get_model_type(brushnet_file)
with init_empty_weights():
if is_SDXL:
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
elif is_PP:
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
brushnet_model = PowerPaintModel.from_config(brushnet_config)
else:
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
if is_PP:
print("PowerPaint model file:", brushnet_file)
else:
print("BrushNet model file:", brushnet_file)
if dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float32':
torch_dtype = torch.float32
else:
torch_dtype = torch.float64
brushnet_model = load_checkpoint_and_dispatch(
brushnet_model,
brushnet_file,
device_map="sequential",
max_memory=None,
offload_folder=None,
offload_state_dict=False,
dtype=torch_dtype,
force_hooks=False,
)
if is_PP:
print("PowerPaint model is loaded")
elif is_SDXL:
print("BrushNet SDXL model is loaded")
else:
print("BrushNet SD1.5 model is loaded")
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype},)
def brushnet_model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
is_SDXL, is_PP = self.check_compatibilty(model, brushnet)
if is_PP:
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
width = masked_image.shape[2]
height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
elif is_SDXL:
scaling_factor = sdxl_scaling_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = brushnet['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
# unload vae
del vae
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# prepare embeddings
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
if prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape,
'multiplying prompt_embeds')
if negative_prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
negative_prompt_embeds = torch.concat(
[negative_prompt_embeds] + [negative_prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape,
'multiplying negative_prompt_embeds')
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
else:
print('BrushNet: positive conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
else:
print('BrushNet: negative conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]],
device=brushnet['brushnet'].device).to(dtype=torch_dtype)
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
if not is_SDXL:
pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None
time_ids = None
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
add_brushnet_patch(model,
brushnet['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=brushnet['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
#powperpaint
def load_powerpaint_clip(self, base_clip_file, pp_clip_file):
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_clip_file])
print('PowerPaint base CLIP file: ', base_clip_file)
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
add_tokens(
tokenizer=pp_tokenizer,
text_encoder=pp_text_encoder,
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
initialize_tokens=["a", "a", "a"],
num_vectors_per_token=10,
)
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_clip_file), strict=False)
print('PowerPaint CLIP file: ', pp_clip_file)
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
return (pp_clip,)
def powerpaint_model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
is_SDXL, is_PP = self.check_compatibilty(model, powerpaint)
if not is_PP:
raise Exception("BrushNet model was loaded, please use BrushNet node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
# width = masked_image.shape[2]
# height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = powerpaint['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
# prepare embeddings
if function == "object removal":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
print('You should add to positive prompt: "empty scene blur"')
# positive = positive + " empty scene blur"
elif function == "context aware":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = ""
negative_promptB = ""
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
elif function == "shape guided":
promptA = "P_shape"
promptB = "P_ctxt"
negative_promptA = "P_shape"
negative_promptB = "P_ctxt"
elif function == "image outpainting":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
else:
promptA = "P_obj"
promptB = "P_obj"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
tokens = clip.tokenize(promptA)
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptA)
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(promptB)
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptB)
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(
powerpaint['brushnet'].device)
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(
dtype=torch_dtype).to(powerpaint['brushnet'].device)
# unload vae and CLIPs
del vae
del clip
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
if save_memory != 'none':
powerpaint['brushnet'].set_attention_slice(save_memory)
add_brushnet_patch(model,
powerpaint['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
negative_prompt_embeds_pp, prompt_embeds_pp,
None, None, None,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=powerpaint['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
@torch.inference_mode()
def brushnet_inference(x, timesteps, transformer_options, debug):
if 'model_patch' not in transformer_options:
print('BrushNet inference: there is no model_patch key in transformer_options')
return ([], 0, [])
mp = transformer_options['model_patch']
if 'brushnet' not in mp:
print('BrushNet inference: there is no brushnet key in mdel_patch')
return ([], 0, [])
bo = mp['brushnet']
if 'model' not in bo:
print('BrushNet inference: there is no model key in brushnet')
return ([], 0, [])
brushnet = bo['model']
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
print('BrushNet model is not a BrushNetModel class')
return ([], 0, [])
torch_dtype = bo['dtype']
cl_list = bo['latents']
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
pe = bo['prompt_embeds']
npe = bo['negative_prompt_embeds']
ppe, nppe, time_ids = bo['add_embeds']
#do_classifier_free_guidance = mp['free_guidance']
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
x = x.detach().clone()
x = x.to(torch_dtype).to(brushnet.device)
timesteps = timesteps.detach().clone()
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
total_steps = mp['total_steps']
step = mp['step']
added_cond_kwargs = {}
if do_classifier_free_guidance and step == 0:
print('BrushNet inference: do_classifier_free_guidance is True')
sub_idx = None
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
sub_idx = transformer_options['ad_params']['sub_idxs']
# we have batch input images
batch = cl_list[0].shape[0]
# we have incoming latents
latents_incoming = x.shape[0]
# and we already got some
latents_got = bo['latent_id']
if step == 0 or batch > 1:
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
% (step, batch, latents_incoming, latents_got))
image_latents = []
masks = []
prompt_embeds = []
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
if sub_idx:
# AnimateDiff indexes detected
if step == 0:
print('BrushNet inference: AnimateDiff indexes detected and applied')
batch = len(sub_idx)
if do_classifier_free_guidance:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
negative_prompt_embeds.append(npe)
pooled_prompt_embeds.append(ppe)
negative_pooled_prompt_embeds.append(nppe)
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
else:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
else:
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
continue_batch = True
for i in range(latents_incoming):
number = latents_got + i
if number < batch:
# 1st pass, cond
image_latents.append(cl_list[0][number][None,:,:,:])
masks.append(cl_list[1][number][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
elif do_classifier_free_guidance and number < batch * 2:
# 2nd pass, uncond
image_latents.append(cl_list[0][number-batch][None,:,:,:])
masks.append(cl_list[1][number-batch][None,:,:,:])
negative_prompt_embeds.append(npe)
negative_pooled_prompt_embeds.append(nppe)
else:
# latent batch
image_latents.append(cl_list[0][0][None,:,:,:])
masks.append(cl_list[1][0][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
latents_got = -i
continue_batch = False
if continue_batch:
# we don't have full batch yet
if do_classifier_free_guidance:
if number < batch * 2 - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
if number < batch - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
bo['latent_id'] = 0
cl = []
for il, m in zip(image_latents, masks):
cl.append(torch.concat([il, m], dim=1))
cl2apply = torch.concat(cl, dim=0)
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
prompt_embeds.extend(negative_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
if ppe is not None:
added_cond_kwargs = {}
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
else:
added_cond_kwargs = None
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
if step == 0:
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
conditioning_latents = torch.nn.functional.interpolate(
conditioning_latents, size=(
x.shape[2],
x.shape[3],
), mode='bicubic',
).to(torch_dtype).to(brushnet.device)
if step == 0:
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
if debug: print('BrushNet: step =', step)
if step < control_guidance_start or step > control_guidance_end:
cond_scale = 0.0
else:
cond_scale = brushnet_conditioning_scale
return brushnet(x,
encoder_hidden_states=prompt_embeds,
brushnet_cond=conditioning_latents,
timestep = timesteps,
conditioning_scale=cond_scale,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
debug=debug,
)
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
controls,
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
debug):
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
if model.model.model_config.custom_operations is None:
fp8 = model.model.model_config.optimizations.get("fp8", model.model.model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(model.model.model_config.unet_config.get("dtype", None), model.model.manual_cast_dtype,
fp8_optimizations=fp8, scaled_fp8=model.model.model_config.scaled_fp8)
else:
# such as gguf
operations = model.model.model_config.custom_operations
if is_SDXL:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
else:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.attention.SpatialTransformer],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[9, comfy.ldm.modules.attention.SpatialTransformer],
[10, comfy.ldm.modules.attention.SpatialTransformer],
[11, comfy.ldm.modules.attention.SpatialTransformer]]
def last_layer_index(block, tp):
layer_list = []
for layer in block:
layer_list.append(type(layer))
layer_list.reverse()
if tp not in layer_list:
return -1, layer_list.reverse()
return len(layer_list) - 1 - layer_list.index(tp), layer_list
def brushnet_forward(model, x, timesteps, transformer_options, control):
if 'brushnet' not in transformer_options['model_patch']:
input_samples = []
mid_sample = 0
output_samples = []
else:
# brushnet inference
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
# give additional samples to blocks
for i, tp in input_blocks:
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "input block:", layer_list)
continue
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
if idx < 0:
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
model.middle_block[idx].add_sample_after = mid_sample
for i, tp in output_blocks:
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "outnput block:", layer_list)
continue
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
patch_model_function_wrapper(model, brushnet_forward)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'brushnet' not in mp:
mp['brushnet'] = {}
bo = mp['brushnet']
bo['model'] = brushnet
bo['dtype'] = torch_dtype
bo['latents'] = conditioning_latents
bo['controls'] = controls
bo['prompt_embeds'] = prompt_embeds
bo['negative_prompt_embeds'] = negative_prompt_embeds
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
bo['latent_id'] = 0
# patch layers `forward` so we can apply brushnet
def forward_patched_by_brushnet(self, x, *args, **kwargs):
h = self.original_forward(x, *args, **kwargs)
if hasattr(self, 'add_sample_after') and type(self):
to_add = self.add_sample_after
if torch.is_tensor(to_add):
# interpolate due to RAUNet
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
h += to_add.to(h.dtype).to(h.device)
else:
h += self.add_sample_after
self.add_sample_after = 0
return h
for i, block in enumerate(model.model.diffusion_model.input_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for j, layer in enumerate(model.model.diffusion_model.middle_block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for i, block in enumerate(model.model.diffusion_model.output_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0

View File

@@ -0,0 +1,58 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

View File

@@ -0,0 +1,63 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [
5,
10,
20
],
"block_out_channels": [
320,
640,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 2048,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": 2816,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": [
1,
2,
10
],
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": null,
"use_linear_projection": true
}

View File

@@ -0,0 +1,57 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.2",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
import torch
import comfy
# Check and add 'model_patch' to model.model_options['transformer_options']
def add_model_patch_option(model):
if 'transformer_options' not in model.model_options:
model.model_options['transformer_options'] = {}
to = model.model_options['transformer_options']
if "model_patch" not in to:
to["model_patch"] = {}
return to
# Patch model with model_function_wrapper
def patch_model_function_wrapper(model, forward_patch, remove=False):
def brushnet_model_function_wrapper(apply_model_method, options_dict):
to = options_dict['c']['transformer_options']
control = None
if 'control' in options_dict['c']:
control = options_dict['c']['control']
x = options_dict['input']
timestep = options_dict['timestep']
# check if there are patches to execute
if 'model_patch' not in to or 'forward' not in to['model_patch']:
return apply_model_method(x, timestep, **options_dict['c'])
mp = to['model_patch']
unet = mp['unet']
all_sigmas = mp['all_sigmas']
sigma = to['sigmas'][0].item()
total_steps = all_sigmas.shape[0] - 1
step = torch.argmin((all_sigmas - sigma).abs()).item()
mp['step'] = step
mp['total_steps'] = total_steps
# comfy.model_base.apply_model
xc = model.model.model_sampling.calculate_input(timestep, x)
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
t = model.model.model_sampling.timestep(timestep).float()
# execute all patches
for method in mp['forward']:
method(unet, xc, t, to, control)
return apply_model_method(x, timestep, **options_dict['c'])
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
print('BrushNet is going to replace existing model_function_wrapper:',
model.model_options["model_function_wrapper"])
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
to = add_model_patch_option(model)
mp = to['model_patch']
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
if 'forward' not in mp:
mp['forward'] = []
if remove:
if forward_patch in mp['forward']:
mp['forward'].remove(forward_patch)
else:
mp['forward'].append(forward_patch)
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1
# apply patches to code
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
comfy.samplers.original_sample = comfy.samplers.sample
comfy.samplers.sample = modified_sample
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
''' Modified by BrushNet nodes'''
cfg_guider = comfy.samplers.CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
### Modified part ######################################################################
to = add_model_patch_option(model)
to['model_patch']['all_sigmas'] = sigmas
#######################################################################################
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
def modified_apply_control(h, control, name):
'''Modified by BrushNet nodes'''
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(
h.dtype).to(h.device)
try:
h += ctrl
except:
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
def add_model_patch(model):
to = add_model_patch_option(model)
mp = to['model_patch']
if "brushnet" in mp:
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1

View File

@@ -0,0 +1,467 @@
import copy
import random
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from typing import Any, List, Optional, Union
class TokenizerWrapper:
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
currently. This wrapper is modified from https://github.com/huggingface/dif
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
py#L358 # noqa.
Args:
from_pretrained (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
from_config (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
will be passed to `from_pretrained` function. Otherwise, *args
and **kwargs will be used to initialize the model by
`self._module_cls(*args, **kwargs)`.
"""
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
def __getattr__(self, name: str) -> Any:
if name in self.__dict__:
return getattr(self, name)
# if name == "wrapped":
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text
def __repr__(self):
"""The representation of the wrapper."""
s = super().__repr__()
prefix = f"Wrapped Module Class: {self._module_cls}\n"
prefix += f"Wrapped Module Name: {self._module_name}\n"
if self._from_pretrained:
prefix += f"From Pretrained: {self._from_pretrained}\n"
s = prefix + s
return s
class EmbeddingLayerWithFixes(nn.Module):
"""The revised embedding layer to support external embeddings. This design
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
jack.py#L224 # noqa.
Args:
wrapped (nn.Emebdding): The embedding layer to be wrapped.
external_embeddings (Union[dict, List[dict]], optional): The external
embeddings added to this layer. Defaults to None.
"""
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
super().__init__()
self.wrapped = wrapped
self.num_embeddings = wrapped.weight.shape[0]
self.external_embeddings = []
if external_embeddings:
self.add_embeddings(external_embeddings)
self.trainable_embeddings = nn.ParameterDict()
@property
def weight(self):
"""Get the weight of wrapped embedding layer."""
return self.wrapped.weight
def check_duplicate_names(self, embeddings: List[dict]):
"""Check whether duplicate names exist in list of 'external
embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
names = [emb["name"] for emb in embeddings]
assert len(names) == len(set(names)), (
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
)
def check_ids_overlap(self, embeddings):
"""Check whether overlap exist in token ids of 'external_embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
ids_range.sort() # sort by 'start'
# check if 'end' has overlapping
for idx in range(len(ids_range) - 1):
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
)
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
"""Add external embeddings to this layer.
Use case:
Args:
embeddings (Union[dict, list[dict]]): The external embeddings to
be added. Each dict must contain the following 4 fields: 'name'
(the name of this embedding), 'embedding' (the embedding
tensor), 'start' (the start token id of this embedding), 'end'
(the end token id of this embedding). For example:
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
"""
if isinstance(embeddings, dict):
embeddings = [embeddings]
self.external_embeddings += embeddings
self.check_duplicate_names(self.external_embeddings)
self.check_ids_overlap(self.external_embeddings)
# set for trainable
added_trainable_emb_info = []
for embedding in embeddings:
trainable = embedding.get("trainable", False)
if trainable:
name = embedding["name"]
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
self.trainable_embeddings[name] = embedding["embedding"]
added_trainable_emb_info.append(name)
added_emb_info = [emb["name"] for emb in embeddings]
added_emb_info = ", ".join(added_emb_info)
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
if added_trainable_emb_info:
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Replace external input ids to 0.
Args:
input_ids (torch.Tensor): The input ids to be replaced.
Returns:
torch.Tensor: The replaced input ids.
"""
input_ids_fwd = input_ids.clone()
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
return input_ids_fwd
def replace_embeddings(
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
) -> torch.Tensor:
"""Replace external embedding to the embedding layer. Noted that, in
this function we use `torch.cat` to avoid inplace modification.
Args:
input_ids (torch.Tensor): The original token ids. Shape like
[LENGTH, ].
embedding (torch.Tensor): The embedding of token ids after
`replace_input_ids` function.
external_embedding (dict): The external embedding to be replaced.
Returns:
torch.Tensor: The replaced embedding.
"""
new_embedding = []
name = external_embedding["name"]
start = external_embedding["start"]
end = external_embedding["end"]
target_ids_to_replace = [i for i in range(start, end)]
ext_emb = external_embedding["embedding"].to(embedding.device)
# do not need to replace
if not (input_ids == start).any():
return embedding
# start replace
s_idx, e_idx = 0, 0
while e_idx < len(input_ids):
if input_ids[e_idx] == start:
if e_idx != 0:
# add embedding do not need to replace
new_embedding.append(embedding[s_idx:e_idx])
# check if the next embedding need to replace is valid
actually_ids_to_replace = [int(i) for i in input_ids[e_idx: e_idx + end - start]]
assert actually_ids_to_replace == target_ids_to_replace, (
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
f"Expect '{target_ids_to_replace}' for embedding "
f"'{name}' but found '{actually_ids_to_replace}'."
)
new_embedding.append(ext_emb)
s_idx = e_idx + end - start
e_idx = s_idx + 1
else:
e_idx += 1
if e_idx == len(input_ids):
new_embedding.append(embedding[s_idx:e_idx])
return torch.cat(new_embedding, dim=0)
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
"""The forward function.
Args:
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
[LENGTH, ].
external_embeddings (Optional[List[dict]]): The external
embeddings. If not passed, only `self.external_embeddings`
will be used. Defaults to None.
input_ids: shape like [bz, LENGTH] or [LENGTH].
"""
assert input_ids.ndim in [1, 2]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if external_embeddings is None and not self.external_embeddings:
return self.wrapped(input_ids, out_dtype=out_dtype)
input_ids_fwd = self.replace_input_ids(input_ids)
inputs_embeds = self.wrapped(input_ids_fwd)
vecs = []
if external_embeddings is None:
external_embeddings = []
elif isinstance(external_embeddings, dict):
external_embeddings = [external_embeddings]
embeddings = self.external_embeddings + external_embeddings
for input_id, embedding in zip(input_ids, inputs_embeds):
new_embedding = embedding
for external_embedding in embeddings:
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
vecs.append(new_embedding)
return torch.stack(vecs).to(out_dtype)
def add_tokens(
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None,
num_vectors_per_token: int = 1
):
"""Add token for training.
# TODO: support add tokens as dict, then we can load pretrained tokens.
"""
if initialize_tokens is not None:
assert len(initialize_tokens) == len(
placeholder_tokens
), "placeholder_token should be the same length as initialize_token"
for ii in range(len(placeholder_tokens)):
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
# text_encoder.set_embedding_layer()
embedding_layer = text_encoder.text_model.embeddings.token_embedding
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
embedding_layer = text_encoder.text_model.embeddings.token_embedding
assert embedding_layer is not None, (
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
)
initialize_embedding = []
if initialize_tokens is not None:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
else:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer("a").input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
len_emb = temp_embedding.shape[0]
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
initialize_embedding.append(init_weight)
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
token_info_all = []
for ii in range(len(placeholder_tokens)):
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
token_info["embedding"] = initialize_embedding[ii]
token_info["trainable"] = True
token_info_all.append(token_info)
embedding_layer.add_embeddings(token_info_all)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
#credit to city96 for this module
#from https://github.com/city96/ComfyUI_ExtraModels/

View File

@@ -0,0 +1,120 @@
"""
List of all DiT model types / settings
"""
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
dit_conf = {
"XL/2": { # DiT_XL_2
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"XL/4": { # DiT_XL_4
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 4,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"XL/8": { # DiT_XL_8
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 8,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"L/2": { # DiT_L_2
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"L/4": { # DiT_L_4
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 4,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"L/8": { # DiT_L_8
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 8,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"B/2": { # DiT_B_2
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 2,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"B/4": { # DiT_B_4
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 4,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"B/8": { # DiT_B_8
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 8,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"S/2": { # DiT_S_2
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 2,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
"S/4": { # DiT_S_4
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 4,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
"S/8": { # DiT_S_8
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 8,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
}

View File

@@ -0,0 +1,661 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.

View File

@@ -0,0 +1,139 @@
"""
List of all PixArt model types / settings
"""
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
pixart_conf = {
"PixArtMS_XL_2": { # models/PixArtMS
"target": "PixArtMS",
"unet_config": {
"input_size" : 1024//8,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"pe_interpolation": 2,
},
"sampling_settings" : sampling_settings,
},
"PixArtMS_Sigma_XL_2": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size" : 1024//8,
"token_num" : 300,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"micro_condition": False,
"pe_interpolation": 2,
"model_max_length": 300,
},
"sampling_settings" : sampling_settings,
},
"PixArtMS_Sigma_XL_2_900M": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size": 1024 // 8,
"token_num": 300,
"depth": 42,
"num_heads": 16,
"patch_size": 2,
"hidden_size": 1152,
"micro_condition": False,
"pe_interpolation": 2,
"model_max_length": 300,
},
"sampling_settings": sampling_settings,
},
"PixArtMS_Sigma_XL_2_2K": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size" : 2048//8,
"token_num" : 300,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"micro_condition": False,
"pe_interpolation": 4,
"model_max_length": 300,
},
"sampling_settings" : sampling_settings,
},
"PixArt_XL_2": { # models/PixArt
"target": "PixArt",
"unet_config": {
"input_size" : 512//8,
"token_num" : 120,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"pe_interpolation": 1,
},
"sampling_settings" : sampling_settings,
},
}
pixart_conf.update({ # controlnet models
"ControlPixArtHalf": {
"target": "ControlPixArtHalf",
"unet_config": pixart_conf["PixArt_XL_2"]["unet_config"],
"sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"],
},
"ControlPixArtMSHalf": {
"target": "ControlPixArtMSHalf",
"unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"],
"sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"],
}
})
pixart_res = {
"PixArtMS_XL_2": { # models/PixArtMS 1024x1024
'0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856],
'0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600],
'0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344],
'0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152],
'0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024],
'1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896],
'1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768],
'1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640],
'2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576],
'3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512],
},
"PixArt_XL_2": { # models/PixArt 512x512
'0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928],
'0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800],
'0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672],
'0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576],
'0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512],
'1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448],
'1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384],
'1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320],
'2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288],
'3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256]
},
"PixArtMS_Sigma_XL_2_2K": {
'0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712],
'0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200],
'0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688],
'0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304],
'0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048],
'1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792],
'1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536],
'1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280],
'2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152],
'3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024]
}
}
# These should be the same
pixart_res.update({
"PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"],
"PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"],
})

View File

@@ -0,0 +1,216 @@
# For using the diffusers format weights
# Based on the original ComfyUI function +
# https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py
import torch
conversion_map_ms = [ # for multi_scale_train (MS)
# Resolution
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
# Aspect ratio
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
]
def get_depth(state_dict):
return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys())
def get_lora_depth(state_dict):
return sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys())
def get_conversion_map(state_dict):
conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers)
# Patch embeddings
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
# Caption projection
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
# AdaLN-single LN
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
# Shared norm
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
# Final block
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
]
# Add actual transformer blocks
for depth in range(get_depth(state_dict)):
# Transformer blocks
conversion_map += [
(f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"),
# Projection
(f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"),
(f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"),
# Feed-forward
(f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"),
(f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"),
(f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"),
(f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"),
# Cross-attention (proj)
(f"blocks.{depth}.cross_attn.proj.weight", f"transformer_blocks.{depth}.attn2.to_out.0.weight"),
(f"blocks.{depth}.cross_attn.proj.bias", f"transformer_blocks.{depth}.attn2.to_out.0.bias"),
]
return conversion_map
def find_prefix(state_dict, target_key):
prefix = ""
for k in state_dict.keys():
if k.endswith(target_key):
prefix = k.split(target_key)[0]
break
return prefix
def convert_state_dict(state_dict):
if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys():
cmap = get_conversion_map(state_dict) + conversion_map_ms
else:
cmap = get_conversion_map(state_dict)
missing = [k for k, v in cmap if v not in state_dict]
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
matched = list(v for k, v in cmap if v in state_dict.keys())
for depth in range(get_depth(state_dict)):
for wb in ["weight", "bias"]:
# Self Attention
key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}"
new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat((
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
# Cross-attention (linear)
key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}"
new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')]
new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat((
state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if len(matched) < len(state_dict):
print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
print(list(set(state_dict.keys()) - set(matched)))
if len(missing) > 0:
print(f"PixArt: UNET conversion has missing keys!")
print(missing)
return new_state_dict
# Same as above but for LoRA weights:
def convert_lora_state_dict(state_dict, peft=True):
# koyha
rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")
rep_pk = lambda x: x.replace(".weight", ".alpha")
if peft: # peft
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
rep_pp = lambda x: x.replace(".weight", ".alpha")
prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
else: # OneTrainer
rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight"
rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight"
rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha"
prefix = "lora_transformer_"
t5_marker = "lora_te_encoder"
t5_keys = []
for key in list(state_dict.keys()):
if key.startswith(prefix):
state_dict[key[len(prefix):]] = state_dict.pop(key)
elif t5_marker in key:
t5_keys.append(state_dict.pop(key))
if len(t5_keys) > 0:
print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys")
cmap = []
cmap_unet = get_conversion_map(state_dict) + conversion_map_ms # todo: 512 model
for k, v in cmap_unet:
if v.endswith(".weight"):
cmap.append((rep_ak(k), rep_ap(v)))
cmap.append((rep_bk(k), rep_bp(v)))
if not peft:
cmap.append((rep_pk(k), rep_pp(v)))
missing = [k for k, v in cmap if v not in state_dict]
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
matched = list(v for k, v in cmap if v in state_dict.keys())
lora_depth = get_lora_depth(state_dict)
for fp, fk in ((rep_ap, rep_ak), (rep_bp, rep_bk)):
for depth in range(lora_depth):
# Self Attention
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")]
matched += [akey('q'), akey('k'), akey('v')]
# Self Attention projection?
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')]
matched += [key('out.0')]
# Cross-attention (linear)
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')]
new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat((
state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")]
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")]
matched += [akey('q'), akey('k'), akey('v')]
# Cross Attention projection?
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')]
matched += [key('out.0')]
key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight")
new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key]
matched += [key]
key = fp(f"transformer_blocks.{depth}.ff.net.2.weight")
new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key]
matched += [key]
if len(matched) < len(state_dict):
print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
print(list(set(state_dict.keys()) - set(matched)))
if len(missing) > 0:
print(f"PixArt: LoRA conversion has missing keys! (probably)")
print(missing)
return new_state_dict

View File

@@ -0,0 +1,331 @@
import os
import json
import copy
import torch
import math
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
from comfy import model_management
from .diffusers_convert import convert_state_dict, convert_lora_state_dict
# checkpointbf
class EXM_PixArt(comfy.supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = comfy.latent_formats.SD15
def __init__(self, model_conf):
self.model_target = model_conf.get("target")
self.unet_config = model_conf.get("unet_config", {})
self.sampling_settings = model_conf.get("sampling_settings", {})
self.latent_format = self.latent_format()
# UNET is handled by extension
self.unet_config["disable_unet_model_creation"] = True
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS
class EXM_PixArt_Model(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
img_hw = kwargs.get("img_hw", None)
if img_hw is not None:
out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
aspect_ratio = kwargs.get("aspect_ratio", None)
if aspect_ratio is not None:
out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
cn_hint = kwargs.get("cn_hint", None)
if cn_hint is not None:
out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
return out
def load_pixart(model_path, model_conf=None):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)
# prefix
for prefix in ["model.diffusion_model.", ]:
if any(True for x in state_dict if x.startswith(prefix)):
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
# diffusers
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers
# guess auto config
if model_conf is None:
model_conf = guess_pixart_config(state_dict)
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
# ignore fp8/etc and use directly for now
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype:
print(f"PixArt: falling back to {manual_cast_dtype}")
unet_dtype = manual_cast_dtype
model_conf = EXM_PixArt(model_conf) # convert to object
model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
model_conf,
model_type=comfy.model_base.ModelType.EPS,
device=model_management.get_torch_device()
)
if model_conf.model_target == "PixArtMS":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
elif model_conf.model_target == "PixArt":
from .models.PixArt import PixArt
model.diffusion_model = PixArt(**model_conf.unet_config)
elif model_conf.model_target == "PixArtMSSigma":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.latent_format = comfy.latent_formats.SDXL()
elif model_conf.model_target == "ControlPixArtMSHalf":
from .models.PixArtMS import PixArtMS
from .models.pixart_controlnet import ControlPixArtMSHalf
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
elif model_conf.model_target == "ControlPixArtHalf":
from .models.PixArt import PixArt
from .models.pixart_controlnet import ControlPixArtHalf
model.diffusion_model = PixArt(**model_conf.unet_config)
model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
else:
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
if len(m) > 0: print("Missing UNET keys", m)
if len(u) > 0: print("Leftover UNET keys", u)
model.diffusion_model.dtype = unet_dtype
model.diffusion_model.eval()
model.diffusion_model.to(unet_dtype)
model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=load_device,
offload_device=offload_device,
)
return model_patcher
def guess_pixart_config(sd):
"""
Guess config based on converted state dict.
"""
# Shared settings based on DiT_XL_2 - could be enumerated
config = {
"num_heads": 16, # get from attention
"patch_size": 2, # final layer I guess?
"hidden_size": 1152, # pos_embed.shape[2]
}
config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28
try:
# this is not present in the diffusers version for sigma?
config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
except KeyError:
# need better logic to guess this
config["model_max_length"] = 300
if "pos_embed" in sd:
config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
config["pe_interpolation"] = config["input_size"] // (512 // 8) # dumb guess
target_arch = "PixArtMS"
if config["model_max_length"] == 300:
# Sigma
target_arch = "PixArtMSSigma"
config["micro_condition"] = False
if "input_size" not in config:
# The diffusers weights for 1K/2K are exactly the same...?
# replace patch embed logic with HyDiT?
print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
config["input_size"] = 1024 // 8
else:
# Alpha
if "csize_embedder.mlp.0.weight" in sd:
# MS (microconds)
target_arch = "PixArtMS"
config["micro_condition"] = True
if "input_size" not in config:
config["input_size"] = 1024 // 8
config["pe_interpolation"] = 2
else:
# PixArt
target_arch = "PixArt"
if "input_size" not in config:
config["input_size"] = 512 // 8
config["pe_interpolation"] = 1
print("PixArt guessed config:", target_arch, config)
return {
"target": target_arch,
"unet_config": config,
"sampling_settings": {
"beta_schedule": "sqrt_linear",
"linear_start": 0.0001,
"linear_end": 0.02,
"timesteps": 1000,
}
}
# lora
class EXM_PixArt_ModelPatcher(comfy.model_patcher.ModelPatcher):
def calculate_weight(self, patches, weight, key):
"""
This is almost the same as the comfy function, but stripped down to just the LoRA patch code.
The problem with the original code is the q/k/v keys being combined into one for the attention.
In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv).
This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict.
The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead.
This is the best thing I could think of that would fix that, but it's very fragile.
- Check key shape to determine if it needs the fallback logic
- Cut the input into parts based on the shape (undoing the torch.cat)
- Do the matrix multiplication logic
- Recombine them to match the expected shape
"""
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "lora":
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
try:
mat1 = mat1.flatten(start_dim=1)
mat2 = mat2.flatten(start_dim=1)
ch1 = mat1.shape[0] // mat2.shape[1]
ch2 = mat2.shape[0] // mat1.shape[1]
### Fallback logic for shape mismatch ###
if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0] / mat2.shape[1]) % 1 == 0:
mat1 = mat1.chunk(ch1, dim=0)
mat2 = mat2.chunk(ch1, dim=0)
weight += torch.cat(
[alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)],
dim=0,
).reshape(weight.shape).type(weight.dtype)
else:
weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def clone(self):
n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def replace_model_patcher(model):
n = EXM_PixArt_ModelPatcher(
model=model.model,
size=model.size,
load_device=model.load_device,
offload_device=model.offload_device,
current_device=model.current_device,
weight_inplace_update=model.weight_inplace_update,
)
n.patches = {}
for k in model.patches:
n.patches[k] = model.patches[k][:]
n.object_patches = model.object_patches.copy()
n.model_options = copy.deepcopy(model.model_options)
return n
def find_peft_alpha(path):
def load_json(json_path):
with open(json_path) as f:
data = json.load(f)
alpha = data.get("lora_alpha")
alpha = alpha or data.get("alpha")
if not alpha:
print(" Found config but `lora_alpha` is missing!")
else:
print(f" Found config at {json_path} [alpha:{alpha}]")
return alpha
# For some weird reason peft doesn't include the alpha in the actual model
print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...")
files = [
f"{os.path.splitext(path)[0]}.json",
f"{os.path.splitext(path)[0]}.config.json",
os.path.join(os.path.dirname(path), "adapter_config.json"),
]
for file in files:
if os.path.isfile(file):
return load_json(file)
print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.")
return 8.0
def load_pixart_lora(model, lora, lora_path, strength):
k_back = lambda x: x.replace(".lora_up.weight", "")
# need to convert the actual weights for this to work.
if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
lora = convert_lora_state_dict(lora, peft=True)
alpha = find_peft_alpha(lora_path)
lora.update({f"{k_back(x)}.alpha": torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
else: # OneTrainer
lora = convert_lora_state_dict(lora, peft=False)
key_map = {k_back(x): f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake
loaded = comfy.lora.load_lora(lora, key_map)
if model is not None:
# switch to custom model patcher when using LoRAs
if isinstance(model, EXM_PixArt_ModelPatcher):
new_modelpatcher = model.clone()
else:
new_modelpatcher = replace_model_patcher(model)
k = new_modelpatcher.add_patches(loaded, strength)
else:
k = ()
new_modelpatcher = None
k = set(k)
for x in loaded:
if (x not in k):
print("NOT LOADED", x)
return new_modelpatcher

View File

@@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import os
import numpy as np
from timm.models.layers import DropPath
from timm.models.vision_transformer import PatchEmbed, Mlp
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=1.0,
pe_precision=None,
config=None,
model_max_length=120,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.depth = depth
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size // self.patch_size
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=self.kv_compress_config['sampling'],
sr_ratio=int(
self.kv_compress_config['scale_factor']
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
def forward_raw(self, x, t, y, mask=None, data_info=None):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, y=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
y: extra conditioning.
"""
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = to_2tuple(grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed.astype(np.float32)
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb

View File

@@ -0,0 +1,273 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
from tqdm import tqdm
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
from .PixArt import PixArt, get_2d_sincos_pos_embed
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(PixArt):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__(
input_size=input_size,
patch_size=patch_size,
in_channels=in_channels,
hidden_size=hidden_size,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
class_dropout_prob=class_dropout_prob,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
drop_path=drop_path,
pe_interpolation=pe_interpolation,
config=config,
model_max_length=model_max_length,
qk_norm=qk_norm,
kv_compress_config=kv_compress_config,
**kwargs,
)
self.dtype = torch.get_default_dtype()
self.h = self.w = 0
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
def forward_raw(self, x, t, y, mask=None, data_info=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
bs = x.shape[0]
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round((x.shape[-1]+x.shape[-2])/2.0 / (512/8.0), self.pe_precision or 0)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=pe_interpolation,
base_size=self.base_size
)
).unsqueeze(0).to(device=x.device, dtype=self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep) # (N, D)
if self.micro_conditioning:
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
csize = self.csize_embedder(c_size, bs) # (N, D)
ar = self.ar_embedder(ar, bs) # (N, D)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
img_hw: height|width conditioning
aspect_ratio: aspect ratio conditioning
"""
## size/ar from cond with fallback based on the latent image shape.
bs = x.shape[0]
data_info = {}
if img_hw is None:
data_info["img_hw"] = torch.tensor(
[[x.shape[2]*8, x.shape[3]*8]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["img_hw"] = img_hw.to(dtype=x.dtype, device=x.device)
if aspect_ratio is None or True:
data_info["aspect_ratio"] = torch.tensor(
[[x.shape[2]/x.shape[3]]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["aspect_ratio"] = aspect_ratio.to(dtype=x.dtype, device=x.device)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
data_info=data_info,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
assert self.h * self.w == x.shape[1]
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
return imgs

View File

@@ -0,0 +1,477 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Mlp, Attention as Attention_
from einops import rearrange
from comfy import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
else:
print("""
########################################
PixArt: Not using xformers!
Expect images to be non-deterministic!
Batch sizes > 1 are most likely broken
########################################
""")
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model*2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
if model_management.xformers_enabled():
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
attn_bias=attn_bias
)
else:
q, k, v = map(lambda t: t.permute(0, 2, 1, 3),(q, k, v),)
attn_mask = None
if mask is not None and len(mask) > 1:
# Create equivalent of xformer diagonal block mask, still only correct for square masks
# But depth doesn't matter as tensors can expand in that dimension
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask = torch.block_diag(attn_mask_template)
# create a mask on the diagonal for each mask in the batch
for n in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p
).permute(0, 2, 1, 3).contiguous()
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(Attention_):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
sampling='conv',
sr_ratio=1,
qk_norm=False,
**block_kwargs,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
self.sr.weight.data.fill_(1/sr_ratio**2)
self.sr.bias.data.zero_()
self.norm = nn.LayerNorm(dim)
if qk_norm:
self.q_norm = nn.LayerNorm(dim)
self.k_norm = nn.LayerNorm(dim)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
dtype = q.dtype
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
attn_bias = None
if mask is not None:
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# Switch between torch / xformers attention
if model_management.xformers_enabled():
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
attn_bias=attn_bias
)
else:
q, k, v = map(lambda t: t.transpose(1, 2),(q, k, v),)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
attn_mask=attn_bias
).transpose(1, 2).contiguous()
x = x.view(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
#################################################################################
# AMP attention with fp32 softmax to fix loss NaN problem during training #
#################################################################################
class Attention(Attention_):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
use_fp32_attention = getattr(self, 'fp32_attention', False)
if use_fp32_attention:
q, k = q.float(), k.float()
with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size):
super().__init__()
self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(t.dtype))
return t_emb
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
super().__init__()
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
super().__init__()
self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption

View File

@@ -0,0 +1,312 @@
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
from .PixArt import PixArt, get_2d_sincos_pos_embed
from .PixArtMS import PixArtMSBlock, PixArtMS
from .utils import auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t, mask=None, c=None):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c = self.copied_block(x + c, y, t, mask)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c = self.copied_block(c, y, t, mask)
c_skip = self.after_proj(c)
return c, c_skip
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None:
super().__init__()
self.dtype = torch.get_default_dtype()
self.base_model = base_model.eval()
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
return self.x_embedder(c) + pos_embed if c is not None else c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
# modify the original PixArtMS forward function
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
if c is not None:
# update c
for index in range(1, self.copy_blocks_num + 1):
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
else:
for index in range(1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, cn_hint=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
cn_hint: controlnet hint
"""
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
timestep = timesteps.to(self.dtype),
y = context.to(self.dtype),
c = cn_hint,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
model_out = self.forward_raw(x, t, y, data_info=data_info, c=c, **kwargs)
return model_out.chunk(2, dim=1)[0]
# def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
# return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs)
def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs):
return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
assert self.h * self.w == x.shape[1]
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
return imgs
# @property
# def dtype(self):
## 返回模型参数的数据类型
# return next(self.parameters()).dtype
# The implementation for PixArtMS_Half + 1024 resolution
class ControlPixArtMSHalf(ControlPixArtHalf):
# support multi-scale res model (multi-scale model can also be applied to single reso training & inference)
def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None:
super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num)
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
# modify the original PixArtMS forward function
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
bs = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep) # (N, D)
csize = self.csize_embedder(c_size, bs) # (N, D)
ar = self.ar_embedder(ar, bs) # (N, D)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
if c is not None:
# update c
for index in range(1, self.copy_blocks_num + 1):
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
else:
for index in range(1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, cn_hint=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
img_hw: height|width conditioning
aspect_ratio: aspect ratio conditioning
cn_hint: controlnet hint
"""
## size/ar from cond with fallback based on the latent image shape.
bs = x.shape[0]
data_info = {}
if img_hw is None:
data_info["img_hw"] = torch.tensor(
[[x.shape[2]*8, x.shape[3]*8]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["img_hw"] = img_hw.to(x.dtype)
if aspect_ratio is None or True:
data_info["aspect_ratio"] = torch.tensor(
[[x.shape[2]/x.shape[3]]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["aspect_ratio"] = aspect_ratio.to(x.dtype)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
timestep = timesteps.to(self.dtype),
y = context.to(self.dtype),
c = cn_hint,
data_info=data_info,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps

View File

@@ -0,0 +1,122 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from collections.abc import Iterable
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
assert isinstance(model, nn.Module)
def set_attr(module):
module.grad_checkpointing = True
module.fp32_attention = use_fp32_attention
module.grad_checkpointing_step = gc_step
model.apply(set_attr)
def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, 'grad_checkpointing', False):
if isinstance(module, Iterable):
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, **kwargs)
else:
return checkpoint(module, *args, **kwargs)
return module(*args, **kwargs)
def checkpoint_sequential(functions, step, input, *args, **kwargs):
# Hack for keyword-only parameter in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
def run_function(start, end, functions):
def forward(input):
for j in range(start, end + 1):
input = functions[j](input, *args)
return input
return forward
if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())
# the last chunk has to be non-volatile
end = -1
segment = len(functions) // step
for start in range(0, step * (segment - 1), step):
end = start + step - 1
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
return run_function(end + 1, len(functions) - 1, functions)(input)
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn

View File

@@ -0,0 +1,38 @@
import torch
from comfy import model_management
def string_to_dtype(s="none", mode=None):
s = s.lower().strip()
if s in ["default", "as-is"]:
return None
elif s in ["auto", "auto (comfy)"]:
if mode == "vae":
return model_management.vae_device()
elif mode == "text_encoder":
return model_management.text_encoder_dtype()
elif mode == "unet":
return model_management.unet_dtype()
else:
raise NotImplementedError(f"Unknown dtype mode '{mode}'")
elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
return None
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16", "half"]:
return torch.float16
elif "fp8" in s or "float8" in s:
if "e5m2" in s:
return torch.float8_e5m2
elif "e4m3" in s:
return torch.float8_e4m3fn
else:
raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
elif "bnb" in s:
assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
return s
elif s is None:
return None
else:
raise NotImplementedError(f"Unknown dtype '{s}'")

View File

@@ -0,0 +1,139 @@
#credit to Acly for this module
#from https://github.com/Acly/comfyui-inpaint-nodes
import torch
import torch.nn.functional as F
import comfy
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.model_management import cast_to_device
from ...libs.log import log_node_warn, log_node_error, log_node_info
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
def __call__(self, x):
x = F.pad(x, (1, 1, 1, 1), "replicate")
return F.conv2d(x, weight=self.head)
# injected_model_patcher_calculate_weight = False
# original_calculate_weight = None
class applyFooocusInpaint:
def calculate_weight_patched(self, patches, weight, key, intermediate_dtype=torch.float32):
remaining = []
for p in patches:
alpha = p[0]
v = p[1]
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
if not is_fooocus_patch:
remaining.append(p)
continue
if alpha != 0.0:
v = v[1]
w1 = cast_to_device(v[0], weight.device, torch.float32)
if w1.shape == weight.shape:
w_min = cast_to_device(v[1], weight.device, torch.float32)
w_max = cast_to_device(v[2], weight.device, torch.float32)
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
else:
print(
f"[ApplyFooocusInpaint] Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
)
if len(remaining) > 0:
return self.original_calculate_weight(remaining, weight, key, intermediate_dtype)
return weight
def __enter__(self):
try:
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
self.original_calculate_weight = comfy.lora.calculate_weight
comfy.lora.calculate_weight = self.calculate_weight_patched
except AttributeError:
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
self.original_calculate_weight = ModelPatcher.calculate_weight
ModelPatcher.calculate_weight = self.calculate_weight_patched
def __exit__(self, exc_type, exc_value, traceback):
try:
comfy.lora.calculate_weight = self.original_calculate_weight
except:
ModelPatcher.calculate_weight = self.original_calculate_weight
# def inject_patched_calculate_weight():
# global injected_model_patcher_calculate_weight
# if not injected_model_patcher_calculate_weight:
# try:
# print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
# original_calculate_weight = comfy.lora.calculate_weight
# comfy.lora.original_calculate_weight = original_calculate_weight
# comfy.lora.calculate_weight = calculate_weight_patched
# except AttributeError:
# print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
# original_calculate_weight = ModelPatcher.calculate_weight
# ModelPatcher.original_calculate_weight = original_calculate_weight
# ModelPatcher.calculate_weight = calculate_weight_patched
# injected_model_patcher_calculate_weight = True
class InpaintWorker:
def __init__(self, node_name):
self.node_name = node_name if node_name is not None else ""
def load_fooocus_patch(self, lora: dict, to_load: dict):
patch_dict = {}
loaded_keys = set()
for key in to_load.values():
if value := lora.get(key, None):
patch_dict[key] = ("fooocus", value)
loaded_keys.add(key)
not_loaded = sum(1 for x in lora if x not in loaded_keys)
if not_loaded > 0:
log_node_info(self.node_name,
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
)
return patch_dict
def _input_block_patch(self, h: torch.Tensor, transformer_options: dict):
if transformer_options["block"][1] == 0:
if self._inpaint_block is None or self._inpaint_block.shape != h.shape:
assert self._inpaint_head_feature is not None
batch = h.shape[0] // self._inpaint_head_feature.shape[0]
self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1)
h = h + self._inpaint_block
return h
def patch(self, model, latent, patch):
base_model: BaseModel = model.model
latent_pixels = base_model.process_latent_in(latent["samples"])
noise_mask = latent["noise_mask"].round()
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
inpaint_head_model, inpaint_lora = patch
feed = torch.cat([latent_mask, latent_pixels], dim=1)
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
self._inpaint_head_feature = inpaint_head_model(feed)
self._inpaint_block = None
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
lora_keys.update({x: x for x in base_model.state_dict().keys()})
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
m = model.clone()
m.set_model_input_block_patch(self._input_block_patch)
patched = m.add_patches(loaded_lora, 1.0)
m.model_options['transformer_options']['fooocus'] = True
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
if not_patched_count > 0:
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
# inject_patched_calculate_weight()
return (m,)

View File

@@ -0,0 +1,156 @@
import torch
import numpy as np
import cv2
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from .simple_extractor_dataset import SimpleFolderDataset
from .transforms import transform_logits
from tqdm import tqdm
from PIL import Image
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def delete_irregular(logits_result):
parsing_result = np.argmax(logits_result, axis=2)
upper_cloth = np.where(parsing_result == 4, 255, 0)
contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area = []
for i in range(len(contours)):
a = cv2.contourArea(contours[i], True)
area.append(abs(a))
if len(area) != 0:
top = area.index(max(area))
M = cv2.moments(contours[top])
cY = int(M["m01"] / M["m00"])
dresses = np.where(parsing_result == 7, 255, 0)
contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area_dress = []
for j in range(len(contours_dress)):
a_d = cv2.contourArea(contours_dress[j], True)
area_dress.append(abs(a_d))
if len(area_dress) != 0:
top_dress = area_dress.index(max(area_dress))
M_dress = cv2.moments(contours_dress[top_dress])
cY_dress = int(M_dress["m01"] / M_dress["m00"])
wear_type = "dresses"
if len(area) != 0:
if len(area_dress) != 0 and cY_dress > cY:
irregular_list = np.array([4, 5, 6])
logits_result[:, :, irregular_list] = -1
else:
irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13])
logits_result[:cY, :, irregular_list] = -1
wear_type = "cloth_pant"
parsing_result = np.argmax(logits_result, axis=2)
# pad border
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)
return parsing_result, wear_type
def hole_fill(img):
img_copy = img.copy()
mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
cv2.floodFill(img, mask, (0, 0), 255)
img_inverse = cv2.bitwise_not(img)
dst = cv2.bitwise_or(img_copy, img_inverse)
return dst
def refine_mask(mask):
contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area = []
for j in range(len(contours)):
a_d = cv2.contourArea(contours[j], True)
area.append(abs(a_d))
refine_mask = np.zeros_like(mask).astype(np.uint8)
if len(area) != 0:
i = area.index(max(area))
cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
# keep large area in skin case
for j in range(len(area)):
if j != i and area[i] > 2000:
cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1)
return refine_mask
def refine_hole(parsing_result_filled, parsing_result, arm_mask):
filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0),
np.where(parsing_result != 4, 255, 0)) - arm_mask * 255
contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8)
for i in range(len(contours)):
a = cv2.contourArea(contours[i], True)
# keep hole > 2000 pixels
if abs(a) > 2000:
cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1)
return refine_hole_mask + arm_mask
def onnx_inference(lip_session, input_dir, mask_components=[0]):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
input_size = [473, 473]
dataset_lip = SimpleFolderDataset(root=input_dir, input_size=input_size, transform=transform)
dataloader_lip = DataLoader(dataset_lip)
palette = get_palette(20)
with torch.no_grad():
for _, batch in enumerate(tqdm(dataloader_lip)):
image, meta = batch
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]
output = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)})
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
logits_result_lip = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h,
input_size=input_size)
parsing_result = np.argmax(logits_result_lip, axis=2)
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
mask = np.isin(output_img, mask_components).astype(np.uint8)
mask_image = Image.fromarray(mask * 255)
mask_image = mask_image.convert("RGB")
mask_image = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
output_img = output_img.convert('RGB')
output_img = torch.from_numpy(np.array(output_img).astype(np.float32) / 255.0).unsqueeze(0)
return output_img, mask_image

View File

@@ -0,0 +1,109 @@
import numpy as np
import torch
from PIL import Image
from .parsing_api import onnx_inference
from ...libs.utils import install_package
class HumanParsing:
def __init__(self, model_path):
self.model_path = model_path
self.session = None
def __call__(self, input_image, mask_components):
if self.session is None:
install_package('onnxruntime')
import onnxruntime as ort
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# session_options.add_session_config_entry('gpu_id', str(gpu_id))
self.session = ort.InferenceSession(self.model_path, sess_options=session_options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
parsed_image, mask = onnx_inference(self.session, input_image, mask_components)
return parsed_image, mask
class HumanParts:
def __init__(self, model_path):
self.model_path = model_path
self.session = None
# self.classes_dict = {
# "background": 0,
# "hair": 2,
# "glasses": 4,
# "top-clothes": 5,
# "bottom-clothes": 9,
# "torso-skin": 10,
# "face": 13,
# "left-arm": 14,
# "right-arm": 15,
# "left-leg": 16,
# "right-leg": 17,
# "left-foot": 18,
# "right-foot": 19,
# },
self.classes = [0, 13, 2, 4, 5, 9, 10, 14, 15, 16, 17, 18, 19]
def __call__(self, input_image, mask_components):
if self.session is None:
install_package('onnxruntime')
import onnxruntime as ort
self.session = ort.InferenceSession(self.model_path, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
mask, = self.get_mask(self.session, input_image, 0, mask_components)
return mask
def get_mask(self, model, image, rotation, mask_components):
image = image.squeeze(0)
image_np = image.numpy() * 255
pil_image = Image.fromarray(image_np.astype(np.uint8))
original_size = pil_image.size # to resize the mask later
# resize to 512x512 as the model expects
pil_image = pil_image.resize((512, 512))
center = (256, 256)
if rotation != 0:
pil_image = pil_image.rotate(rotation, center=center)
# normalize the image
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
image_np = np.expand_dims(image_np, axis=0)
# use the onnx model to get the mask
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
result = model.run([output_name], {input_name: image_np})
result = np.array(result[0]).argmax(axis=3).squeeze(0)
score: int = 0
mask = np.zeros_like(result)
for class_index in mask_components:
detected = result == self.classes[class_index]
mask[detected] = 255
score += mask.sum()
# back to the original size
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
if rotation != 0:
mask_image = mask_image.rotate(-rotation, center=center)
mask_image = mask_image.resize(original_size)
# and back to numpy...
mask = np.array(mask_image).astype(np.float32) / 255
# add 2 dimensions to match the expected output
mask = np.expand_dims(mask, axis=0)
mask = np.expand_dims(mask, axis=0)
# ensure to return a "binary mask_image"
del image_np, result # free up memory, maybe not necessary
return (torch.from_numpy(mask.astype(np.uint8)),)

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : dataset.py
@Time : 8/30/19 9:12 PM
@Desc : Dataset Definition
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import cv2
import numpy as np
from PIL import Image
from torch.utils import data
from .transforms import get_affine_transform
class SimpleFolderDataset(data.Dataset):
def __init__(self, root, input_size=[512, 512], transform=None):
self.root = root
self.input_size = input_size
self.transform = transform
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
self.input_size = np.asarray(input_size)
self.is_pil_image = False
if isinstance(root, Image.Image):
self.file_list = [root]
self.is_pil_image = True
elif os.path.isfile(root):
self.file_list = [os.path.basename(root)]
self.root = os.path.dirname(root)
else:
self.file_list = os.listdir(self.root)
def __len__(self):
return len(self.file_list)
def _box2cs(self, box):
x, y, w, h = box[:4]
return self._xywh2cs(x, y, w, h)
def _xywh2cs(self, x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array([w, h], dtype=np.float32)
return center, scale
def __getitem__(self, index):
if self.is_pil_image:
img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
else:
img_name = self.file_list[index]
img_path = os.path.join(self.root, img_name)
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
h, w, _ = img.shape
# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0
trans = get_affine_transform(person_center, s, r, self.input_size)
input = cv2.warpAffine(
img,
trans,
(int(self.input_size[1]), int(self.input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))
input = self.transform(input)
meta = {
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}
return input, meta

View File

@@ -0,0 +1,167 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import torch
class BRG2Tensor_transform(object):
def __call__(self, pic):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
if isinstance(img, torch.ByteTensor):
return img.float()
else:
return img
class BGR2RGB_transform(object):
def __call__(self, tensor):
return tensor[[2,1,0],:,:]
def flip_back(output_flipped, matched_parts):
'''
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
'''
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""
flip coords
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
return joints*joints_vis, joints_vis
def transform_preds(coords, center, scale, input_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def transform_parsing(pred, center, scale, width, height, input_size):
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
target_pred = cv2.warpAffine(
pred,
trans,
(int(width), int(height)), #(int(width), int(height)),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0))
return target_pred
def transform_logits(logits, center, scale, width, height, input_size):
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
channel = logits.shape[2]
target_logits = []
for i in range(channel):
target_logit = cv2.warpAffine(
logits[:,:,i],
trans,
(int(width), int(height)), #(int(width), int(height)),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0))
target_logits.append(target_logit)
target_logits = np.stack(target_logits,axis=2)
return target_logits
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[1]
dst_h = output_size[0]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, (dst_w-1) * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [(dst_w-1) * 0.5, (dst_h-1) * 0.5]
dst[1, :] = np.array([(dst_w-1) * 0.5, (dst_h-1) * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[1]), int(output_size[0])),
flags=cv2.INTER_LINEAR)
return dst_img

View File

@@ -0,0 +1,184 @@
#credit to huchenlei for this module
#from https://github.com/huchenlei/ComfyUI-IC-Light-Native
import torch
import numpy as np
from typing import Tuple, TypedDict, Callable
import comfy.model_management
from comfy.sd import load_unet
from comfy.ldm.models.autoencoder import AutoencoderKL
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from PIL import Image
from nodes import VAEEncode
from ...libs.image import np2tensor, pil2tensor
class UnetParams(TypedDict):
input: torch.Tensor
timestep: torch.Tensor
c: dict
cond_or_uncond: torch.Tensor
class VAEEncodeArgMax(VAEEncode):
def encode(self, vae, pixels):
assert isinstance(
vae.first_stage_model, AutoencoderKL
), "ArgMax only supported for AutoencoderKL"
original_sample_mode = vae.first_stage_model.regularization.sample
vae.first_stage_model.regularization.sample = False
ret = super().encode(vae, pixels)
vae.first_stage_model.regularization.sample = original_sample_mode
return ret
class ICLight:
@staticmethod
def apply_c_concat(params: UnetParams, concat_conds) -> UnetParams:
"""Apply c_concat on unet call."""
sample = params["input"]
params["c"]["c_concat"] = torch.cat(
(
[concat_conds.to(sample.device)]
* (sample.shape[0] // concat_conds.shape[0])
),
dim=0,
)
return params
@staticmethod
def create_custom_conv(
original_conv: torch.nn.Module,
dtype: torch.dtype,
device=torch.device,
) -> torch.nn.Module:
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(
8,
original_conv.out_channels,
original_conv.kernel_size,
original_conv.stride,
original_conv.padding,
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(original_conv.weight)
new_conv_in.bias = original_conv.bias
return new_conv_in.to(dtype=dtype, device=device)
def generate_lighting_image(self, original_image, direction):
_, image_height, image_width, _ = original_image.shape
if direction == 'Left Light':
gradient = np.linspace(255, 0, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Right Light':
gradient = np.linspace(0, 255, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Top Light':
gradient = np.linspace(255, 0, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Bottom Light':
gradient = np.linspace(0, 255, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Circle Light':
x = np.linspace(-1, 1, image_width)
y = np.linspace(-1, 1, image_height)
x, y = np.meshgrid(x, y)
r = np.sqrt(x ** 2 + y ** 2)
r = r / r.max()
color1 = np.array([0, 0, 0])[np.newaxis, np.newaxis, :]
color2 = np.array([255, 255, 255])[np.newaxis, np.newaxis, :]
gradient = (color1 * r[..., np.newaxis] + color2 * (1 - r)[..., np.newaxis]).astype(np.uint8)
image = pil2tensor(Image.fromarray(gradient))
return image
else:
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
return image
def generate_source_image(self, original_image, source):
batch_size, image_height, image_width, _ = original_image.shape
if source == 'Use Flipped Background Image':
if batch_size < 2:
raise ValueError('Must be at least 2 image to use flipped background image.')
original_image = [img.unsqueeze(0) for img in original_image]
image = torch.flip(original_image[1], [2])
return image
elif source == 'Ambient':
input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
return np2tensor(input_bg)
elif source == 'Left Light':
gradient = np.linspace(224, 32, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Right Light':
gradient = np.linspace(32, 224, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Top Light':
gradient = np.linspace(224, 32, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Bottom Light':
gradient = np.linspace(32, 224, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
else:
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
return image
def apply(self, ic_model_path, model, c_concat: dict, ic_model=None) -> Tuple[ModelPatcher]:
device = comfy.model_management.get_torch_device()
dtype = comfy.model_management.unet_dtype()
work_model = model.clone()
# Apply scale factor.
base_model: BaseModel = work_model.model
scale_factor = base_model.model_config.latent_format.scale_factor
# [B, 4, H, W]
concat_conds: torch.Tensor = c_concat["samples"] * scale_factor
# [1, 4 * B, H, W]
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
def unet_dummy_apply(unet_apply: Callable, params: UnetParams):
"""A dummy unet apply wrapper serving as the endpoint of wrapper
chain."""
return unet_apply(x=params["input"], t=params["timestep"], **params["c"])
existing_wrapper = work_model.model_options.get(
"model_function_wrapper", unet_dummy_apply
)
def wrapper_func(unet_apply: Callable, params: UnetParams):
return existing_wrapper(unet_apply, params=self.apply_c_concat(params, concat_conds))
work_model.set_model_unet_function_wrapper(wrapper_func)
if not ic_model:
ic_model = load_unet(ic_model_path)
ic_model_state_dict = ic_model.model.diffusion_model.state_dict()
work_model.add_patches(
patches={
("diffusion_model." + key): (
'diff',
[
value.to(dtype=dtype, device=device),
{"pad_weight": key == 'input_blocks.0.0.weight'}
]
)
for key, value in ic_model_state_dict.items()
}
)
return (work_model, ic_model)

View File

@@ -0,0 +1,268 @@
#credit to shakker-labs and instantX for this module
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
import torch
from PIL import Image
import numpy as np
from .attention_processor import IPAFluxAttnProcessor2_0
from .utils import is_model_pathched, FluxUpdateModules
from .sd3.resampler import TimeResampler
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
image_proj_model = None
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class InstantXFluxIpadapterApply:
def __init__(self, num_tokens=128):
self.device = None
self.dtype = torch.float16
self.num_tokens = num_tokens
self.ip_ckpt = None
self.clip_vision = None
self.image_encoder = None
self.clip_image_processor = None
# state_dict
self.state_dict = None
self.joint_attention_dim = 4096
self.hidden_size = 3072
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
s = flux_model.model_sampling
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
percent_to_timestep_function(timestep_percent_range[1]))
ip_attn_procs = {} # 19+38=57
dsb_count = len(flux_model.diffusion_model.double_blocks)
for i in range(dsb_count):
name = f"double_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
ssb_count = len(flux_model.diffusion_model.single_blocks)
for i in range(ssb_count):
name = f"single_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
return ip_attn_procs
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
global image_proj_model
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
return ip_attn_procs
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
# outputs = self.clip_vision.encode_image(pil_image)
# clip_image_embeds = outputs['image_embeds']
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
global image_proj_model
image_prompt_embeds = image_proj_model(clip_image_embeds)
return image_prompt_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
# self.clip_vision = ipadapter["clipvision"]['model']
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
# process image
pil_image = image.numpy()[0] * 255.0
pil_image = Image.fromarray(pil_image.astype(np.uint8))
# initialize ipadapter
global image_proj_model
if image_proj_model is None:
image_proj_model = MLPProjModel(
cross_attention_dim=self.joint_attention_dim, # 4096
id_embeddings_dim=1152,
num_tokens=self.num_tokens,
)
image_proj_model.to(self.device, dtype=self.dtype)
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
# process control image
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
# set model
# is_patched = is_model_pathched(model.model)
bi = model.clone()
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
return (bi, image)
def patch_sd3(
patcher,
ip_procs,
resampler: TimeResampler,
clip_embeds,
weight=1.0,
start=0.0,
end=1.0,
):
"""
Patches a model_sampler to add the ipadapter
"""
mmdit = patcher.model.diffusion_model
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
"timesteps", 1000
)
# hook the model's forward function
# so that when it gets called, we can grab the timestep and send it to the resampler
ip_options = {
"hidden_states": None,
"t_emb": None,
"weight": weight,
}
def ddit_wrapper(forward, args):
# this is between 0 and 1, so the adapters can calculate start_point and end_point
# actually, do we need to get the sigma value instead?
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
if start <= t_percent <= end:
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
embeds = clip_embeds[args["cond_or_uncond"]]
# slight efficiency optimization todo: pass the embeds through and then afterwards
# repeat to the batch size
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
# the resampler wants between 0 and MAX_STEPS
timestep = args["timestep"] * timestep_schedule_max
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
# these will need to be accessible to the IPAdapters
ip_options["hidden_states"] = image_emb
ip_options["t_emb"] = t_emb
else:
ip_options["hidden_states"] = None
ip_options["t_emb"] = None
return forward(args["input"], args["timestep"], **args["c"])
patcher.set_model_unet_function_wrapper(ddit_wrapper)
# patch each dit block
for i, block in enumerate(mmdit.joint_blocks):
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
class InstantXSD3IpadapterApply:
def __init__(self):
self.device = None
self.dtype = torch.float16
self.clip_image_processor = None
self.image_encoder = None
self.resampler = None
self.procs = None
@torch.inference_mode()
def encode(self, image):
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype),
output_hidden_states=True,
).hidden_states[-2]
clip_image_embeds = torch.cat(
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
)
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
return clip_image_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
self.resampler = TimeResampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=64,
embedding_dim=1152,
output_dim=2432,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
)
self.resampler.eval()
self.resampler.to(self.device, dtype=self.dtype)
self.resampler.load_state_dict(self.state_dict["image_proj"])
# now we'll create the attention processors
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
n_procs = len(
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
)
self.procs = torch.nn.ModuleList(
[
# this is hardcoded for SD3.5L
IPAttnProcessor(
hidden_size=2432,
cross_attention_dim=2432,
ip_hidden_states_dim=2432,
ip_encoder_hidden_states_dim=2432,
head_dim=64,
timesteps_emb_dim=1280,
).to(self.device, dtype=torch.float16)
for _ in range(n_procs)
]
)
self.procs.load_state_dict(self.state_dict["ip_adapter"])
work_model = model.clone()
embeds = self.encode(image)
patch_sd3(
work_model,
self.procs,
self.resampler,
embeds,
weight,
start_at,
end_at,
)
return (work_model, image)

View File

@@ -0,0 +1,87 @@
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class IPAFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, timestep_range=None):
super().__init__()
self.hidden_size = hidden_size # 3072
self.cross_attention_dim = cross_attention_dim # 4096
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.timestep_range = timestep_range
def __call__(
self,
num_heads,
query,
image_emb: torch.FloatTensor,
t: torch.FloatTensor
) -> torch.FloatTensor:
# only apply IPA if timestep is within range
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
return None
# `ip-adapter` projections
ip_hidden_states = image_emb
ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states)
ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states)
ip_hidden_states_key_proj = rearrange(ip_hidden_states_key_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_value_proj = rearrange(ip_hidden_states_value_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj)
ip_hidden_states_value_proj = self.norm_added_v(ip_hidden_states_value_proj)
ip_hidden_states = F.scaled_dot_product_attention(query.to(image_emb.device).to(image_emb.dtype),
ip_hidden_states_key_proj,
ip_hidden_states_value_proj,
dropout_p=0.0, is_causal=False)
ip_hidden_states = rearrange(ip_hidden_states, "B H L D -> B L (H D)", H=num_heads)
ip_hidden_states = ip_hidden_states.to(query.dtype).to(query.device)
return self.scale * ip_hidden_states

View File

@@ -0,0 +1,153 @@
import torch
from torch import Tensor, nn
from .math import attention
from ..attention_processor import IPAFluxAttnProcessor2_0
from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock
from comfy import model_management as mm
class DoubleStreamBlockIPA(nn.Module):
def __init__(self, original_block: DoubleStreamBlock, ip_adapter, image_emb):
super().__init__()
mlp_hidden_dim = original_block.img_mlp[0].out_features
mlp_ratio = mlp_hidden_dim / original_block.hidden_size
mlp_hidden_dim = int(original_block.hidden_size * mlp_ratio)
self.num_heads = original_block.num_heads
self.hidden_size = original_block.hidden_size
self.img_mod = original_block.img_mod
self.img_norm1 = original_block.img_norm1
self.img_attn = original_block.img_attn
self.img_norm2 = original_block.img_norm2
self.img_mlp = original_block.img_mlp
self.txt_mod = original_block.txt_mod
self.txt_norm1 = original_block.txt_norm1
self.txt_attn = original_block.txt_attn
self.txt_norm2 = original_block.txt_norm2
self.txt_mlp = original_block.txt_mlp
self.flipped_img_txt = original_block.flipped_img_txt
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
ip_hidden_states = adapter(self.num_heads, img_q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
img_attn = img_attn + ip_hidden_states
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlockIPA(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(self, original_block: SingleStreamBlock, ip_adapter, image_emb):
super().__init__()
self.hidden_dim = original_block.hidden_size
self.num_heads = original_block.num_heads
self.scale = original_block.scale
self.mlp_hidden_dim = original_block.mlp_hidden_dim
# qkv and mlp_in
self.linear1 = original_block.linear1
# proj and mlp_out
self.linear2 = original_block.linear2
self.norm = original_block.norm
self.hidden_size = original_block.hidden_size
self.pre_norm = original_block.pre_norm
self.mlp_act = original_block.mlp_act
self.modulation = original_block.modulation
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def add_adapter(self, ip_adapter: IPAFluxAttnProcessor2_0, image_emb):
self.ip_adapter.append(ip_adapter)
self.image_emb.append(image_emb)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
# maybe we want a single joint attention call for all adapters?
ip_hidden_states = adapter(self.num_heads, q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
attn = attn + ip_hidden_states
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@@ -0,0 +1,35 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,219 @@
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import (RMSNorm, JointBlock,)
class AdaLayerNorm(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
super().__init__()
self.silu = nn.SiLU()
num_params_dict = dict(
zero=6,
normal=2,
)
num_params = num_params_dict[mode]
self.linear = nn.Linear(
time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True
)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.mode = mode
def forward(
self,
x,
hidden_dtype=None,
emb=None,
):
emb = self.linear(self.silu(emb))
if self.mode == "normal":
shift_msa, scale_msa = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x
elif self.mode == "zero":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
6, dim=1
)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class IPAttnProcessor(nn.Module):
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
ip_hidden_states_dim=None,
ip_encoder_hidden_states_dim=None,
head_dim=None,
timesteps_emb_dim=1280,
):
super().__init__()
self.norm_ip = AdaLayerNorm(
ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim
)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
def forward(
self,
ip_hidden_states,
img_query,
img_key=None,
img_value=None,
t_emb=None,
n_heads=1,
):
if ip_hidden_states is None:
return None
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
return None
# norm ip input
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=t_emb)
# to k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# reshape
img_query = rearrange(img_query, "b l (h d) -> b h l d", h=n_heads)
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=n_heads)
# note that the image is in a different shape: b l h d
# so we transpose to b h l d
# or do we have to transpose here?
img_value = torch.transpose(img_value, 1, 2)
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=n_heads)
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=n_heads)
# norm
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
#
ip_hidden_states = F.scaled_dot_product_attention(
img_query, key, value, dropout_p=0.0, is_causal=False
)
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
return ip_hidden_states
class JointBlockIPWrapper:
"""To be used as a patch_replace with Comfy"""
def __init__(
self,
original_block: JointBlock,
adapter: IPAttnProcessor,
ip_options=None,
):
self.original_block = original_block
self.adapter = adapter
if ip_options is None:
ip_options = {}
self.ip_options = ip_options
def block_mixing(self, context, x, context_block, x_block, c):
"""
Comes from mmdit.py. Modified to add ipadapter attention.
"""
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
qkv = tuple(torch.cat((context_qkv[j], x_qkv[j]), dim=1) for j in range(3))
attn = optimized_attention(
qkv[0],
qkv[1],
qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
# if the current timestep is not in the ipadapter enabling range, then the resampler wasn't run
# and the hidden states will be None
if (
self.ip_options["hidden_states"] is not None
and self.ip_options["t_emb"] is not None
):
# IP-Adapter
ip_attn = self.adapter(
self.ip_options["hidden_states"],
*x_qkv,
self.ip_options["t_emb"],
x_block.attn.num_heads,
)
x_attn = x_attn + ip_attn * self.ip_options["weight"]
# Everything else is unchanged
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0],
x_qkv2[1],
x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
def __call__(self, args, _):
# Code from mmdit.py:
# in this case, we're blocks_replace[("double_block", i)]
# note that although we're passed the original block,
# we can't actually get it from inside its wrapper
# (which would simplify the whole code...)
# ```
# def block_wrap(args):
# out = {}
# out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
# return out
# out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
# context = out["txt"]
# x = out["img"]
# ```
c, x = self.block_mixing(
args["txt"],
args["img"],
self.original_block.context_block,
self.original_block.x_block,
c=args["vec"],
)
return {"txt": c, "img": x}

View File

@@ -0,0 +1,385 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch
import torch.nn as nn
from typing import Optional
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, shift=None, scale=None):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
if shift is not None and scale is not None:
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
*args,
**kwargs,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class TimeResampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
# msa
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
# ff
FeedForward(dim=dim, mult=ff_mult),
# adaLN
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)),
]
)
)
# time
self.time_proj = Timesteps(
timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift
)
self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
# adaLN
# self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
# )
def forward(self, x, timestep, need_temb=False):
timestep_emb = self.embedding_time(x, timestep) # bs, dim
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for attn, ff, adaLN_modulation in self.layers:
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(
timestep_emb
).chunk(4, dim=1)
latents = attn(x, latents, shift_msa, scale_msa) + latents
res = latents
for idx_ff in range(len(ff)):
layer_ff = ff[idx_ff]
latents = layer_ff(latents)
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
latents = latents * (
1 + scale_mlp.unsqueeze(1)
) + shift_mlp.unsqueeze(1)
latents = latents + res
# latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
if need_temb:
return latents, timestep_emb
else:
return latents
def embedding_time(self, sample, timestep):
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, None)
return emb

View File

@@ -0,0 +1,136 @@
import torch
from torch import Tensor
from .flux.layers import DoubleStreamBlockIPA, SingleStreamBlockIPA
from comfy.ldm.flux.layers import timestep_embedding
from types import MethodType
def FluxUpdateModules(bi, ip_attn_procs, image_emb):
flux_model = bi.model
bi.add_object_patch(f"diffusion_model.forward_orig", MethodType(forward_orig_ipa, flux_model.diffusion_model))
for i, original in enumerate(flux_model.diffusion_model.double_blocks):
patch_name = f"double_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
# if there's already a patch there, collect its adapters and replace it
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, DoubleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = DoubleStreamBlockIPA(original, procs, embs)
# for example, ComfyUI internally uses model.add_patches to add loras
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
for i, original in enumerate(flux_model.diffusion_model.single_blocks):
patch_name = f"single_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, SingleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = SingleStreamBlockIPA(original, procs, embs)
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
def is_model_pathched(model):
def test(mod):
if isinstance(mod, DoubleStreamBlockIPA):
return True
else:
for p in mod.children():
if test(p):
return True
return False
result = test(model)
return result
def forward_orig_ipa(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor|None = None,
control=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, SingleStreamBlockIPA): # ipadaper
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
if isinstance(block, SingleStreamBlockIPA): # ipadaper
img = block(img, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@@ -0,0 +1,42 @@
{
"_name_or_path": "THUDM/chatglm3-6b-base",
"model_type": "chatglm",
"architectures": [
"ChatGLMModel"
],
"auto_map": {
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
},
"add_bias_linear": false,
"add_qkv_bias": true,
"apply_query_key_layer_scaling": true,
"apply_residual_connection_post_layernorm": false,
"attention_dropout": 0.0,
"attention_softmax_in_fp32": true,
"bias_dropout_fusion": true,
"ffn_hidden_size": 13696,
"fp32_residual_connection": false,
"hidden_dropout": 0.0,
"hidden_size": 4096,
"kv_channels": 128,
"layernorm_epsilon": 1e-05,
"multi_query_attention": true,
"multi_query_group_num": 2,
"num_attention_heads": 32,
"num_layers": 28,
"original_rope": true,
"padded_vocab_size": 65024,
"post_layer_norm": true,
"rmsnorm": true,
"seq_length": 32768,
"use_cache": true,
"torch_dtype": "float16",
"transformers_version": "4.30.2",
"tie_word_embeddings": false,
"eos_token_id": 2,
"pad_token_id": 0
}

View File

@@ -0,0 +1,60 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,300 @@
import json
import os
import re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs

View File

@@ -0,0 +1,12 @@
{
"name_or_path": "THUDM/chatglm3-6b-base",
"remove_space": false,
"do_lower_case": false,
"tokenizer_class": "ChatGLMTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_chatglm.ChatGLMTokenizer",
null
]
}
}

View File

@@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}

View File

@@ -0,0 +1,303 @@
import json
import os
import torch
import subprocess
import sys
import comfy.supported_models
import comfy.model_patcher
import comfy.model_management
import comfy.model_detection as model_detection
import comfy.model_base as model_base
from comfy.model_base import sdxl_pooled, CLIPEmbeddingNoiseAugmentation, Timestep, ModelType
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.clip_vision import ClipVisionModel, Output
from comfy.utils import load_torch_file
from .chatglm.modeling_chatglm import ChatGLMModel, ChatGLMConfig
from .chatglm.tokenization_chatglm import ChatGLMTokenizer
class KolorsUNetModel(UNetModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_hid_proj = torch.nn.Linear(4096, 2048, bias=True)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
if "context" in kwargs:
kwargs["context"] = self.encoder_hid_proj(kwargs["context"])
result = super().forward(*args, **kwargs)
return result
class KolorsSDXL(model_base.SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
model_base.BaseModel.__init__(self, model_config, model_type, device=device, unet_model=KolorsUNetModel)
self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs):
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(
dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class Kolors(comfy.supported_models.SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 5632,
"use_temporal_attention": False,
}
def get_model(self, state_dict, prefix="", device=None):
out = KolorsSDXL(self, model_type=self.model_type(state_dict, prefix), device=device, )
out.__class__ = model_base.SDXL
if self.inpaint_model():
out.set_inpaint()
return out
def kolors_unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
attn_res = 1
count_blocks = model_detection.count_blocks
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(
state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(
state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(
state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(
i, ab)].shape[1]
attn_res *= 2
if attn_blocks == 0:
for i in range(res_blocks):
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
Kolors = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 9,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 8,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1,
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [Kolors, Kolors_inpaint,
Kolors_ip2p, SDXL, SDXL_mid_cnet, SDXL_small_cnet]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
# print("key {} does not match".format(k), match[k], "||", unet_config[k])
matches = False
break
if matches:
return model_detection.convert_config(unet_config)
return None
# chatglm3 model
class chatGLM3Model(torch.nn.Module):
def __init__(self, textmodel_json_config=None, device='cpu', offload_device='cpu', model_path=None):
super().__init__()
if model_path is None:
raise ValueError("model_path is required")
self.device = device
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"chatglm",
"config_chatglm.json"
)
with open(textmodel_json_config, 'r') as file:
config = json.load(file)
textmodel_json_config = ChatGLMConfig(**config)
is_accelerate_available = False
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
pass
from contextlib import nullcontext
with (init_empty_weights() if is_accelerate_available else nullcontext()):
with torch.no_grad():
print('torch version:', torch.__version__)
self.text_encoder = ChatGLMModel(textmodel_json_config).eval()
if '4bit' in model_path:
try:
import cpm_kernels
except ImportError:
print("Installing cpm_kernels...")
subprocess.run([sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True)
pass
self.text_encoder.quantize(4)
elif '8bit' in model_path:
self.text_encoder.quantize(8)
sd = load_torch_file(model_path)
if is_accelerate_available:
for key in sd:
set_module_tensor_to_device(self.text_encoder, key, device=offload_device, value=sd[key])
else:
print("WARNING: Accelerate not available, use load_state_dict load model")
self.text_encoder.load_state_dict()
def load_chatglm3(model_path=None):
if model_path is None:
return
load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
glm3model = chatGLM3Model(
device=load_device,
offload_device=offload_device,
model_path=model_path
)
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'chatglm', "tokenizer")
tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
text_encoder = glm3model.text_encoder
return {"text_encoder":text_encoder, "tokenizer":tokenizer}
# clipvision model
def load_clipvision_vitl_336(path):
sd = load_torch_file(path)
if "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
raise Exception("Unsupported clip vision model")
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
return clip
class applyKolorsUnet:
def __enter__(self):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.utils
import comfy.clip_vision
self.original_UNET_MAP_BASIC = comfy.utils.UNET_MAP_BASIC.copy()
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.weight", "encoder_hid_proj.weight"),)
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.bias", "encoder_hid_proj.bias"),)
self.original_unet_config_from_diffusers_unet = model_detection.unet_config_from_diffusers_unet
model_detection.unet_config_from_diffusers_unet = kolors_unet_config_from_diffusers_unet
import comfy.supported_models
self.original_supported_models = comfy.supported_models.models
comfy.supported_models.models = [Kolors]
self.original_load_clipvision_from_sd = comfy.clip_vision.load_clipvision_from_sd
comfy.clip_vision.load_clipvision_from_sd = load_clipvision_vitl_336
def __exit__(self, type, value, traceback):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.utils
import comfy.supported_models
import comfy.clip_vision
comfy.utils.UNET_MAP_BASIC = self.original_UNET_MAP_BASIC
model_detection.unet_config_from_diffusers_unet = self.original_unet_config_from_diffusers_unet
comfy.supported_models.models = self.original_supported_models
comfy.clip_vision.load_clipvision_from_sd = self.original_load_clipvision_from_sd
def is_kolors_model(model):
unet_config = model.model.model_config.unet_config if hasattr(model, 'model') else None
if unet_config and "adm_in_channels" in unet_config and unet_config["adm_in_channels"] == 5632:
return True
else:
return False

View File

@@ -0,0 +1,66 @@
import torch
from torch.nn import Linear
from types import MethodType
import comfy.model_management
import comfy.samplers
from comfy.cldm.cldm import ControlNet
from comfy.controlnet import ControlLora
def patch_controlnet(model, control_net):
import comfy.controlnet
if isinstance(control_net, ControlLora):
del_keys = []
for k in control_net.control_weights:
if k.startswith("label_emb.0.0."):
del_keys.append(k)
for k in del_keys:
control_net.control_weights.pop(k)
super_pre_run = ControlLora.pre_run
super_copy = ControlLora.copy
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = model.model.diffusion_model.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
def KolorsControlLora_pre_run(self, *args, **kwargs):
result = super_pre_run(self, *args, **kwargs)
if hasattr(self, "control_model"):
self.control_model.forward = MethodType(
KolorsControlNet_forward, self.control_model)
return result
control_net.pre_run = MethodType(
KolorsControlLora_pre_run, control_net)
def KolorsControlLora_copy(self, *args, **kwargs):
c = super_copy(self, *args, **kwargs)
c.pre_run = MethodType(
KolorsControlLora_pre_run, c)
return c
control_net.copy = MethodType(KolorsControlLora_copy, control_net)
elif isinstance(control_net, comfy.controlnet.ControlNet):
model_label_emb = model.model.diffusion_model.label_emb
control_net.control_model.label_emb = model_label_emb
control_net.control_model_wrapped.model.label_emb = model_label_emb
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = model.model.diffusion_model.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
control_net.control_model.forward = MethodType(
KolorsControlNet_forward, control_net.control_model)
else:
raise NotImplementedError(f"Type {control_net} not supported for KolorsControlNetPatch")
return control_net

View File

@@ -0,0 +1,105 @@
import re
import random
import gc
import comfy.model_management as mm
from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
if clean_gpu:
mm.unload_all_models()
mm.soft_empty_cache()
# Function to randomly select an option from the brackets
def choose_random_option(match):
options = match.group(1).split('|')
return random.choice(options)
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt)
if "|" in prompt:
prompt = prompt.split("|")
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
# Define tokenizers and text encoders
tokenizer = chatglm3_model['tokenizer']
text_encoder = chatglm3_model['text_encoder']
text_encoder.to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True)
# [batch_size, 77, 4096]
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
bs_embed = text_proj.shape[0]
text_proj = text_proj.repeat(1, 1).view(bs_embed, -1)
text_encoder.to(offload_device)
if clean_gpu:
mm.soft_empty_cache()
gc.collect()
return [[prompt_embeds, {"pooled_output": text_proj},]]
def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False):
time_start = 0
time_end = 1
match = re.search(r'TIMESTEP.*$', text)
if match:
timestep = match.group()
timestep = timestep.split(' ')
timestep = timestep[0]
text = text.replace(timestep, '')
value = timestep.split(':')
if len(value) >= 3:
time_start = float(value[1])
time_end = float(value[2])
elif len(value) == 2:
time_start = float(value[1])
time_end = 1
elif len(value) == 1:
time_start = 0.1
time_end = 1
pass3 = [x.strip() for x in text.split("BREAK")]
pass3 = [x for x in pass3 if x != '']
if len(pass3) == 0:
pass3 = ['']
conditioning = None
for text in pass3:
cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu)
if conditioning is not None:
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
else:
conditioning = cond
# setTimeStepRange
if time_start > 0 or time_end < 1:
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
return conditioning

View File

@@ -0,0 +1,219 @@
#credit to huchenlei for this module
#from https://github.com/huchenlei/ComfyUI-layerdiffuse
import torch
import comfy.model_management
import comfy.lora
import copy
from typing import Optional
from enum import Enum
from comfy.utils import load_torch_file
from comfy.conds import CONDRegular
from comfy_extras.nodes_compositing import JoinImageWithAlpha
try:
from .model import ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel
except:
ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel = None, None, None
from .attension_sharing import AttentionSharingPatcher
from ...config import LAYER_DIFFUSION, LAYER_DIFFUSION_DIR, LAYER_DIFFUSION_VAE
from ...libs.utils import to_lora_patch_dict, get_local_filepath, get_sd_version
load_layer_model_state_dict = load_torch_file
class LayerMethod(Enum):
FG_ONLY_ATTN = "Attention Injection"
FG_ONLY_CONV = "Conv Injection"
FG_TO_BLEND = "Foreground"
FG_BLEND_TO_BG = "Foreground to Background"
BG_TO_BLEND = "Background"
BG_BLEND_TO_FG = "Background to Foreground"
EVERYTHING = "Everything"
class LayerDiffuse:
def __init__(self) -> None:
self.vae_transparent_decoder = None
self.frames = 1
def get_layer_diffusion_method(self, method, has_blend_latent):
method = LayerMethod(method)
if method == LayerMethod.BG_TO_BLEND and has_blend_latent:
method = LayerMethod.BG_BLEND_TO_FG
elif method == LayerMethod.FG_TO_BLEND and has_blend_latent:
method = LayerMethod.FG_BLEND_TO_BG
return method
def apply_layer_c_concat(self, cond, uncond, c_concat):
def write_c_concat(cond):
new_cond = []
for t in cond:
n = [t[0], t[1].copy()]
if "model_conds" not in n[1]:
n[1]["model_conds"] = {}
n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat)
new_cond.append(n)
return new_cond
return (write_c_concat(cond), write_c_concat(uncond))
def apply_layer_diffusion(self, model, method, weight, samples, blend_samples, positive, negative, image=None, additional_cond=(None, None, None)):
control_img: Optional[torch.TensorType] = None
sd_version = get_sd_version(model)
model_url = LAYER_DIFFUSION[method.value][sd_version]["model_url"]
if image is not None:
image = image.movedim(-1, 1)
try:
if hasattr(comfy.lora, "calculate_weight"):
comfy.lora.calculate_weight = calculate_weight_adjust_channel(comfy.lora.calculate_weight)
else:
ModelPatcher.calculate_weight = calculate_weight_adjust_channel(ModelPatcher.calculate_weight)
except:
pass
if method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN] and sd_version == 'sd1':
self.frames = 1
elif method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG, LayerMethod.FG_BLEND_TO_BG] and sd_version == 'sd1':
self.frames = 2
batch_size, _, height, width = samples['samples'].shape
if batch_size % 2 != 0:
raise Exception(f"The batch size should be a multiple of 2. 批次大小需为2的倍数")
control_img = image
elif method == LayerMethod.EVERYTHING and sd_version == 'sd1':
batch_size, _, height, width = samples['samples'].shape
self.frames = 3
if batch_size % 3 != 0:
raise Exception(f"The batch size should be a multiple of 3. 批次大小需为3的倍数")
if model_url is None:
raise Exception(f"{method.value} is not supported for {sd_version} model")
model_path = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
layer_lora_state_dict = load_layer_model_state_dict(model_path)
work_model = model.clone()
if sd_version == 'sd1':
patcher = AttentionSharingPatcher(
work_model, self.frames, use_control=control_img is not None
)
patcher.load_state_dict(layer_lora_state_dict, strict=True)
if control_img is not None:
patcher.set_control(control_img)
else:
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
work_model.add_patches(layer_lora_patch_dict, weight)
# cond_contact
if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV]:
samp_model = work_model
elif sd_version == 'sdxl':
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND]:
c_concat = model.model.latent_format.process_in(samples["samples"])
else:
c_concat = model.model.latent_format.process_in(torch.cat([samples["samples"], blend_samples["samples"]], dim=1))
samp_model, positive, negative = (work_model,) + self.apply_layer_c_concat(positive, negative, c_concat)
elif sd_version == 'sd1':
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG]:
additional_cond = (additional_cond[0], None)
elif method in [LayerMethod.FG_TO_BLEND, LayerMethod.FG_BLEND_TO_BG]:
additional_cond = (additional_cond[1], None)
work_model.model_options.setdefault("transformer_options", {})
work_model.model_options["transformer_options"]["cond_overwrite"] = [
cond[0][0] if cond is not None else None
for cond in additional_cond
]
samp_model = work_model
return samp_model, positive, negative
def join_image_with_alpha(self, image, alpha):
out = image.movedim(-1, 1)
if out.shape[1] == 3: # RGB
out = torch.cat([out, torch.ones_like(out[:, :1, :, :])], dim=1)
for i in range(out.shape[0]):
out[i, 3, :, :] = alpha
return out.movedim(1, -1)
def image_to_alpha(self, image, latent):
pixel = image.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W]
decoded = []
sub_batch_size = 16
for start_idx in range(0, latent.shape[0], sub_batch_size):
decoded.append(
self.vae_transparent_decoder.decode_pixel(
pixel[start_idx: start_idx + sub_batch_size],
latent[start_idx: start_idx + sub_batch_size],
)
)
pixel_with_alpha = torch.cat(decoded, dim=0)
# [B, C, H, W] => [B, H, W, C]
pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
image = pixel_with_alpha[..., 1:]
alpha = pixel_with_alpha[..., 0]
alpha = 1.0 - alpha
try:
new_images, = JoinImageWithAlpha().execute(image, alpha)
except:
new_images, = JoinImageWithAlpha().join_image_with_alpha(image, alpha)
return new_images, alpha
def make_3d_mask(self, mask):
if len(mask.shape) == 4:
return mask.squeeze(0)
elif len(mask.shape) == 2:
return mask.unsqueeze(0)
return mask
def masks_to_list(self, masks):
if masks is None:
empty_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
return ([empty_mask],)
res = []
for mask in masks:
res.append(mask)
return [self.make_3d_mask(x) for x in res]
def layer_diffusion_decode(self, layer_diffusion_method, latent, blend_samples, samp_images, model):
alpha = []
if layer_diffusion_method is not None:
sd_version = get_sd_version(model)
if sd_version not in ['sdxl', 'sd1']:
raise Exception(f"Only SDXL and SD1.5 model supported for Layer Diffusion")
method = self.get_layer_diffusion_method(layer_diffusion_method, blend_samples is not None)
sd15_allow = True if sd_version == 'sd1' and method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.EVERYTHING, LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG] else False
sdxl_allow = True if sd_version == 'sdxl' and method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN, LayerMethod.BG_BLEND_TO_FG] else False
if sdxl_allow or sd15_allow:
if self.vae_transparent_decoder is None:
model_url = LAYER_DIFFUSION_VAE['decode'][sd_version]["model_url"]
if model_url is None:
raise Exception(f"{method.value} is not supported for {sd_version} model")
decoder_file = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
self.vae_transparent_decoder = TransparentVAEDecoder(
load_torch_file(decoder_file),
device=comfy.model_management.get_torch_device(),
dtype=(torch.float16 if comfy.model_management.should_use_fp16() else torch.float32),
)
if method in [LayerMethod.EVERYTHING, LayerMethod.BG_BLEND_TO_FG, LayerMethod.BG_TO_BLEND]:
new_images = []
sliced_samples = copy.copy({"samples": latent})
for index in range(len(samp_images)):
if index % self.frames == 0:
img = samp_images[index::self.frames]
alpha_images, _alpha = self.image_to_alpha(img, sliced_samples["samples"][index::self.frames])
alpha.append(self.make_3d_mask(_alpha[0]))
new_images.append(alpha_images[0])
else:
new_images.append(samp_images[index])
else:
new_images, alpha = self.image_to_alpha(samp_images, latent)
else:
new_images = samp_images
else:
new_images = samp_images
return (new_images, samp_images, alpha)

View File

@@ -0,0 +1,359 @@
# Currently only sd15
import functools
import torch
import einops
from comfy import model_management, utils
from comfy.ldm.modules.attention import optimized_attention
module_mapping_sd15 = {
0: "input_blocks.1.1.transformer_blocks.0.attn1",
1: "input_blocks.1.1.transformer_blocks.0.attn2",
2: "input_blocks.2.1.transformer_blocks.0.attn1",
3: "input_blocks.2.1.transformer_blocks.0.attn2",
4: "input_blocks.4.1.transformer_blocks.0.attn1",
5: "input_blocks.4.1.transformer_blocks.0.attn2",
6: "input_blocks.5.1.transformer_blocks.0.attn1",
7: "input_blocks.5.1.transformer_blocks.0.attn2",
8: "input_blocks.7.1.transformer_blocks.0.attn1",
9: "input_blocks.7.1.transformer_blocks.0.attn2",
10: "input_blocks.8.1.transformer_blocks.0.attn1",
11: "input_blocks.8.1.transformer_blocks.0.attn2",
12: "output_blocks.3.1.transformer_blocks.0.attn1",
13: "output_blocks.3.1.transformer_blocks.0.attn2",
14: "output_blocks.4.1.transformer_blocks.0.attn1",
15: "output_blocks.4.1.transformer_blocks.0.attn2",
16: "output_blocks.5.1.transformer_blocks.0.attn1",
17: "output_blocks.5.1.transformer_blocks.0.attn2",
18: "output_blocks.6.1.transformer_blocks.0.attn1",
19: "output_blocks.6.1.transformer_blocks.0.attn2",
20: "output_blocks.7.1.transformer_blocks.0.attn1",
21: "output_blocks.7.1.transformer_blocks.0.attn2",
22: "output_blocks.8.1.transformer_blocks.0.attn1",
23: "output_blocks.8.1.transformer_blocks.0.attn2",
24: "output_blocks.9.1.transformer_blocks.0.attn1",
25: "output_blocks.9.1.transformer_blocks.0.attn2",
26: "output_blocks.10.1.transformer_blocks.0.attn1",
27: "output_blocks.10.1.transformer_blocks.0.attn2",
28: "output_blocks.11.1.transformer_blocks.0.attn1",
29: "output_blocks.11.1.transformer_blocks.0.attn2",
30: "middle_block.1.transformer_blocks.0.attn1",
31: "middle_block.1.transformer_blocks.0.attn2",
}
def compute_cond_mark(cond_or_uncond, sigmas):
cond_or_uncond_size = int(sigmas.shape[0])
cond_mark = []
for cx in cond_or_uncond:
cond_mark += [cx] * cond_or_uncond_size
cond_mark = torch.Tensor(cond_mark).to(sigmas)
return cond_mark
class LoRALinearLayer(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
super().__init__()
self.down = torch.nn.Linear(in_features, rank, bias=False)
self.up = torch.nn.Linear(rank, out_features, bias=False)
self.org = [org]
def forward(self, h):
org_weight = self.org[0].weight.to(h)
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
down_weight = self.down.weight
up_weight = self.up.weight
final_weight = org_weight + torch.mm(up_weight, down_weight)
return torch.nn.functional.linear(h, final_weight, org_bias)
class AttentionSharingUnit(torch.nn.Module):
# `transformer_options` passed to the most recent BasicTransformerBlock.forward
# call.
transformer_options: dict = {}
def __init__(self, module, frames=2, use_control=True, rank=256):
super().__init__()
self.heads = module.heads
self.frames = frames
self.original_module = [module]
q_in_channels, q_out_channels = (
module.to_q.in_features,
module.to_q.out_features,
)
k_in_channels, k_out_channels = (
module.to_k.in_features,
module.to_k.out_features,
)
v_in_channels, v_out_channels = (
module.to_v.in_features,
module.to_v.out_features,
)
o_in_channels, o_out_channels = (
module.to_out[0].in_features,
module.to_out[0].out_features,
)
hidden_size = k_out_channels
self.to_q_lora = [
LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
for _ in range(self.frames)
]
self.to_k_lora = [
LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
for _ in range(self.frames)
]
self.to_v_lora = [
LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
for _ in range(self.frames)
]
self.to_out_lora = [
LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
for _ in range(self.frames)
]
self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)
self.temporal_i = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_n = torch.nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6
)
self.temporal_q = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_k = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_v = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_o = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.control_convs = None
if use_control:
self.control_convs = [
torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(256, hidden_size, kernel_size=1),
)
for _ in range(self.frames)
]
self.control_convs = torch.nn.ModuleList(self.control_convs)
self.control_signals = None
def forward(self, h, context=None, value=None):
transformer_options = self.transformer_options
modified_hidden_states = einops.rearrange(
h, "(b f) d c -> f b d c", f=self.frames
)
if self.control_convs is not None:
context_dim = int(modified_hidden_states.shape[2])
control_outs = []
for f in range(self.frames):
control_signal = self.control_signals[context_dim].to(
modified_hidden_states
)
control = self.control_convs[f](control_signal)
control = einops.rearrange(control, "b c h w -> b (h w) c")
control_outs.append(control)
control_outs = torch.stack(control_outs, dim=0)
modified_hidden_states = modified_hidden_states + control_outs.to(
modified_hidden_states
)
if context is None:
framed_context = modified_hidden_states
else:
framed_context = einops.rearrange(
context, "(b f) d c -> f b d c", f=self.frames
)
framed_cond_mark = einops.rearrange(
compute_cond_mark(
transformer_options["cond_or_uncond"],
transformer_options["sigmas"],
),
"(b f) -> f b",
f=self.frames,
).to(modified_hidden_states)
attn_outs = []
for f in range(self.frames):
fcf = framed_context[f]
if context is not None:
cond_overwrite = transformer_options.get("cond_overwrite", [])
if len(cond_overwrite) > f:
cond_overwrite = cond_overwrite[f]
else:
cond_overwrite = None
if cond_overwrite is not None:
cond_mark = framed_cond_mark[f][:, None, None]
fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark
q = self.to_q_lora[f](modified_hidden_states[f])
k = self.to_k_lora[f](fcf)
v = self.to_v_lora[f](fcf)
o = optimized_attention(q, k, v, self.heads)
o = self.to_out_lora[f](o)
o = self.original_module[0].to_out[1](o)
attn_outs.append(o)
attn_outs = torch.stack(attn_outs, dim=0)
modified_hidden_states = modified_hidden_states + attn_outs.to(
modified_hidden_states
)
modified_hidden_states = einops.rearrange(
modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
)
x = modified_hidden_states
x = self.temporal_n(x)
x = self.temporal_i(x)
d = x.shape[1]
x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)
q = self.temporal_q(x)
k = self.temporal_k(x)
v = self.temporal_v(x)
x = optimized_attention(q, k, v, self.heads)
x = self.temporal_o(x)
x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)
modified_hidden_states = modified_hidden_states + x
return modified_hidden_states - h
@classmethod
def hijack_transformer_block(cls):
def register_get_transformer_options(func):
@functools.wraps(func)
def forward(self, x, context=None, transformer_options={}):
cls.transformer_options = transformer_options
return func(self, x, context, transformer_options)
return forward
from comfy.ldm.modules.attention import BasicTransformerBlock
BasicTransformerBlock.forward = register_get_transformer_options(
BasicTransformerBlock.forward
)
AttentionSharingUnit.hijack_transformer_block()
class AdditionalAttentionCondsEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.blocks_0 = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 64*64*256
self.blocks_1 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 32*32*256
self.blocks_2 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 16*16*256
self.blocks_3 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 8*8*256
self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]
def __call__(self, h):
results = {}
for b in self.blks:
h = b(h)
results[int(h.shape[2]) * int(h.shape[3])] = h
return results
class HookerLayers(torch.nn.Module):
def __init__(self, layer_list):
super().__init__()
self.layers = torch.nn.ModuleList(layer_list)
class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
super().__init__()
model_management.unload_model_clones(unet)
units = []
for i in range(32):
real_key = module_mapping_sd15[i]
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
u = AttentionSharingUnit(
attn_module, frames=frames, use_control=use_control, rank=rank
)
units.append(u)
unet.add_object_patch("diffusion_model." + real_key, u)
self.hookers = HookerLayers(units)
if use_control:
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
else:
self.kwargs_encoder = None
self.dtype = torch.float32
if model_management.should_use_fp16(model_management.get_torch_device()):
self.dtype = torch.float16
self.hookers.half()
return
def set_control(self, img):
img = img.cpu().float() * 2.0 - 1.0
signals = self.kwargs_encoder(img)
for m in self.hookers.layers:
m.control_signals = signals
return

View File

@@ -0,0 +1,390 @@
import torch.nn as nn
import torch
import cv2
import numpy as np
import comfy.model_management
from comfy.model_patcher import ModelPatcher
from tqdm import tqdm
from typing import Optional, Tuple
from ...libs.utils import install_package
from packaging import version
try:
install_package("diffusers", "0.27.2", True, "0.25.0")
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers import __version__
if __version__:
if version.parse(__version__) < version.parse("0.26.0"):
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
else:
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
import functools
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class LatentTransparencyOffsetEncoder(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = torch.nn.Sequential(
torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1),
nn.SiLU(),
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
nn.SiLU(),
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
nn.SiLU(),
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
nn.SiLU(),
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
nn.SiLU(),
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
nn.SiLU(),
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
nn.SiLU(),
zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)),
)
def __call__(self, x):
return self.blocks(x)
# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3
class UNet1024(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types: Tuple[str] = (
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 4,
norm_eps: float = 1e-5,
):
super().__init__()
# input
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
)
self.latent_conv_in = zero_module(
nn.Conv2d(4, block_out_channels[2], kernel_size=1)
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=None,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=(
attention_head_dim
if attention_head_dim is not None
else output_channel
),
downsample_padding=downsample_padding,
resnet_time_scale_shift="default",
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=None,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attention_head_dim=(
attention_head_dim
if attention_head_dim is not None
else block_out_channels[-1]
),
resnet_groups=norm_num_groups,
attn_groups=None,
add_attention=True,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=None,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=(
attention_head_dim
if attention_head_dim is not None
else output_channel
),
resnet_time_scale_shift="default",
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=3, padding=1
)
def forward(self, x, latent):
sample_latent = self.latent_conv_in(latent)
sample = self.conv_in(x)
emb = None
down_block_res_samples = (sample,)
for i, downsample_block in enumerate(self.down_blocks):
if i == 3:
sample = sample + sample_latent
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
sample = self.mid_block(sample, emb)
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
sample = upsample_block(sample, res_samples, emb)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def checkerboard(shape):
return np.indices(shape).sum(axis=0) % 2
def fill_checkerboard_bg(y: torch.Tensor) -> torch.Tensor:
alpha = y[..., :1]
fg = y[..., 1:]
B, H, W, C = fg.shape
cb = checkerboard(shape=(H // 64, W // 64))
cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
cb = torch.from_numpy(cb).to(fg)
vis = fg * alpha + cb * (1 - alpha)
return vis
class TransparentVAEDecoder:
def __init__(self, sd, device, dtype):
self.load_device = device
self.dtype = dtype
model = UNet1024(in_channels=3, out_channels=4)
model.load_state_dict(sd, strict=True)
model.to(self.load_device, dtype=self.dtype)
model.eval()
self.model = model
@torch.no_grad()
def estimate_single_pass(self, pixel, latent):
y = self.model(pixel, latent)
return y
@torch.no_grad()
def estimate_augmented(self, pixel, latent):
args = [
[False, 0],
[False, 1],
[False, 2],
[False, 3],
[True, 0],
[True, 1],
[True, 2],
[True, 3],
]
result = []
for flip, rok in tqdm(args):
feed_pixel = pixel.clone()
feed_latent = latent.clone()
if flip:
feed_pixel = torch.flip(feed_pixel, dims=(3,))
feed_latent = torch.flip(feed_latent, dims=(3,))
feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3))
feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3))
eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1)
eps = torch.rot90(eps, k=-rok, dims=(2, 3))
if flip:
eps = torch.flip(eps, dims=(3,))
result += [eps]
result = torch.stack(result, dim=0)
median = torch.median(result, dim=0).values
return median
@torch.no_grad()
def decode_pixel(
self, pixel: torch.TensorType, latent: torch.TensorType
) -> torch.TensorType:
# pixel.shape = [B, C=3, H, W]
assert pixel.shape[1] == 3
pixel_device = pixel.device
pixel_dtype = pixel.dtype
pixel = pixel.to(device=self.load_device, dtype=self.dtype)
latent = latent.to(device=self.load_device, dtype=self.dtype)
# y.shape = [B, C=4, H, W]
y = self.estimate_augmented(pixel, latent)
y = y.clip(0, 1)
assert y.shape[1] == 4
# Restore image to original device of input image.
return y.to(pixel_device, dtype=pixel_dtype)
def calculate_weight_adjust_channel(func):
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""
@functools.wraps(func)
def calculate_weight(
patches, weight: torch.Tensor, key: str, intermediate_type=torch.float32
) -> torch.Tensor:
weight = func(patches, weight, key, intermediate_type)
for p in patches:
alpha = p[0]
v = p[1]
# The recursion call should be handled in the main func call.
if isinstance(v, list):
continue
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if all(
(
alpha != 0.0,
w1.shape != weight.shape,
w1.ndim == weight.ndim == 4,
)
):
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(
f"Merged with {key} channel changed from {weight.shape} to {new_shape}"
)
new_diff = alpha * comfy.model_management.cast_to_device(
w1, weight.device, weight.dtype
)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[
: weight.shape[0],
: weight.shape[1],
: weight.shape[2],
: weight.shape[3],
] = weight
new_weight[
: new_diff.shape[0],
: new_diff.shape[1],
: new_diff.shape[2],
: new_diff.shape[3],
] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
return weight
return calculate_weight
except ImportError:
ModelMixin = None
ConfigMixin = None
TransparentVAEDecoder = None
calculate_weight_adjust_channel = None
print("\33[33mModule 'diffusers' load failed. If you don't have it installed, do it:\033[0m")
print("\33[33mpip install diffusers\033[0m")