Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,42 @@
{
"_name_or_path": "THUDM/chatglm3-6b-base",
"model_type": "chatglm",
"architectures": [
"ChatGLMModel"
],
"auto_map": {
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
},
"add_bias_linear": false,
"add_qkv_bias": true,
"apply_query_key_layer_scaling": true,
"apply_residual_connection_post_layernorm": false,
"attention_dropout": 0.0,
"attention_softmax_in_fp32": true,
"bias_dropout_fusion": true,
"ffn_hidden_size": 13696,
"fp32_residual_connection": false,
"hidden_dropout": 0.0,
"hidden_size": 4096,
"kv_channels": 128,
"layernorm_epsilon": 1e-05,
"multi_query_attention": true,
"multi_query_group_num": 2,
"num_attention_heads": 32,
"num_layers": 28,
"original_rope": true,
"padded_vocab_size": 65024,
"post_layer_norm": true,
"rmsnorm": true,
"seq_length": 32768,
"use_cache": true,
"torch_dtype": "float16",
"transformers_version": "4.30.2",
"tie_word_embeddings": false,
"eos_token_id": 2,
"pad_token_id": 0
}

View File

@@ -0,0 +1,60 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,300 @@
import json
import os
import re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs

View File

@@ -0,0 +1,12 @@
{
"name_or_path": "THUDM/chatglm3-6b-base",
"remove_space": false,
"do_lower_case": false,
"tokenizer_class": "ChatGLMTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_chatglm.ChatGLMTokenizer",
null
]
}
}

View File

@@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}

View File

@@ -0,0 +1,303 @@
import json
import os
import torch
import subprocess
import sys
import comfy.supported_models
import comfy.model_patcher
import comfy.model_management
import comfy.model_detection as model_detection
import comfy.model_base as model_base
from comfy.model_base import sdxl_pooled, CLIPEmbeddingNoiseAugmentation, Timestep, ModelType
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.clip_vision import ClipVisionModel, Output
from comfy.utils import load_torch_file
from .chatglm.modeling_chatglm import ChatGLMModel, ChatGLMConfig
from .chatglm.tokenization_chatglm import ChatGLMTokenizer
class KolorsUNetModel(UNetModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_hid_proj = torch.nn.Linear(4096, 2048, bias=True)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
if "context" in kwargs:
kwargs["context"] = self.encoder_hid_proj(kwargs["context"])
result = super().forward(*args, **kwargs)
return result
class KolorsSDXL(model_base.SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
model_base.BaseModel.__init__(self, model_config, model_type, device=device, unet_model=KolorsUNetModel)
self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs):
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(
dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class Kolors(comfy.supported_models.SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 5632,
"use_temporal_attention": False,
}
def get_model(self, state_dict, prefix="", device=None):
out = KolorsSDXL(self, model_type=self.model_type(state_dict, prefix), device=device, )
out.__class__ = model_base.SDXL
if self.inpaint_model():
out.set_inpaint()
return out
def kolors_unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
attn_res = 1
count_blocks = model_detection.count_blocks
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(
state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(
state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(
state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(
i, ab)].shape[1]
attn_res *= 2
if attn_blocks == 0:
for i in range(res_blocks):
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
Kolors = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 9,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
Kolors_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 8,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1,
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [Kolors, Kolors_inpaint,
Kolors_ip2p, SDXL, SDXL_mid_cnet, SDXL_small_cnet]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
# print("key {} does not match".format(k), match[k], "||", unet_config[k])
matches = False
break
if matches:
return model_detection.convert_config(unet_config)
return None
# chatglm3 model
class chatGLM3Model(torch.nn.Module):
def __init__(self, textmodel_json_config=None, device='cpu', offload_device='cpu', model_path=None):
super().__init__()
if model_path is None:
raise ValueError("model_path is required")
self.device = device
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"chatglm",
"config_chatglm.json"
)
with open(textmodel_json_config, 'r') as file:
config = json.load(file)
textmodel_json_config = ChatGLMConfig(**config)
is_accelerate_available = False
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
pass
from contextlib import nullcontext
with (init_empty_weights() if is_accelerate_available else nullcontext()):
with torch.no_grad():
print('torch version:', torch.__version__)
self.text_encoder = ChatGLMModel(textmodel_json_config).eval()
if '4bit' in model_path:
try:
import cpm_kernels
except ImportError:
print("Installing cpm_kernels...")
subprocess.run([sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True)
pass
self.text_encoder.quantize(4)
elif '8bit' in model_path:
self.text_encoder.quantize(8)
sd = load_torch_file(model_path)
if is_accelerate_available:
for key in sd:
set_module_tensor_to_device(self.text_encoder, key, device=offload_device, value=sd[key])
else:
print("WARNING: Accelerate not available, use load_state_dict load model")
self.text_encoder.load_state_dict()
def load_chatglm3(model_path=None):
if model_path is None:
return
load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
glm3model = chatGLM3Model(
device=load_device,
offload_device=offload_device,
model_path=model_path
)
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'chatglm', "tokenizer")
tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
text_encoder = glm3model.text_encoder
return {"text_encoder":text_encoder, "tokenizer":tokenizer}
# clipvision model
def load_clipvision_vitl_336(path):
sd = load_torch_file(path)
if "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
raise Exception("Unsupported clip vision model")
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
return clip
class applyKolorsUnet:
def __enter__(self):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.utils
import comfy.clip_vision
self.original_UNET_MAP_BASIC = comfy.utils.UNET_MAP_BASIC.copy()
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.weight", "encoder_hid_proj.weight"),)
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.bias", "encoder_hid_proj.bias"),)
self.original_unet_config_from_diffusers_unet = model_detection.unet_config_from_diffusers_unet
model_detection.unet_config_from_diffusers_unet = kolors_unet_config_from_diffusers_unet
import comfy.supported_models
self.original_supported_models = comfy.supported_models.models
comfy.supported_models.models = [Kolors]
self.original_load_clipvision_from_sd = comfy.clip_vision.load_clipvision_from_sd
comfy.clip_vision.load_clipvision_from_sd = load_clipvision_vitl_336
def __exit__(self, type, value, traceback):
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.utils
import comfy.supported_models
import comfy.clip_vision
comfy.utils.UNET_MAP_BASIC = self.original_UNET_MAP_BASIC
model_detection.unet_config_from_diffusers_unet = self.original_unet_config_from_diffusers_unet
comfy.supported_models.models = self.original_supported_models
comfy.clip_vision.load_clipvision_from_sd = self.original_load_clipvision_from_sd
def is_kolors_model(model):
unet_config = model.model.model_config.unet_config if hasattr(model, 'model') else None
if unet_config and "adm_in_channels" in unet_config and unet_config["adm_in_channels"] == 5632:
return True
else:
return False

View File

@@ -0,0 +1,66 @@
import torch
from torch.nn import Linear
from types import MethodType
import comfy.model_management
import comfy.samplers
from comfy.cldm.cldm import ControlNet
from comfy.controlnet import ControlLora
def patch_controlnet(model, control_net):
import comfy.controlnet
if isinstance(control_net, ControlLora):
del_keys = []
for k in control_net.control_weights:
if k.startswith("label_emb.0.0."):
del_keys.append(k)
for k in del_keys:
control_net.control_weights.pop(k)
super_pre_run = ControlLora.pre_run
super_copy = ControlLora.copy
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = model.model.diffusion_model.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
def KolorsControlLora_pre_run(self, *args, **kwargs):
result = super_pre_run(self, *args, **kwargs)
if hasattr(self, "control_model"):
self.control_model.forward = MethodType(
KolorsControlNet_forward, self.control_model)
return result
control_net.pre_run = MethodType(
KolorsControlLora_pre_run, control_net)
def KolorsControlLora_copy(self, *args, **kwargs):
c = super_copy(self, *args, **kwargs)
c.pre_run = MethodType(
KolorsControlLora_pre_run, c)
return c
control_net.copy = MethodType(KolorsControlLora_copy, control_net)
elif isinstance(control_net, comfy.controlnet.ControlNet):
model_label_emb = model.model.diffusion_model.label_emb
control_net.control_model.label_emb = model_label_emb
control_net.control_model_wrapped.model.label_emb = model_label_emb
super_forward = ControlNet.forward
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
with torch.cuda.amp.autocast(enabled=True):
context = model.model.diffusion_model.encoder_hid_proj(context)
return super_forward(self, x, hint, timesteps, context, **kwargs)
control_net.control_model.forward = MethodType(
KolorsControlNet_forward, control_net.control_model)
else:
raise NotImplementedError(f"Type {control_net} not supported for KolorsControlNetPatch")
return control_net

View File

@@ -0,0 +1,105 @@
import re
import random
import gc
import comfy.model_management as mm
from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
if clean_gpu:
mm.unload_all_models()
mm.soft_empty_cache()
# Function to randomly select an option from the brackets
def choose_random_option(match):
options = match.group(1).split('|')
return random.choice(options)
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt)
if "|" in prompt:
prompt = prompt.split("|")
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
# Define tokenizers and text encoders
tokenizer = chatglm3_model['tokenizer']
text_encoder = chatglm3_model['text_encoder']
text_encoder.to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True)
# [batch_size, 77, 4096]
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
bs_embed = text_proj.shape[0]
text_proj = text_proj.repeat(1, 1).view(bs_embed, -1)
text_encoder.to(offload_device)
if clean_gpu:
mm.soft_empty_cache()
gc.collect()
return [[prompt_embeds, {"pooled_output": text_proj},]]
def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False):
time_start = 0
time_end = 1
match = re.search(r'TIMESTEP.*$', text)
if match:
timestep = match.group()
timestep = timestep.split(' ')
timestep = timestep[0]
text = text.replace(timestep, '')
value = timestep.split(':')
if len(value) >= 3:
time_start = float(value[1])
time_end = float(value[2])
elif len(value) == 2:
time_start = float(value[1])
time_end = 1
elif len(value) == 1:
time_start = 0.1
time_end = 1
pass3 = [x.strip() for x in text.split("BREAK")]
pass3 = [x for x in pass3 if x != '']
if len(pass3) == 0:
pass3 = ['']
conditioning = None
for text in pass3:
cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu)
if conditioning is not None:
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
else:
conditioning = cond
# setTimeStepRange
if time_start > 0 or time_end < 1:
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
return conditioning