Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,113 @@
import urllib.parse
from os import PathLike
from aiohttp import web
from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher
from server import PromptServer
from pathlib import Path
# 文件限制大小MB
max_size = 50
def suffix_limiter(self: web.StaticResource, request: web.Request):
suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"}
rel_url = request.match_info["filename"]
try:
filename = Path(rel_url)
if filename.anchor:
raise web.HTTPForbidden()
filepath = self._directory.joinpath(filename).resolve()
if filepath.exists() and filepath.suffix.lower() not in suffixes:
raise web.HTTPForbidden(reason="File type is not allowed")
finally:
pass
def filesize_limiter(self: web.StaticResource, request: web.Request):
rel_url = request.match_info["filename"]
try:
filename = Path(rel_url)
filepath = self._directory.joinpath(filename).resolve()
if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024:
raise web.HTTPForbidden(reason="File size is too large")
finally:
pass
class LimitResource(web.StaticResource):
limiters = []
def push_limiter(self, limiter):
self.limiters.append(limiter)
async def _handle(self, request: web.Request) -> web.StreamResponse:
try:
for limiter in self.limiters:
limiter(self, request)
except (ValueError, FileNotFoundError) as error:
raise web.HTTPNotFound() from error
return await super()._handle(request)
def __repr__(self) -> str:
name = "'" + self.name + "'" if self.name is not None else ""
return f'<LimitResource {name} {self._prefix} -> {self._directory!r}>'
class LimitRouter(web.StaticDef):
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(f", {name}={value!r}")
return f'<LimitRouter {self.prefix} -> {self.path}{"".join(info)}>'
def register(self, router: UrlDispatcher) -> list[AbstractRoute]:
# resource = router.add_static(self.prefix, self.path, **self.kwargs)
def add_static(
self: UrlDispatcher,
prefix: str,
path: PathLike,
*,
name=None,
expect_handler=None,
chunk_size: int = 256 * 1024,
show_index: bool = False,
follow_symlinks: bool = False,
append_version: bool = False,
) -> web.AbstractResource:
assert prefix.startswith("/")
if prefix.endswith("/"):
prefix = prefix[:-1]
resource = LimitResource(
prefix,
path,
name=name,
expect_handler=expect_handler,
chunk_size=chunk_size,
show_index=show_index,
follow_symlinks=follow_symlinks,
append_version=append_version,
)
resource.push_limiter(suffix_limiter)
resource.push_limiter(filesize_limiter)
self.register_resource(resource)
return resource
resource = add_static(router, self.prefix, self.path, **self.kwargs)
routes = resource.get_info().get("routes", {})
return list(routes.values())
def path_to_url(path):
if not path:
return path
path = path.replace("\\", "/")
if not path.startswith("/"):
path = "/" + path
while path.startswith("//"):
path = path[1:]
path = path.replace("//", "/")
return path
def add_static_resource(prefix, path,limit=False):
app = PromptServer.instance.app
prefix = path_to_url(prefix)
prefix = urllib.parse.quote(prefix)
prefix = path_to_url(prefix)
if limit:
route = LimitRouter(prefix, path, {"follow_symlinks": True})
else:
route = web.static(prefix, path, follow_symlinks=True)
app.add_routes([route])

View File

@@ -0,0 +1,427 @@
import torch
import numpy as np
import re
import itertools
from comfy import model_management
from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG
try:
from comfy.text_encoders.sd3_clip import SD3ClipModel, T5XXLModel
except ImportError:
from comfy.sd3_clip import SD3ClipModel, T5XXLModel
from nodes import NODE_CLASS_MAPPINGS, ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
def _grouper(n, iterable):
it = iter(iterable)
while True:
chunk = list(itertools.islice(it, n))
if not chunk:
return
yield chunk
def _norm_mag(w, n):
d = w - 1
return 1 + np.sign(d) * np.sqrt(np.abs(d) ** 2 / n)
# return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
def divide_length(word_ids, weights):
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
sums[0] = 1
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def shift_mean_weight(word_ids, weights):
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x, y) if id != 0])
weights = [[w if id == 0 else w + delta
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def scale_to_norm(weights, word_ids, w_max):
top = np.max(weights)
w_max = min(top, w_max)
weights = [[w_max if id == 0 else (w / top) * w_max
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def from_zero(weights, base_emb):
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
return base_emb * weight_tensor
def mask_word_id(tokens, word_ids, target_id, mask_token):
new_tokens = [[mask_token if wid == target_id else t
for t, wid in zip(x, y)] for x, y in zip(tokens, word_ids)]
mask = np.array(word_ids) == target_id
return (new_tokens, mask)
def batched_clip_encode(tokens, length, encode_func, num_chunks):
embs = []
for e in _grouper(32, tokens):
enc, pooled = encode_func(e)
enc = enc.reshape((len(e), length, -1))
embs.append(enc)
embs = torch.cat(embs)
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
return embs
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
pooled_base = base_emb[0, length - 1:length, :]
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
weight_dict = dict((id, w)
for id, w in zip(wids, np.array(weights).reshape(-1)[inds])
if w != 1.0)
if len(weight_dict) == 0:
return torch.zeros_like(base_emb), base_emb[0, length - 1:length, :]
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
# TODO: find most suitable masking token here
m_token = (m_token, 1.0)
ws = []
masked_tokens = []
masks = []
# create prompts
for id, w in weight_dict.items():
masked, m = mask_word_id(tokens, word_ids, id, m_token)
masked_tokens.extend(masked)
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
m = m.reshape(1, -1, 1).expand(base_emb.shape)
masks.append(m)
ws.append(w)
# batch process prompts
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
masks = torch.cat(masks)
embs = (base_emb.expand(embs.shape) - embs)
pooled = embs[0, length - 1:length, :]
embs *= masks
embs = embs.sum(axis=0, keepdim=True)
pooled_start = pooled_base.expand(len(ws), -1)
ws = torch.tensor(ws).reshape(-1, 1).expand(pooled_start.shape)
pooled = (pooled - pooled_start) * (ws - 1)
pooled = pooled.mean(axis=0, keepdim=True)
return ((weight_tensor - 1) * embs), pooled_base + pooled
def mask_inds(tokens, inds, mask_token):
clip_len = len(tokens[0])
inds_set = set(inds)
new_tokens = [[mask_token if i * clip_len + j in inds_set else t
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
return new_tokens
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
w, w_inv = np.unique(weights, return_inverse=True)
if np.sum(w < 1) == 0:
return base_emb, tokens, base_emb[0, length - 1:length, :]
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
# using the comma token as a masking token seems to work better than aos tokens for SD 1.x
m_token = (m_token, 1.0)
masked_tokens = []
masked_current = tokens
for i in range(len(w)):
if w[i] >= 1:
continue
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
masked_tokens.extend(masked_current)
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
embs = torch.cat([base_emb, embs])
w = w[w <= 1.0]
w_mix = np.diff([0] + w.tolist())
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1, 1, 1))
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
return weighted_emb, masked_current, weighted_emb[0, length - 1:length, :]
def scale_emb_to_mag(base_emb, weighted_emb):
norm_base = torch.linalg.norm(base_emb)
norm_weighted = torch.linalg.norm(weighted_emb)
embeddings_final = (norm_base / norm_weighted) * weighted_emb
return embeddings_final
def recover_dist(base_emb, weighted_emb):
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
return embeddings_final
def A1111_renorm(base_emb, weighted_emb):
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
return embeddings_final
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266,
length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
tokens = [[t for t, _, _ in x] for x in tokenized]
weights = [[w for _, w, _ in x] for x in tokenized]
word_ids = [[wid for _, _, wid in x] for x in tokenized]
# weight normalization
# ====================
# distribute down/up weights over word lengths
if token_normalization.startswith("length"):
weights = divide_length(word_ids, weights)
# make mean of word tokens 1
if token_normalization.endswith("mean"):
weights = shift_mean_weight(word_ids, weights)
# weight interpretation
# =====================
pooled = None
if weight_interpretation == "comfy":
weighted_tokens = [[(t, w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
weighted_emb, pooled_base = encode_func(weighted_tokens)
pooled = pooled_base
else:
unweighted_tokens = [[(t, 1.0) for t, _, _ in x] for x in tokenized]
base_emb, pooled_base = encode_func(unweighted_tokens)
if weight_interpretation == "A1111":
weighted_emb = from_zero(weights, base_emb)
weighted_emb = A1111_renorm(base_emb, weighted_emb)
pooled = pooled_base
if weight_interpretation == "compel":
pos_tokens = [[(t, w) if w >= 1.0 else (t, 1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
weighted_emb, _ = encode_func(pos_tokens)
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
if weight_interpretation == "comfy++":
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
# unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
weighted_emb += embs
if weight_interpretation == "down_weight":
weights = scale_to_norm(weights, word_ids, w_max)
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
if return_pooled:
if apply_to_pooled:
return weighted_emb, pooled
else:
return weighted_emb, pooled_base
return weighted_emb, None
def encode_token_weights_g(model, token_weight_pairs):
return model.clip_g.encode_token_weights(token_weight_pairs)
def encode_token_weights_l(model, token_weight_pairs):
l_out, pooled = model.clip_l.encode_token_weights(token_weight_pairs)
return l_out, pooled
def encode_token_weights_t5(model, token_weight_pairs):
return model.t5xxl.encode_token_weights(token_weight_pairs)
def encode_token_weights(model, token_weight_pairs, encode_func):
if model.layer_idx is not None:
# 2016 [c2cb8e88] 及以上版本去除了sdxl clip的clip_layer方法
# if compare_revision(2016):
model.cond_stage_model.set_clip_options({'layer': model.layer_idx})
# else:
# model.cond_stage_model.clip_layer(model.layer_idx)
model_management.load_model_gpu(model.patcher)
return encode_func(model.cond_stage_model, token_weight_pairs)
def prepareXL(embs_l, embs_g, pooled, clip_balance):
l_w = 1 - max(0, clip_balance - .5) * 2
g_w = 1 - max(0, .5 - clip_balance) * 2
if embs_l is not None:
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
else:
return embs_g, pooled
def prepareSD3(out, pooled, clip_balance):
lg_w = 1 - max(0, clip_balance - .5) * 2
t5_w = 1 - max(0, .5 - clip_balance) * 2
if out.shape[0] > 1:
return torch.cat([out[0] * lg_w, out[1] * t5_w], dim=-1), pooled
else:
return out, pooled
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5,
apply_to_pooled=True, width=1024, height=1024, crop_w=0, crop_h=0, target_width=1024, target_height=1024, a1111_prompt_style=False, steps=1):
# Use clip text encode by smzNodes like same as a1111, when if you need installed the smzNodes
if a1111_prompt_style:
if "smZ CLIPTextEncode" in NODE_CLASS_MAPPINGS:
cls = NODE_CLASS_MAPPINGS['smZ CLIPTextEncode']
embeddings_final, = cls().encode(clip, text, weight_interpretation, True, True, False, False, 6, 1024, 1024, 0, 0, 1024, 1024, '', '', steps)
return embeddings_final
else:
raise Exception(f"[smzNodes Not Found] you need to install 'ComfyUI-smzNodes'")
time_start = 0
time_end = 1
match = re.search(r'TIMESTEP.*$', text)
if match:
timestep = match.group()
timestep = timestep.split(' ')
timestep = timestep[0]
text = text.replace(timestep, '')
value = timestep.split(':')
if len(value) >= 3:
time_start = float(value[1])
time_end = float(value[2])
elif len(value) == 2:
time_start = float(value[1])
time_end = 1
elif len(value) == 1:
time_start = 0.1
time_end = 1
pass3 = [x.strip() for x in text.split("BREAK")]
pass3 = [x for x in pass3 if x != '']
if len(pass3) == 0:
pass3 = ['']
# pass3_str = [f'[{x}]' for x in pass3]
# print(f"CLIP: {str.join(' + ', pass3_str)}")
conditioning = None
for text in pass3:
tokenized = clip.tokenize(text, return_word_ids=True)
if SD3ClipModel and isinstance(clip.cond_stage_model, SD3ClipModel):
lg_out = None
pooled = None
out = None
if len(tokenized['l']) > 0 or len(tokenized['g']) > 0:
if clip.cond_stage_model.clip_l is not None:
lg_out, l_pooled = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max, return_pooled=True,)
else:
l_pooled = torch.zeros((1, 768), device=model_management.intermediate_device())
if clip.cond_stage_model.clip_g is not None:
g_out, g_pooled = advanced_encode_from_tokens(tokenized['g'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
w_max=w_max, return_pooled=True)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
g_out = None
g_pooled = torch.zeros((1, 1280), device=model_management.intermediate_device())
if lg_out is not None:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
# t5xxl
if 't5xxl' in tokenized:
t5_out, t5_pooled = advanced_encode_from_tokens(tokenized['t5xxl'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_t5),
w_max=w_max, return_pooled=True)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
if out is None:
out = torch.zeros((1, 77, 4096), device=model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=model_management.intermediate_device())
embeddings_final, pooled = prepareSD3(out, pooled, clip_balance)
cond = [[embeddings_final, {"pooled_output": pooled}]]
elif isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)):
embs_l = None
embs_g = None
pooled = None
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max,
return_pooled=False)
if 'g' in tokenized:
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x,
encode_token_weights_g),
w_max=w_max,
return_pooled=True,
apply_to_pooled=apply_to_pooled)
embeddings_final, pooled = prepareXL(embs_l, embs_g, pooled, clip_balance)
cond = [[embeddings_final, {"pooled_output": pooled}]]
# cond = [[embeddings_final,
# {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w,
# "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
else:
embeddings_final, pooled = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max,return_pooled=True,)
cond = [[embeddings_final, {"pooled_output": pooled}]]
if conditioning is not None:
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
else:
conditioning = cond
# setTimeStepRange
if time_start > 0 or time_end < 1:
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
return conditioning

View File

@@ -0,0 +1,372 @@
import yaml
import pathlib
import base64
import io
import json
import os
import pickle
import zlib
import urllib.parse
import urllib.request
import urllib.error
from enum import Enum
from functools import singledispatch
from typing import Any, List, Union
import numpy as np
import torch
from PIL import Image
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path, 'config.yaml')
class BizyAIRAPI:
def __init__(self):
self.base_url = 'https://bizyair-api.siliconflow.cn/x/v1'
self.api_key = None
def getAPIKey(self):
if self.api_key is None:
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if 'BIZYAIR_API_KEY' not in data:
raise Exception("Please add BIZYAIR_API_KEY to config.yaml")
self.api_key = data['BIZYAIR_API_KEY']
else:
raise Exception("Please add config.yaml to root path")
return self.api_key
def send_post_request(self, url, payload, headers):
try:
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
with urllib.request.urlopen(req) as response:
response_data = response.read().decode("utf-8")
return response_data
except urllib.error.URLError as e:
if "Unauthorized" in str(e):
raise Exception(
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
)
else:
raise Exception(
f"Failed to connect to the server: {e}, if you have no key, "
)
# joycaption
def joyCaption(self, payload, image, apikey_override=None, API_URL='/supernode/joycaption2'):
if apikey_override is not None:
api_key = apikey_override
else:
api_key = self.getAPIKey()
url = f"{self.base_url}{API_URL}"
print('Sending request to:', url)
auth = f"Bearer {api_key}"
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": auth,
}
input_image = encode_data(image, disable_image_marker=True)
payload["image"] = input_image
ret: str = self.send_post_request(url=url, payload=payload, headers=headers)
ret = json.loads(ret)
try:
if "result" in ret:
ret = json.loads(ret["result"])
except Exception as e:
raise Exception(f"Unexpected response: {ret} {e=}")
if ret["type"] == "error":
raise Exception(ret["message"])
msg = ret["data"]
if msg["type"] not in ("comfyair", "bizyair",):
raise Exception(f"Unexpected response type: {msg}")
caption = msg["data"]
return caption
bizyairAPI = BizyAIRAPI()
BIZYAIR_DEBUG = True
# Marker to identify base64-encoded tensors
TENSOR_MARKER = "TENSOR:"
IMAGE_MARKER = "IMAGE:"
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
def convert_image_to_rgb(image: Image.Image) -> Image.Image:
if image.mode != "RGB":
return image.convert("RGB")
return image
def encode_image_to_base64(
image: Image.Image, format: str = "png", quality: int = 100, lossless=False
) -> str:
image = convert_image_to_rgb(image)
with io.BytesIO() as output:
image.save(output, format=format, quality=quality, lossless=lossless)
output.seek(0)
img_bytes = output.getvalue()
if BIZYAIR_DEBUG:
print(f"encode_image_to_base64: {format_bytes(len(img_bytes))}")
return base64.b64encode(img_bytes).decode("utf-8")
def decode_base64_to_np(img_data: str, format: str = "png") -> np.ndarray:
img_bytes = base64.b64decode(img_data)
if BIZYAIR_DEBUG:
print(f"decode_base64_to_np: {format_bytes(len(img_bytes))}")
with io.BytesIO(img_bytes) as input_buffer:
img = Image.open(input_buffer)
# https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/nodes.py#L1511
img = img.convert("RGB")
return np.array(img)
def decode_base64_to_image(img_data: str) -> Image.Image:
img_bytes = base64.b64decode(img_data)
with io.BytesIO(img_bytes) as input_buffer:
img = Image.open(input_buffer)
if BIZYAIR_DEBUG:
format_info = img.format.upper() if img.format else "Unknown"
print(f"decode image format: {format_info}")
return img
def format_bytes(num_bytes: int) -> str:
"""
Converts a number of bytes to a human-readable string with units (B, KB, or MB).
:param num_bytes: The number of bytes to convert.
:return: A string representing the number of bytes in a human-readable format.
"""
if num_bytes < 1024:
return f"{num_bytes} B"
elif num_bytes < 1024 * 1024:
return f"{num_bytes / 1024:.2f} KB"
else:
return f"{num_bytes / (1024 * 1024):.2f} MB"
def _legacy_encode_comfy_image(image: torch.Tensor, image_format="png") -> str:
input_image = image.cpu().detach().numpy()
i = 255.0 * input_image[0]
input_image = np.clip(i, 0, 255).astype(np.uint8)
base64ed_image = encode_image_to_base64(
Image.fromarray(input_image), format=image_format
)
return base64ed_image
def _legacy_decode_comfy_image(
img_data: Union[List, str], image_format="png"
) -> torch.tensor:
if isinstance(img_data, List):
decoded_imgs = [decode_comfy_image(x, old_version=True) for x in img_data]
combined_imgs = torch.cat(decoded_imgs, dim=0)
return combined_imgs
out = decode_base64_to_np(img_data, format=image_format)
out = np.array(out).astype(np.float32) / 255.0
output = torch.from_numpy(out)[None,]
return output
def _new_encode_comfy_image(images: torch.Tensor, image_format="WEBP", **kwargs) -> str:
"""https://docs.comfy.org/essentials/custom_node_snippets#save-an-image-batch
Encode a batch of images to base64 strings.
Args:
images (torch.Tensor): A batch of images.
image_format (str, optional): The format of the images. Defaults to "WEBP".
Returns:
str: A JSON string containing the base64-encoded images.
"""
results = {}
for batch_number, image in enumerate(images):
i = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
base64ed_image = encode_image_to_base64(img, format=image_format, **kwargs)
results[batch_number] = base64ed_image
return json.dumps(results)
def _new_decode_comfy_image(img_datas: str, image_format="WEBP") -> torch.tensor:
"""
Decode a batch of base64-encoded images.
Args:
img_datas (str): A JSON string containing the base64-encoded images.
image_format (str, optional): The format of the images. Defaults to "WEBP".
Returns:
torch.Tensor: A tensor containing the decoded images.
"""
img_datas = json.loads(img_datas)
decoded_imgs = []
for img_data in img_datas.values():
decoded_image = decode_base64_to_np(img_data, format=image_format)
decoded_image = np.array(decoded_image).astype(np.float32) / 255.0
decoded_imgs.append(torch.from_numpy(decoded_image)[None,])
return torch.cat(decoded_imgs, dim=0)
def encode_comfy_image(
image: torch.Tensor, image_format="WEBP", old_version=False, lossless=False
) -> str:
if old_version:
return _legacy_encode_comfy_image(image, image_format)
return _new_encode_comfy_image(image, image_format, lossless=lossless)
def decode_comfy_image(
img_data: Union[List, str], image_format="WEBP", old_version=False
) -> torch.tensor:
if old_version:
return _legacy_decode_comfy_image(img_data, image_format)
return _new_decode_comfy_image(img_data, image_format)
def tensor_to_base64(tensor: torch.Tensor, compress=True) -> str:
tensor_np = tensor.cpu().detach().numpy()
tensor_bytes = pickle.dumps(tensor_np)
if compress:
tensor_bytes = zlib.compress(tensor_bytes)
tensor_b64 = base64.b64encode(tensor_bytes).decode("utf-8")
return tensor_b64
def base64_to_tensor(tensor_b64: str, compress=True) -> torch.Tensor:
tensor_bytes = base64.b64decode(tensor_b64)
if compress:
tensor_bytes = zlib.decompress(tensor_bytes)
tensor_np = pickle.loads(tensor_bytes)
tensor = torch.from_numpy(tensor_np)
return tensor
@singledispatch
def decode_data(input, old_version=False):
raise NotImplementedError(f"Unsupported type: {type(input)}")
@decode_data.register(int)
@decode_data.register(float)
@decode_data.register(bool)
@decode_data.register(type(None))
def _(input, **kwargs):
return input
@decode_data.register(dict)
def _(input, **kwargs):
return {k: decode_data(v, **kwargs) for k, v in input.items()}
@decode_data.register(list)
def _(input, **kwargs):
return [decode_data(x, **kwargs) for x in input]
@decode_data.register(str)
def _(input: str, **kwargs):
if input.startswith(TENSOR_MARKER):
tensor_b64 = input[len(TENSOR_MARKER) :]
return base64_to_tensor(tensor_b64)
elif input.startswith(IMAGE_MARKER):
tensor_b64 = input[len(IMAGE_MARKER) :]
old_version = kwargs.get("old_version", False)
return decode_comfy_image(tensor_b64, old_version=old_version)
return input
@singledispatch
def encode_data(output, disable_image_marker=False, old_version=False):
raise NotImplementedError(f"Unsupported type: {type(output)}")
@encode_data.register(dict)
def _(output, **kwargs):
return {k: encode_data(v, **kwargs) for k, v in output.items()}
@encode_data.register(list)
def _(output, **kwargs):
return [encode_data(x, **kwargs) for x in output]
def is_image_tensor(tensor) -> bool:
"""https://docs.comfy.org/essentials/custom_node_datatypes#image
Check if the given tensor is in the format of an IMAGE (shape [B, H, W, C] where C=3).
`Args`:
tensor (torch.Tensor): The tensor to check.
`Returns`:
bool: True if the tensor is in the IMAGE format, False otherwise.
"""
try:
if not isinstance(tensor, torch.Tensor):
return False
if len(tensor.shape) != 4:
return False
B, H, W, C = tensor.shape
if C != 3:
return False
return True
except:
return False
@encode_data.register(torch.Tensor)
def _(output, **kwargs):
if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
old_version = kwargs.get("old_version", False)
lossless = kwargs.get("lossless", True)
return IMAGE_MARKER + encode_comfy_image(
output, image_format="WEBP", old_version=old_version, lossless=lossless
)
return TENSOR_MARKER + tensor_to_base64(output)
@encode_data.register(int)
@encode_data.register(float)
@encode_data.register(bool)
@encode_data.register(type(None))
def _(output, **kwargs):
return output
@encode_data.register(str)
def _(output, **kwargs):
return output

View File

@@ -0,0 +1,51 @@
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
class FluxAIAPI:
def __init__(self):
self.api_url = "https://fluxaiimagegenerator.com/api"
self.origin = "https://fluxaiimagegenerator.com"
self.user_agent = None
self.cookie = None
def promptGenerate(self, text, cookies=None):
cookie = self.cookie if cookies is None else cookies
if cookie is None:
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if 'FLUXAI_COOKIE' not in data:
raise Exception("Please add FLUXAI_COOKIE to config.yaml")
if "FLUXAI_USER_AGENT" in data:
self.user_agent = data["FLUXAI_USER_AGENT"]
self.cookie = cookie = data['FLUXAI_COOKIE']
headers = {
"Cookie": cookie,
"Referer": "https://fluxaiimagegenerator.com/flux-prompt-generator",
"Origin": self.origin,
"Content-Type": "application/json",
}
if self.user_agent is not None:
headers['User-Agent'] = self.user_agent
url = self.api_url + '/prompt'
json = {
"prompt": text
}
response = requests.post(url, json=json, headers=headers)
res = response.json()
if "error" in res:
return res['error']
elif "data" in res and "prompt" in res['data']:
return res['data']['prompt']
fluxaiAPI = FluxAIAPI()

View File

@@ -0,0 +1,200 @@
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
from server import PromptServer
from ..image import tensor2pil, pil2tensor, image2base64, pil2byte
from ..log import log_node_error
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
default_key = [{'name':'Default', 'key':''}]
class StabilityAPI:
def __init__(self):
self.api_url = "https://api.stability.ai"
self.api_keys = None
self.api_current = 0
self.user_info = {}
def getErrors(self, code):
errors = {
400: "Bad Request",
403: "ApiKey Forbidden",
413: "Your request was larger than 10MiB.",
429: "You have made more than 150 requests in 10 seconds.",
500: "Internal Server Error",
}
return errors.get(code, "Unknown Error")
def getAPIKeys(self):
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if not data:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
with open(config_path, 'w') as f:
yaml.dump(data, f)
if 'STABILITY_API_KEY' not in data:
data['STABILITY_API_KEY'] = default_key
data['STABILITY_API_DEFAULT'] = 0
with open(config_path, 'w') as f:
yaml.dump(data, f)
api_keys = data['STABILITY_API_KEY']
self.api_current = data['STABILITY_API_DEFAULT']
self.api_keys = api_keys
return api_keys
else:
# create a yaml file
with open(config_path, 'w') as f:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
yaml.dump(data, f)
return data['STABILITY_API_KEY']
pass
def setAPIKeys(self, api_keys):
if len(api_keys) > 0:
self.api_keys = api_keys
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_KEY'] = api_keys
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def setAPIDefault(self, current):
if current is not None:
self.api_current = current
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_DEFAULT'] = current
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
api_key = self.api_keys[self.api_current]['key']
files = None
data = {
"prompt": prompt,
"mode": mode,
"model": model,
"seed": seed,
"output_format": output_format,
}
if model == 'sd3':
data['negative_prompt'] = negative_prompt
if mode == 'text-to-image':
files = {"none": ''}
data['aspect_ratio'] = aspect_ratio
elif mode == 'image-to-image':
pil_image = tensor2pil(image)
image_byte = pil2byte(pil_image)
files = {"image": ("output.png", image_byte, 'image/png')}
data['strength'] = strength
response = requests.post(url,
headers={"authorization": f"{api_key}", "accept": "application/json"},
files=files,
data=data,
)
if response.status_code == 200:
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
json_data = response.json()
image_base64 = json_data['image']
image_data = image2base64(image_base64)
output_t = pil2tensor(image_data)
return output_t
else:
if 'application/json' in response.headers['Content-Type']:
error_info = response.json()
log_node_error(node_name, error_info.get('name', 'No name provided'))
log_node_error(node_name, error_info.get('errors', ['No details provided']))
error_status_text = self.getErrors(response.status_code)
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
raise Exception(f"Failed to generate image: {error_status_text}")
# get user account
async def getUserAccount(self, cache=True):
url = f"{self.api_url}/v1/user/account"
api_key = self.api_keys[self.api_current]['key']
name = self.api_keys[self.api_current]['name']
if cache and name in self.user_info:
return self.user_info[name]
else:
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
if response.status_code == 200:
user_info = response.json()
self.user_info[name] = user_info
return user_info
else:
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
return None
# get user balance
async def getUserBalance(self):
url = f"{self.api_url}/v1/user/balance"
api_key = self.api_keys[self.api_current]['key']
response = requests.get(url, headers={
"Authorization": f"Bearer {api_key}"
})
if response.status_code == 200:
return response.json()
else:
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
return None
stableAPI = StabilityAPI()
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
async def get_stability_api_keys(request):
stableAPI.getAPIKeys()
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
async def set_stability_api_keys(request):
post = await request.post()
api_keys = post.get("api_keys")
current = post.get('current')
if api_keys is not None:
api_keys = json.loads(api_keys)
stableAPI.setAPIKeys(api_keys)
if current is not None:
print(current)
stableAPI.setAPIDefault(int(current))
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
else:
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
async def set_stability_api_default(request):
post = await request.post()
current = post.get("current")
if current is not None and current < len(stableAPI.api_keys):
stableAPI.api_current = current
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
async def get_account_info(request):
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
@PromptServer.instance.routes.get("/easyuse/stability/balance")
async def get_balance_info(request):
balance = await stableAPI.getUserBalance()
return web.json_response({'balance': balance})

View File

@@ -0,0 +1,86 @@
import itertools
from typing import Optional
class TaggedCache:
def __init__(self, tag_settings: Optional[dict]=None):
self._tag_settings = tag_settings or {} # tag cache size
self._data = {}
def __getitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
raise KeyError(f'Key `{key}` does not exist')
def __setitem__(self, key, value: tuple):
# value: (tag: str, (islist: bool, data: *))
# if key already exists, pop old value
for tag_data in self._data.values():
if key in tag_data:
tag_data.pop(key, None)
break
tag = value[0]
if tag not in self._data:
try:
from cachetools import LRUCache
default_size = 20
if 'ckpt' in tag:
default_size = 5
elif tag in ['latent', 'image']:
default_size = 100
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size))
except (ImportError, ModuleNotFoundError):
# TODO: implement a simple lru dict
self._data[tag] = {}
self._data[tag][key] = value
def __delitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
del tag_data[key]
return
raise KeyError(f'Key `{key}` does not exist')
def __contains__(self, key):
return any(key in tag_data for tag_data in self._data.values())
def items(self):
yield from itertools.chain(*map(lambda x :x.items(), self._data.values()))
def get(self, key, default=None):
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
return default
def clear(self):
# clear all cache
self._data = {}
cache_settings = {}
cache = TaggedCache(cache_settings)
cache_count = {}
def update_cache(k, tag, v):
cache[k] = (tag, v)
cnt = cache_count.get(k)
if cnt is None:
cnt = 0
cache_count[k] = cnt
else:
cache_count[k] += 1
def remove_cache(key):
global cache
if key == '*':
cache = TaggedCache(cache_settings)
elif key in cache:
del cache[key]
else:
print(f"invalid {key}")

View File

@@ -0,0 +1,153 @@
from threading import Event
import torch
from server import PromptServer
from aiohttp import web
from comfy import model_management as mm
from comfy_execution.graph import ExecutionBlocker
import time
class ChooserCancelled(Exception):
pass
def get_chooser_cache():
"""获取选择器缓存"""
if not hasattr(PromptServer.instance, '_easyuse_chooser_node'):
PromptServer.instance._easyuse_chooser_node = {}
return PromptServer.instance._easyuse_chooser_node
def cleanup_session_data(node_id):
"""清理会话数据"""
node_data = get_chooser_cache()
if node_id in node_data:
session_keys = ["event", "selected", "images", "total_count", "cancelled"]
for key in session_keys:
if key in node_data[node_id]:
del node_data[node_id][key]
def wait_for_chooser(id, images, mode, period=0.1):
try:
node_data = get_chooser_cache()
images = [images[i:i + 1, ...] for i in range(images.shape[0])]
if mode == "Keep Last Selection":
if id in node_data and "last_selection" in node_data[id]:
last_selection = node_data[id]["last_selection"]
if last_selection and len(last_selection) > 0:
valid_indices = [idx for idx in last_selection if 0 <= idx < len(images)]
if valid_indices:
try:
PromptServer.instance.send_sync("easyuse-image-keep-selection", {
"id": id,
"selected": valid_indices
})
except Exception as e:
pass
cleanup_session_data(id)
indices_str = ','.join(str(i) for i in valid_indices)
images = [images[idx] for idx in valid_indices]
images = torch.cat(images, dim=0)
return {"result": (images,)}
if id in node_data:
del node_data[id]
event = Event()
node_data[id] = {
"event": event,
"images": images,
"selected": None,
"total_count": len(images),
"cancelled": False,
}
while id in node_data:
node_info = node_data[id]
if node_info.get("cancelled", False):
cleanup_session_data(id)
raise ChooserCancelled("Manual selection cancelled")
if "selected" in node_info and node_info["selected"] is not None:
break
time.sleep(period)
if id in node_data:
node_info = node_data[id]
selected_indices = node_info.get("selected")
if selected_indices is not None and len(selected_indices) > 0:
valid_indices = [idx for idx in selected_indices if 0 <= idx < len(images)]
if valid_indices:
selected_images = [images[idx] for idx in valid_indices]
if id not in node_data:
node_data[id] = {}
node_data[id]["last_selection"] = valid_indices
cleanup_session_data(id)
selected_images = torch.cat(selected_images, dim=0)
return {"result": (selected_images,)}
else:
cleanup_session_data(id)
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
else:
cleanup_session_data(id)
return {
"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
else:
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
except ChooserCancelled:
raise mm.InterruptProcessingException()
except Exception as e:
node_data = get_chooser_cache()
if id in node_data:
cleanup_session_data(id)
if 'image_list' in locals() and len(images) > 0:
return {"result": (images[0])}
else:
return {"result": (ExecutionBlocker(None),)}
@PromptServer.instance.routes.post('/easyuse/image_chooser_message')
async def handle_image_selection(request):
try:
data = await request.json()
node_id = data.get("node_id")
selected = data.get("selected", [])
action = data.get("action")
node_data = get_chooser_cache()
if node_id not in node_data:
return web.json_response({"code": -1, "error": "Node data does not exist"})
try:
node_info = node_data[node_id]
if "total_count" not in node_info:
return web.json_response({"code": -1, "error": "The node has been processed"})
if action == "cancel":
node_info["cancelled"] = True
node_info["selected"] = []
elif action == "select" and isinstance(selected, list):
valid_indices = [idx for idx in selected if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
if valid_indices:
node_info["selected"] = valid_indices
node_info["cancelled"] = False
else:
return web.json_response({"code": -1, "error": "Invalid Selection Index"})
else:
return web.json_response({"code": -1, "error": "Invalid operation"})
node_info["event"].set()
return web.json_response({"code": 1})
except Exception as e:
if node_id in node_data and "event" in node_data[node_id]:
node_data[node_id]["event"].set()
return web.json_response({"code": -1, "message": "Processing Failed"})
except Exception as e:
return web.json_response({"code": -1, "message": "Request Failed"})

View File

@@ -0,0 +1,115 @@
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def wavelet_color_fix(target: Image, source: Image):
source = source.resize(target.size, resample=Image.Resampling.LANCZOS)
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq

View File

@@ -0,0 +1,57 @@
from .utils import find_wildcards_seed, find_nearest_steps, is_linked_styles_selector
from .log import log_node_warn
from .translate import zh_to_en, has_chinese
from .wildcards import process_with_loras
from .adv_encode import advanced_encode
from nodes import ConditioningConcat, ConditioningCombine, ConditioningAverage, ConditioningSetTimestepRange, CLIPTextEncode
def prompt_to_cond(type, model, clip, clip_skip, lora_stack, text, prompt_token_normalization, prompt_weight_interpretation, a1111_prompt_style ,my_unique_id, prompt, easyCache, can_load_lora=True, steps=None, model_type=None):
styles_selector = is_linked_styles_selector(prompt, my_unique_id, type)
title = "Positive encoding" if type == 'positive' else "Negative encoding"
# Translate cn to en
if model_type not in ['hydit'] and text is not None and has_chinese(text):
text = zh_to_en([text])[0]
if model_type in ['hydit', 'flux', 'mochi']:
log_node_warn(title + "...")
embeddings_final, = CLIPTextEncode().encode(clip, text) if text is not None else (None,)
return (embeddings_final, "", model, clip)
log_node_warn(title + "...")
positive_seed = find_wildcards_seed(my_unique_id, text, prompt)
model, clip, text, cond_decode, show_prompt, pipe_lora_stack = process_with_loras(
text, model, clip, type, positive_seed, can_load_lora, lora_stack, easyCache)
wildcard_prompt = cond_decode if show_prompt or styles_selector else ""
clipped = clip.clone()
# 当clip模型不存在t5xxl时可执行跳过层
if not hasattr(clip.cond_stage_model, 't5xxl'):
if clip_skip != 0:
clipped.clip_layer(clip_skip)
steps = steps if steps is not None else find_nearest_steps(my_unique_id, prompt)
return (advanced_encode(clipped, text, prompt_token_normalization,
prompt_weight_interpretation, w_max=1.0,
apply_to_pooled='enable',
a1111_prompt_style=a1111_prompt_style, steps=steps) if text is not None else None, wildcard_prompt, model, clipped)
def set_cond(old_cond, new_cond, mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end):
if not old_cond:
return new_cond
else:
if mode == "replace":
return new_cond
elif mode == "concat":
return ConditioningConcat().concat(new_cond, old_cond)[0]
elif mode == "combine":
return ConditioningCombine().combine(old_cond, new_cond)[0]
elif mode == 'average':
return ConditioningAverage().addWeighted(new_cond, old_cond, average_strength)[0]
elif mode == 'timestep':
cond_1 = ConditioningSetTimestepRange().set_range(old_cond, old_cond_start, old_cond_end)[0]
cond_2 = ConditioningSetTimestepRange().set_range(new_cond, new_cond_start, new_cond_end)[0]
return ConditioningCombine().combine(cond_1, cond_2)[0]

View File

@@ -0,0 +1,93 @@
import folder_paths
import comfy.controlnet
import comfy.model_management
from nodes import NODE_CLASS_MAPPINGS
union_controlnet_types = {"auto": -1, "openpose": 0, "depth": 1, "hed/pidi/scribble/ted": 2, "canny/lineart/anime_lineart/mlsd": 3, "normal": 4, "segment": 5, "tile": 6, "repaint": 7}
class easyControlnet:
def __init__(self):
pass
def apply(self, control_net_name, image, positive, negative, strength, start_percent=0, end_percent=1, control_net=None, scale_soft_weights=1, mask=None, union_type=None, easyCache=None, use_cache=True, model=None, vae=None):
if strength == 0:
return (positive, negative)
# kolors controlnet patch
from ..modules.kolors.loader import is_kolors_model, applyKolorsUnet
if is_kolors_model(model):
from ..modules.kolors.model_patch import patch_controlnet
if control_net is None:
with applyKolorsUnet():
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
control_net = patch_controlnet(model, control_net)
else:
if control_net is None:
if easyCache is not None:
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
else:
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
control_net = comfy.controlnet.load_controlnet(controlnet_path)
# union controlnet
if union_type is not None:
control_net = control_net.copy()
type_number = union_controlnet_types[union_type]
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:
control_net.set_extra_arg("control_type", [])
if mask is not None:
mask = mask.to(self.device)
if mask is not None and len(mask.shape) < 3:
mask = mask.unsqueeze(0)
control_hint = image.movedim(-1, 1)
is_cond = True
if negative is None:
p = []
for t in positive:
n = [t[0], t[1].copy()]
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
n[1]['control_apply_to_uncond'] = True
if mask is not None:
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = False
p.append(n)
positive = p
else:
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
if mask is not None:
d['mask'] = mask
d['set_area_to_bounds'] = False
n = [t[0], d]
c.append(n)
out.append(c)
positive = out[0]
negative = out[1]
return (positive, negative)

View File

@@ -0,0 +1,167 @@
import torch, math
######################### DynThresh Core #########################
class DynThresh:
Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
Startpoints = ["MEAN", "ZERO"]
Variabilities = ["AD", "STD"]
def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, max_steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
self.mimic_scale = mimic_scale
self.threshold_percentile = threshold_percentile
self.mimic_mode = mimic_mode
self.cfg_mode = cfg_mode
self.max_steps = max_steps
self.cfg_scale_min = cfg_scale_min
self.mimic_scale_min = mimic_scale_min
self.experiment_mode = experiment_mode
self.sched_val = sched_val
self.sep_feat_channels = separate_feature_channels
self.scaling_startpoint = scaling_startpoint
self.variability_measure = variability_measure
self.interpolate_phi = interpolate_phi
def interpret_scale(self, scale, mode, min):
scale -= min
max = self.max_steps - 1
frac = self.step / max
if mode == "Constant":
pass
elif mode == "Linear Down":
scale *= 1.0 - frac
elif mode == "Half Cosine Down":
scale *= math.cos(frac)
elif mode == "Cosine Down":
scale *= math.cos(frac * 1.5707)
elif mode == "Linear Up":
scale *= frac
elif mode == "Half Cosine Up":
scale *= 1.0 - math.cos(frac)
elif mode == "Cosine Up":
scale *= 1.0 - math.cos(frac * 1.5707)
elif mode == "Power Up":
scale *= math.pow(frac, self.sched_val)
elif mode == "Power Down":
scale *= 1.0 - math.pow(frac, self.sched_val)
elif mode == "Linear Repeating":
portion = (frac * self.sched_val) % 1.0
scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2
elif mode == "Cosine Repeating":
scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5
elif mode == "Sawtooth":
scale *= (frac * self.sched_val) % 1.0
scale += min
return scale
def dynthresh(self, cond, uncond, cfg_scale, weights):
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
# uncond shape is (batch, 4, height, width)
conds_per_batch = cond.shape[0] / uncond.shape[0]
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
### Normal first part of the CFG Scale logic, basically
diff = cond_stacked - uncond.unsqueeze(1)
if weights is not None:
diff = diff * weights
relative = diff.sum(1)
### Get the normal result for both mimic and normal scale
mim_target = uncond + relative * mimic_scale
cfg_target = uncond + relative * cfg_scale
### If we weren't doing mimic scale, we'd just return cfg_target here
### Now recenter the values relative to their average rather than absolute, to allow scaling from average
mim_flattened = mim_target.flatten(2)
cfg_flattened = cfg_target.flatten(2)
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
mim_centered = mim_flattened - mim_means
cfg_centered = cfg_flattened - cfg_means
if self.sep_feat_channels:
if self.variability_measure == 'STD':
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
else: # 'AD'
mim_scaleref = mim_centered.abs().max(dim=2).values.unsqueeze(2)
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile, dim=2).unsqueeze(2)
else:
if self.variability_measure == 'STD':
mim_scaleref = mim_centered.std()
cfg_scaleref = cfg_centered.std()
else: # 'AD'
mim_scaleref = mim_centered.abs().max()
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile)
if self.scaling_startpoint == 'ZERO':
scaling_factor = mim_scaleref / cfg_scaleref
result = cfg_flattened * scaling_factor
else: # 'MEAN'
if self.variability_measure == 'STD':
cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref
else: # 'AD'
### Get the maximum value of all datapoints (with an optional threshold percentile on the uncond)
max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref)
### Clamp to the max
cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref)
### Now shrink from the max to normalize and grow to the mimic scale (instead of the CFG scale)
cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref
### Now add it back onto the averages to get into real scale again and return
result = cfg_renormalized + cfg_means
actual_res = result.unflatten(2, mim_target.shape[2:])
if self.interpolate_phi != 1.0:
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
if self.experiment_mode == 1:
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
if num[0][0][y][x] > 1.0:
num[0][1][y][x] *= 0.5
if num[0][1][y][x] > 1.0:
num[0][1][y][x] *= 0.5
if num[0][2][y][x] > 1.5:
num[0][2][y][x] *= 0.5
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 2:
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
over_scale = False
for z in range(0, 4):
if abs(num[0][z][y][x]) > 1.5:
over_scale = True
if over_scale:
for z in range(0, 4):
num[0][z][y][x] *= 0.7
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 3:
coefs = torch.tensor([
# R G B W
[0.298, 0.207, 0.208, 0.0], # L1
[0.187, 0.286, 0.173, 0.0], # L2
[-0.158, 0.189, 0.264, 0.0], # L3
[-0.184, -0.271, -0.473, 1.0], # L4
], device=uncond.device)
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
max_rgb = max(max_r, max_g, max_b)
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
if self.step / (self.max_steps - 1) > 0.2:
if max_rgb < 2.0 and max_w < 3.0:
res_rgb /= max_rgb / 2.4
else:
if max_rgb > 2.4 and max_w > 3.0:
res_rgb /= max_rgb / 2.4
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
return actual_res

View File

@@ -0,0 +1,27 @@
@staticmethod
def easyIn(t: float)-> float:
return t*t
@staticmethod
def easyOut(t: float)-> float:
return -(t * (t - 2))
@staticmethod
def easyInOut(t: float)-> float:
if t < 0.5:
return 2*t*t
else:
return (-2*t*t) + (4*t) - 1
class EasingBase:
def easing(self, t: float, function='linear') -> float:
if function == 'easyIn':
return easyIn(t)
elif function == 'easyOut':
return easyOut(t)
elif function == 'easyInOut':
return easyInOut(t)
else:
return t
def ease(self, start, end, t) -> float:
return end * t + start * (1 - t)

View File

@@ -0,0 +1,273 @@
import torch
from torchvision.transforms.functional import gaussian_blur
from comfy.k_diffusion.sampling import default_noise_sampler, get_ancestral_step, to_d, BrownianTreeNoiseSampler
from tqdm.auto import trange
@torch.no_grad()
def sample_euler_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x + noise * sigma_up * s_noise
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x + noise * sigma_up * s_noise
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
solver_type="midpoint",
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {"heun", "midpoint"}:
raise ValueError("solver_type must be 'heun' or 'midpoint'")
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == "heun":
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == "midpoint":
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
denoised = None # 次ステップとサイズがあわないのでとりあえずNoneにしておく。
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_lcm(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
noise_sampler=None,
eta=None,
s_noise=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
x = denoised
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x

View File

@@ -0,0 +1,227 @@
import os
import base64
import torch
import numpy as np
from enum import Enum
from PIL import Image
from io import BytesIO
from typing import List, Union
import folder_paths
from .utils import install_package
# PIL to Tensor
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
# Tensor to PIL
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
# np to Tensor
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
if isinstance(img_np, list):
return torch.cat([np2tensor(img) for img in img_np], dim=0)
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
# Tensor to np
def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
if len(tensor.shape) == 3: # Single image
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
else: # Batch of images
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
def pil2byte(pil_image, format='PNG'):
byte_arr = BytesIO()
pil_image.save(byte_arr, format=format)
byte_arr.seek(0)
return byte_arr
def image2base64(image_base64):
image_bytes = base64.b64decode(image_base64)
image_data = Image.open(BytesIO(image_bytes))
return image_data
# Get new bounds
def get_new_bounds(width, height, left, right, top, bottom):
"""Returns the new bounds for an image with inset crop data."""
left = 0 + left
right = width - right
top = 0 + top
bottom = height - bottom
return (left, right, top, bottom)
def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert('RGB').split()
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
def image2mask(image: Image) -> torch.Tensor:
_image = image.convert('RGBA')
alpha = _image.split()[0]
bg = Image.new("L", _image.size)
_image = Image.merge('RGBA', (bg, bg, bg, alpha))
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
return ret_mask
def mask2image(mask: torch.Tensor) -> Image:
masks = tensor2np(mask)
for m in masks:
_mask = Image.fromarray(m).convert("L")
_image = Image.new("RGBA", _mask.size, color='white')
_image = Image.composite(
_image, Image.new("RGBA", _mask.size, color='black'), _mask)
return _image
# 图像融合
class blendImage:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1),
img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str = 'normal'):
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic',
crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return blended_image
def empty_image(width, height, batch_size=1, color=0):
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
return torch.cat((r, g, b), dim=-1)
class ResizeMode(Enum):
RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"
def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"
# credit by https://github.com/chflame163/ComfyUI_LayerStyle/blob/main/py/imagefunc.py#L591C1-L617C22
def fit_resize_image(image: Image, target_width: int, target_height: int, fit: str, resize_sampler: str,
background_color: str = '#000000') -> Image:
image = image.convert('RGB')
orig_width, orig_height = image.size
if image is not None:
if fit == 'letterbox':
if orig_width / orig_height > target_width / target_height: # 更宽,上下留黑
fit_width = target_width
fit_height = int(target_width / orig_width * orig_height)
else: # 更瘦,左右留黑
fit_height = target_height
fit_width = int(target_height / orig_height * orig_width)
fit_image = image.resize((fit_width, fit_height), resize_sampler)
ret_image = Image.new('RGB', size=(target_width, target_height), color=background_color)
ret_image.paste(fit_image, box=((target_width - fit_width) // 2, (target_height - fit_height) // 2))
elif fit == 'crop':
if orig_width / orig_height > target_width / target_height: # 更宽,裁左右
fit_width = int(orig_height * target_width / target_height)
fit_image = image.crop(
((orig_width - fit_width) // 2, 0, (orig_width - fit_width) // 2 + fit_width, orig_height))
else: # 更瘦,裁上下
fit_height = int(orig_width * target_height / target_width)
fit_image = image.crop(
(0, (orig_height - fit_height) // 2, orig_width, (orig_height - fit_height) // 2 + fit_height))
ret_image = fit_image.resize((target_width, target_height), resize_sampler)
else:
ret_image = image.resize((target_width, target_height), resize_sampler)
return ret_image
# CLIP反推
import comfy.utils
from torchvision import transforms
Config, Interrogator = None, None
class CI_Inference:
ci_model = None
cache_path: str
def __init__(self):
self.ci_model = None
self.low_vram = False
self.cache_path = os.path.join(folder_paths.models_dir, "clip_interrogator")
def _load_model(self, model_name, low_vram=False):
if not (self.ci_model and model_name == self.ci_model.config.clip_model_name and self.low_vram == low_vram):
self.low_vram = low_vram
print(f"Load model: {model_name}")
config = Config(
device="cuda" if torch.cuda.is_available() else "cpu",
download_cache=True,
clip_model_name=model_name,
clip_model_path=self.cache_path,
cache_path=self.cache_path,
caption_model_name='blip-large'
)
if low_vram:
config.apply_low_vram_defaults()
self.ci_model = Interrogator(config)
def _interrogate(self, image, mode, caption=None):
if mode == 'best':
prompt = self.ci_model.interrogate(image, caption=caption)
elif mode == 'classic':
prompt = self.ci_model.interrogate_classic(image, caption=caption)
elif mode == 'fast':
prompt = self.ci_model.interrogate_fast(image, caption=caption)
elif mode == 'negative':
prompt = self.ci_model.interrogate_negative(image)
else:
raise Exception(f"Unknown mode {mode}")
return prompt
def image_to_prompt(self, image, mode, model_name='ViT-L-14/openai', low_vram=False):
try:
from clip_interrogator import Config, Interrogator
global Config, Interrogator
except:
install_package("clip_interrogator", "0.6.0")
from clip_interrogator import Config, Interrogator
pbar = comfy.utils.ProgressBar(len(image))
self._load_model(model_name, low_vram)
prompt = []
for i in range(len(image)):
im = image[i]
im = tensor2pil(im)
im = im.convert('RGB')
_prompt = self._interrogate(im, mode)
pbar.update(1)
prompt.append(_prompt)
return prompt
ci = CI_Inference()

View File

@@ -0,0 +1,237 @@
import math
import torch
import comfy
def extra_options_to_module_prefix(extra_options):
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
# block_index is: 0-1 or 0-9, depends on the block
# input 7 and 8, middle has 10 blocks
# make module name from extra_options
block = extra_options["block"]
block_index = extra_options["block_index"]
if block[0] == "input":
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
elif block[0] == "middle":
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
elif block[0] == "output":
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
else:
raise Exception("invalid block name")
return module_pfx
def load_control_net_lllite_patch(path, cond_image, multiplier, num_steps, start_percent, end_percent):
# calculate start and end step
start_step = math.floor(num_steps * start_percent * 0.01) if start_percent > 0 else 0
end_step = math.floor(num_steps * end_percent * 0.01) if end_percent > 0 else num_steps
# load weights
ctrl_sd = comfy.utils.load_torch_file(path, safe_load=True)
# split each weights for each module
module_weights = {}
for key, value in ctrl_sd.items():
fragments = key.split(".")
module_name = fragments[0]
weight_name = ".".join(fragments[1:])
if module_name not in module_weights:
module_weights[module_name] = {}
module_weights[module_name][weight_name] = value
# load each module
modules = {}
for module_name, weights in module_weights.items():
# ここの自動判定を何とかしたい
if "conditioning1.4.weight" in weights:
depth = 3
elif weights["conditioning1.2.weight"].shape[-1] == 4:
depth = 2
else:
depth = 1
module = LLLiteModule(
name=module_name,
is_conv2d=weights["down.0.weight"].ndim == 4,
in_dim=weights["down.0.weight"].shape[1],
depth=depth,
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
mlp_dim=weights["down.0.weight"].shape[0],
multiplier=multiplier,
num_steps=num_steps,
start_step=start_step,
end_step=end_step,
)
info = module.load_state_dict(weights)
modules[module_name] = module
if len(modules) == 1:
module.is_first = True
print(f"loaded {path} successfully, {len(modules)} modules")
# cond imageをセットする
cond_image = cond_image.permute(0, 3, 1, 2) # b,h,w,3 -> b,3,h,w
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-+1
for module in modules.values():
module.set_cond_image(cond_image)
class control_net_lllite_patch:
def __init__(self, modules):
self.modules = modules
def __call__(self, q, k, v, extra_options):
module_pfx = extra_options_to_module_prefix(extra_options)
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
if is_attn1:
module_pfx = module_pfx + "_attn1"
else:
module_pfx = module_pfx + "_attn2"
module_pfx_to_q = module_pfx + "_to_q"
module_pfx_to_k = module_pfx + "_to_k"
module_pfx_to_v = module_pfx + "_to_v"
if module_pfx_to_q in self.modules:
q = q + self.modules[module_pfx_to_q](q)
if module_pfx_to_k in self.modules:
k = k + self.modules[module_pfx_to_k](k)
if module_pfx_to_v in self.modules:
v = v + self.modules[module_pfx_to_v](v)
return q, k, v
def to(self, device):
for d in self.modules.keys():
self.modules[d] = self.modules[d].to(device)
return self
return control_net_lllite_patch(modules)
class LLLiteModule(torch.nn.Module):
def __init__(
self,
name: str,
is_conv2d: bool,
in_dim: int,
depth: int,
cond_emb_dim: int,
mlp_dim: int,
multiplier: int,
num_steps: int,
start_step: int,
end_step: int,
):
super().__init__()
self.name = name
self.is_conv2d = is_conv2d
self.multiplier = multiplier
self.num_steps = num_steps
self.start_step = start_step
self.end_step = end_step
self.is_first = False
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else:
self.down = torch.nn.Sequential(
torch.nn.Linear(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
self.depth = depth
self.cond_image = None
self.cond_emb = None
self.current_step = 0
# @torch.inference_mode()
def set_cond_image(self, cond_image):
# print("set_cond_image", self.name)
self.cond_image = cond_image
self.cond_emb = None
self.current_step = 0
def forward(self, x):
if self.num_steps > 0:
if self.current_step < self.start_step:
self.current_step += 1
return torch.zeros_like(x)
elif self.current_step >= self.end_step:
if self.is_first and self.current_step == self.end_step:
print(f"end LLLite: step {self.current_step}")
self.current_step += 1
if self.current_step >= self.num_steps:
self.current_step = 0 # reset
return torch.zeros_like(x)
else:
if self.is_first and self.current_step == self.start_step:
print(f"start LLLite: step {self.current_step}")
self.current_step += 1
if self.current_step >= self.num_steps:
self.current_step = 0 # reset
if self.cond_emb is None:
# print(f"cond_emb is None, {self.name}")
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
cx = self.cond_emb
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
# uncond/condでxはバッチサイズが2倍
if x.shape[0] != cx.shape[0]:
if self.is_conv2d:
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
else:
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
cx = self.up(cx)
return cx * self.multiplier

View File

@@ -0,0 +1,570 @@
import re, time, os, psutil
import folder_paths
import comfy.utils
import comfy.sd
import comfy.controlnet
from comfy.model_patcher import ModelPatcher
from nodes import NODE_CLASS_MAPPINGS
from collections import defaultdict
from .log import log_node_info, log_node_error
from ..modules.dit.pixArt.loader import load_pixart
diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy fluxLoader", "easy comfyLoader", "easy hunyuanDiTLoader", "easy zero123Loader", "easy svdLoader"]
stable_cascade_loaders = ["easy cascadeLoader"]
dit_loaders = ['easy pixArtLoader']
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV", "easy controlnetLoader++"]
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"]
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"]
model_merge_node = ["easy XYInputs: ModelMergeBlocks"]
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy fluxLoader"]
class easyLoader:
def __init__(self):
self.loaded_objects = {
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)}
"unet": defaultdict(tuple),
"clip": defaultdict(tuple),
"clip_vision": defaultdict(tuple),
"bvae": defaultdict(tuple),
"vae": defaultdict(object),
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}}
"controlnet": defaultdict(dict),
"t5": defaultdict(tuple),
"chatglm3": defaultdict(tuple),
}
self.memory_threshold = self.determine_memory_threshold(1)
self.lora_name_cache = []
def clean_values(self, values: str):
original_values = values.split("; ")
cleaned_values = []
for value in original_values:
cleaned_value = value.strip(';').strip()
if cleaned_value == "":
continue
try:
cleaned_value = int(cleaned_value)
except ValueError:
try:
cleaned_value = float(cleaned_value)
except ValueError:
pass
cleaned_values.append(cleaned_value)
return cleaned_values
def clear_unused_objects(self, desired_names: set, object_type: str):
keys = set(self.loaded_objects[object_type].keys())
for key in keys - desired_names:
del self.loaded_objects[object_type][key]
def get_input_value(self, entry, key, prompt=None):
val = entry["inputs"][key]
if isinstance(val, str):
return val
elif isinstance(val, list):
if prompt is not None and val[0]:
return prompt[val[0]]['inputs'][key]
else:
return val[0]
else:
return str(val)
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""):
for idx in range(1, num_loras + 1):
lora_name_key = f"{suffix}lora{idx}_name"
desired_lora_names.add(self.get_input_value(entry, lora_name_key))
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}'
desired_lora_settings.add(setting)
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name"))
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name"))
def update_loaded_objects(self, prompt):
desired_ckpt_names = set()
desired_unet_names = set()
desired_clip_names = set()
desired_vae_names = set()
desired_lora_names = set()
desired_lora_settings = set()
desired_controlnet_names = set()
desired_t5_names = set()
desired_glm3_names = set()
for entry in prompt.values():
class_type = entry["class_type"]
if class_type in lora_widget:
lora_name = self.get_input_value(entry, "lora_name")
desired_lora_names.add(lora_name)
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}'
desired_lora_settings.add(setting)
if class_type in diffusion_loaders:
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt))
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
elif class_type in ['easy kolorsLoader']:
desired_unet_names.add(self.get_input_value(entry, "unet_name"))
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name"))
elif class_type in dit_loaders:
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None
model_name = self.get_input_value(entry, "model_name")
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt)
if t5_name:
desired_t5_names.add(t5_name)
if clip_name:
desired_clip_names.add(clip_name)
desired_ckpt_names.add(ckpt_name+'_'+model_name)
elif class_type in stable_cascade_loaders:
desired_unet_names.add(self.get_input_value(entry, "stage_c"))
desired_unet_names.add(self.get_input_value(entry, "stage_b"))
desired_clip_names.add(self.get_input_value(entry, "clip_name"))
desired_vae_names.add(self.get_input_value(entry, "stage_a"))
elif class_type in cascade_vae_node:
encode_vae_name = self.get_input_value(entry, "encode_vae_name")
decode_vae_name = self.get_input_value(entry, "decode_vae_name")
if encode_vae_name and encode_vae_name != 'None':
desired_vae_names.add(encode_vae_name)
if decode_vae_name and decode_vae_name != 'None':
desired_vae_names.add(decode_vae_name)
elif class_type in controlnet_loaders:
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights")
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
elif class_type in instant_loaders:
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights")
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
elif class_type in model_merge_node:
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1"))
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2"))
vae_use = self.get_input_value(entry, "vae_use")
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2':
desired_vae_names.add(vae_use)
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"]
for object_type in object_types:
if object_type == 'unet':
desired_names = desired_unet_names
elif object_type in ["ckpt", "clip", "bvae"]:
if object_type == 'clip':
desired_names = desired_ckpt_names.union(desired_clip_names)
else:
desired_names = desired_ckpt_names
elif object_type == "vae":
desired_names = desired_vae_names
elif object_type == "controlnet":
desired_names = desired_controlnet_names
elif object_type == "t5":
desired_names = desired_t5_names
elif object_type == "chatglm3":
desired_names = desired_glm3_names
else:
desired_names = desired_lora_names
self.clear_unused_objects(desired_names, object_type)
def add_to_cache(self, obj_type, key, value):
"""
Add an item to the cache with the current timestamp.
"""
timestamped_value = (value, time.time())
self.loaded_objects[obj_type][key] = timestamped_value
def determine_memory_threshold(self, percentage=0.8):
"""
Determines the memory threshold as a percentage of the total available memory.
Args:
- percentage (float): The fraction of total memory to use as the threshold.
Should be a value between 0 and 1. Default is 0.8 (80%).
Returns:
- memory_threshold (int): Memory threshold in bytes.
"""
total_memory = psutil.virtual_memory().total
memory_threshold = total_memory * percentage
return memory_threshold
def get_memory_usage(self):
"""
Returns the memory usage of the current process in bytes.
"""
process = psutil.Process(os.getpid())
return process.memory_info().rss
def eviction_based_on_memory(self):
"""
Evicts objects from cache based on memory usage and priority.
"""
current_memory = self.get_memory_usage()
if current_memory < self.memory_threshold:
return
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"]
for obj_type in eviction_order:
if current_memory < self.memory_threshold:
break
# Sort items based on age (using the timestamp)
items = list(self.loaded_objects[obj_type].items())
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp
for item in items:
if current_memory < self.memory_threshold:
break
del self.loaded_objects[obj_type][item[0]]
current_memory = self.get_memory_usage()
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False):
cache_name = ckpt_name
if config_name not in [None, "Default"]:
cache_name = ckpt_name + "_" + config_name
if cache_name in self.loaded_objects["ckpt"]:
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
output_clip = False if load_vision else True
output_clipvision = True if load_vision else False
if config_name not in [None, "Default"]:
config_path = folder_paths.get_full_path("configs", config_name)
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
else:
model_options = {}
if re.search("nf4", ckpt_name):
from ..modules.bitsandbytes_NF4 import OPS
model_options = {"custom_operations": OPS}
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0])
self.add_to_cache("bvae", cache_name, loaded_ckpt[2])
clip = loaded_ckpt[1]
clip_vision = loaded_ckpt[3]
if clip:
self.add_to_cache("clip", cache_name, clip)
if clip_vision:
self.add_to_cache("clip_vision", cache_name, clip_vision)
self.eviction_based_on_memory()
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision
def load_vae(self, vae_name):
if vae_name in self.loaded_objects["vae"]:
return self.loaded_objects["vae"][vae_name][0]
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
loaded_vae = comfy.sd.VAE(sd=sd)
self.add_to_cache("vae", vae_name, loaded_vae)
self.eviction_based_on_memory()
return loaded_vae
def load_unet(self, unet_name):
if unet_name in self.loaded_objects["unet"]:
log_node_info("Load UNet", f"{unet_name} cached")
return self.loaded_objects["unet"][unet_name][0]
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path)
self.add_to_cache("unet", unet_name, model)
self.eviction_based_on_memory()
return model
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True):
unique_id = f'{control_net_name};{str(scale_soft_weights)}'
if use_cache and unique_id in self.loaded_objects["controlnet"]:
return self.loaded_objects["controlnet"][unique_id][0]
if scale_soft_weights < 1:
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS:
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights']
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False)
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced']
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe)
else:
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'")
else:
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
control_net = comfy.controlnet.load_controlnet(controlnet_path)
if use_cache:
self.add_to_cache("controlnet", unique_id, control_net)
self.eviction_based_on_memory()
return control_net
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None):
if clip_name in self.loaded_objects["clip"]:
return self.loaded_objects["clip"][clip_name][0]
if type == 'stable_diffusion':
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == 'stable_cascade':
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == 'sd3':
clip_type = comfy.sd.CLIPType.SD3
elif type == 'flux':
clip_type = comfy.sd.CLIPType.FLUX
elif type == 'stable_audio':
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
clip_path = folder_paths.get_full_path("clip", clip_name)
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
self.add_to_cache("clip", clip_name, load_clip)
self.eviction_based_on_memory()
return load_clip
def load_lora(self, lora, model=None, clip=None, type=None , use_cache=True):
lora_name = lora["lora_name"]
model = model if model is not None else lora["model"]
clip = clip if clip is not None else lora["clip"]
model_strength = lora["model_strength"]
clip_strength = lora["clip_strength"]
lbw = lora["lbw"] if "lbw" in lora else None
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None
model_hash = str(model)[44:-1]
clip_hash = str(clip)[25:-1] if clip else ''
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}'
if use_cache and unique_id in self.loaded_objects["lora"]:
log_node_info("Load LORA",f"{lora_name} cached")
return self.loaded_objects["lora"][unique_id][0]
orig_lora_name = lora_name
lora_name = self.resolve_lora_name(lora_name)
if lora_name is not None:
lora_path = folder_paths.get_full_path("loras", lora_name)
else:
lora_path = None
if lora_path is not None:
log_node_info("Load LORA",f"{lora_name}: model={model_strength:.3f}, clip={clip_strength:.3f}, LBW={lbw}, A={lbw_a}, B={lbw_b}")
if lbw:
lbw = lora["lbw"]
lbw_a = lora["lbw_a"]
lbw_b = lora["lbw_b"]
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS:
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack')
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire']
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0,
lbw_a, lbw_b, "", lbw)
else:
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
keys = _lora.keys()
if "down_blocks.0.resnets.0.norm1.bias" in keys:
print('Using LORA for Resadapter')
key_map = {}
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
mapping_norm = {}
for key in keys:
if ".weight" in key:
key_name_in_ori_sd = key_map[key.replace(".weight", "")]
mapping_norm[key_name_in_ori_sd] = _lora[key]
elif ".bias" in key:
key_name_in_ori_sd = key_map[key.replace(".bias", "")]
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[
key
]
else:
print("===>Unexpected key", key)
mapping_norm[key] = _lora[key]
for k in mapping_norm.keys():
if k not in model.model.state_dict():
print("===>Missing key:", k)
model.model.load_state_dict(mapping_norm, strict=False)
return (model, clip)
# PixArt
if type is not None and type == 'PixArt':
from ..modules.dit.pixArt.loader import load_pixart_lora
model = load_pixart_lora(model, _lora, lora_path, model_strength)
else:
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength)
if use_cache:
self.add_to_cache("lora", unique_id, (model, clip))
self.eviction_based_on_memory()
else:
log_node_error(f"LORA NOT FOUND", orig_lora_name)
return model, clip
def resolve_lora_name(self, name):
if os.path.exists(name):
return name
else:
if len(self.lora_name_cache) == 0:
loras = folder_paths.get_filename_list("loras")
self.lora_name_cache.extend(loras)
for x in self.lora_name_cache:
if x.endswith(name):
return x
# 如果刷新网页后新添加的lora走这个逻辑
log_node_info("LORA NOT IN CACHE", f"{name}")
loras = folder_paths.get_filename_list("loras")
for x in loras:
if x.endswith(name):
self.lora_name_cache.append(x)
return x
return None
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False):
model: ModelPatcher | None = None
clip: comfy.sd.CLIP | None = None
vae: comfy.sd.VAE | None = None
clip_vision = None
lora_stack = []
# Check for model override
can_load_lora = True
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型
# Determine whether there is a model or Lora overlapping xyplot, and if there is, prioritize caching the first model.
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks",
"easy XYInputs: Checkpoint"]), None)
# This will find nodes that aren't actively connected to anything, and skip loading lora's for them.
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None)
if xy_lora_id is not None:
can_load_lora = False
if xy_model_id is not None:
node = prompt[xy_model_id]
if "ckpt_name_1" in node["inputs"]:
ckpt_name_1 = node["inputs"]["ckpt_name_1"]
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1)
can_load_lora = False
elif model_override is not None and clip_override is not None and vae_override is not None:
model = model_override
clip = clip_override
vae = vae_override
else:
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name)
if model_override is not None:
model = model_override
if vae_override is not None:
vae = vae_override
elif clip_override is not None:
clip = clip_override
if optional_lora_stack is not None and can_load_lora:
for lora in optional_lora_stack:
# This is a subtle bit of code because it uses the model created by the last call, and passes it to the next call.
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1],
"clip_strength": lora[2]}
model, clip = self.load_lora(lora)
lora['model'] = model
lora['clip'] = clip
lora_stack.append(lora)
if lora_name != "None" and can_load_lora:
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength,
"clip_strength": lora_clip_strength}
model, clip = self.load_lora(lora)
lora_stack.append(lora)
# Check for custom VAE
if vae_name not in ["Baked VAE", "Baked-VAE"]:
vae = self.load_vae(vae_name)
# CLIP skip
if not clip:
raise Exception("No CLIP found")
return model, clip, vae, clip_vision, lora_stack
# Kolors
def load_kolors_unet(self, unet_name):
if unet_name in self.loaded_objects["unet"]:
log_node_info("Load Kolors UNet", f"{unet_name} cached")
return self.loaded_objects["unet"][unet_name][0]
else:
from ..modules.kolors.loader import applyKolorsUnet
with applyKolorsUnet():
unet_path = folder_paths.get_full_path("unet", unet_name)
sd = comfy.utils.load_torch_file(unet_path)
model = comfy.sd.load_unet_state_dict(sd)
if model is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
self.add_to_cache("unet", unet_name, model)
self.eviction_based_on_memory()
return model
def load_chatglm3(self, chatglm3_name):
from ..modules.kolors.loader import load_chatglm3
if chatglm3_name in self.loaded_objects["chatglm3"]:
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached")
return self.loaded_objects["chatglm3"][chatglm3_name][0]
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name))
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model)
self.eviction_based_on_memory()
return chatglm_model
# DiT
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs):
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]:
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0]
model = None
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt'
if model_type == 'PixArt':
pixart_conf = kwargs['pixart_conf']
model_conf = pixart_conf[model_name]
model = load_pixart(ckpt_path, model_conf)
if model:
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model)
self.eviction_based_on_memory()
return model
def load_t5_from_sd3_clip(self, sd3_clip, padding):
try:
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel
except:
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel
import copy
clip = sd3_clip.clone()
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!"
# remove transformer
transformer = clip.cond_stage_model.t5xxl.transformer
clip.cond_stage_model.t5xxl.transformer = None
# clone object
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False)
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl)
# put transformer back
clip.cond_stage_model.t5xxl.transformer = transformer
tmp.t5xxl.transformer = transformer
# override special tokens
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens)
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match
# tokenizer
tok = SD3Tokenizer()
tok.t5xxl.min_length = padding
clip.cond_stage_model = tmp
clip.tokenizer = tok
return clip

View File

@@ -0,0 +1,77 @@
COLORS_FG = {
'BLACK': '\33[30m',
'RED': '\33[31m',
'GREEN': '\33[32m',
'YELLOW': '\33[33m',
'BLUE': '\33[34m',
'MAGENTA': '\33[35m',
'CYAN': '\33[36m',
'WHITE': '\33[37m',
'GREY': '\33[90m',
'BRIGHT_RED': '\33[91m',
'BRIGHT_GREEN': '\33[92m',
'BRIGHT_YELLOW': '\33[93m',
'BRIGHT_BLUE': '\33[94m',
'BRIGHT_MAGENTA': '\33[95m',
'BRIGHT_CYAN': '\33[96m',
'BRIGHT_WHITE': '\33[97m',
}
COLORS_STYLE = {
'RESET': '\33[0m',
'BOLD': '\33[1m',
'NORMAL': '\33[22m',
'ITALIC': '\33[3m',
'UNDERLINE': '\33[4m',
'BLINK': '\33[5m',
'BLINK2': '\33[6m',
'SELECTED': '\33[7m',
}
COLORS_BG = {
'BLACK': '\33[40m',
'RED': '\33[41m',
'GREEN': '\33[42m',
'YELLOW': '\33[43m',
'BLUE': '\33[44m',
'MAGENTA': '\33[45m',
'CYAN': '\33[46m',
'WHITE': '\33[47m',
'GREY': '\33[100m',
'BRIGHT_RED': '\33[101m',
'BRIGHT_GREEN': '\33[102m',
'BRIGHT_YELLOW': '\33[103m',
'BRIGHT_BLUE': '\33[104m',
'BRIGHT_MAGENTA': '\33[105m',
'BRIGHT_CYAN': '\33[106m',
'BRIGHT_WHITE': '\33[107m',
}
def log_node_success(node_name, message=None):
"""Logs a success message."""
_log_node(COLORS_FG["GREEN"], node_name, message)
def log_node_info(node_name, message=None):
"""Logs an info message."""
_log_node(COLORS_FG["CYAN"], node_name, message)
def log_node_warn(node_name, message=None):
"""Logs an warn message."""
_log_node(COLORS_FG["YELLOW"], node_name, message)
def log_node_error(node_name, message=None):
"""Logs an warn message."""
_log_node(COLORS_FG["RED"], node_name, message)
def log_node(node_name, message=None):
"""Logs a message."""
_log_node(COLORS_FG["CYAN"], node_name, message)
def _log_node(color, node_name, message=None, prefix=''):
print(_get_log_msg(color, node_name, message, prefix=prefix))
def _get_log_msg(color, node_name, message=None, prefix=''):
msg = f'{COLORS_STYLE["BOLD"]}{color}{prefix}[EasyUse] {node_name.replace(" (EasyUse)", "")}'
msg += f':{COLORS_STYLE["RESET"]} {message}' if message is not None else f'{COLORS_STYLE["RESET"]}'
return msg

View File

@@ -0,0 +1,133 @@
"""
Math utility functions for formula evaluation
"""
import math
import re
def evaluate_formula(formula: str, a=0, b=0, c=0, d=0) -> float:
"""
计算字符串数学公式
支持的运算符和函数:
- 基本运算:+, -, *, /, //, %, **
- 比较运算:>, <, >=, <=, ==, !=
- 数学函数abs, pow, round, ceil, floor, sqrt, exp, log, log10
- 三角函数sin, cos, tan, asin, acos, atan
- 常量pi, e
Args:
formula: 数学公式字符串可以使用变量a、b、c、d
a: 变量a的值
b: 变量b的值
c: 变量c的值
d: 变量d的值
Returns:
计算结果
Examples:
>>> evaluate_formula("a + b", 1, 2)
3.0
>>> evaluate_formula("pow(a, 2)", 5)
25.0
>>> evaluate_formula("ceil(a / b)", 5, 2)
3.0
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 5, 3)
3.0
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 2, 3)
2.0
"""
# 安全的数学函数白名单
safe_dict = {
# 基本运算
'abs': abs,
'pow': pow,
'round': round,
# 数学函数
'ceil': math.ceil,
'floor': math.floor,
'sqrt': math.sqrt,
'exp': math.exp,
'log': math.log,
'log10': math.log10,
# 三角函数
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'asin': math.asin,
'acos': math.acos,
'atan': math.atan,
# 常量
'pi': math.pi,
'e': math.e,
# 变量
'a': float(a),
'b': float(b),
'c': float(c),
'd': float(d),
}
try:
# 使用eval计算公式限制可用的函数和变量
result = eval(formula, {"__builtins__": {}}, safe_dict)
return float(result)
except Exception as e:
raise ValueError(f"公式计算错误: {str(e)}")
def ceil_value(value: float) -> int:
"""向上取整"""
return math.ceil(value)
def floor_value(value: float) -> int:
"""向下取整"""
return math.floor(value)
def round_value(value: float, decimals: int = 0) -> float:
"""
四舍五入
Args:
value: 要取整的值
decimals: 保留小数位数
Returns:
四舍五入后的值
"""
return round(value, decimals)
def power(base: float, exponent: float) -> float:
"""计算幂运算"""
return math.pow(base, exponent)
def sqrt_value(value: float) -> float:
"""计算平方根"""
if value < 0:
raise ValueError("不能对负数求平方根")
return math.sqrt(value)
def add(a: float, b: float) -> float:
"""加法"""
return a + b
def subtract(a: float, b: float) -> float:
"""减法"""
return a - b
def multiply(a: float, b: float) -> float:
"""乘法"""
return a * b
def divide(a: float, b: float) -> float:
"""除法"""
if b == 0:
raise ValueError("除数不能为零")
return a / b

View File

@@ -0,0 +1,55 @@
from server import PromptServer
from aiohttp import web
import time
import json
class MessageCancelled(Exception):
pass
class Message:
stash = {}
messages = {}
cancelled = False
@classmethod
def addMessage(cls, id, message):
if message == '__cancel__':
cls.messages = {}
cls.cancelled = True
elif message == '__start__':
cls.messages = {}
cls.stash = {}
cls.cancelled = False
else:
cls.messages[str(id)] = message
@classmethod
def waitForMessage(cls, id, period=0.1, asList=False):
sid = str(id)
while not (sid in cls.messages) and not ("-1" in cls.messages):
if cls.cancelled:
cls.cancelled = False
raise MessageCancelled()
time.sleep(period)
if cls.cancelled:
cls.cancelled = False
raise MessageCancelled()
message = cls.messages.pop(str(id), None) or cls.messages.pop("-1")
try:
if asList:
return [str(x.strip()) for x in message.split(",")]
else:
try:
return json.loads(message)
except ValueError:
return message
except ValueError:
print( f"ERROR IN MESSAGE - failed to parse '${message}' as ${'comma separated list of strings' if asList else 'string'}")
return [message] if asList else message
@PromptServer.instance.routes.post('/easyuse/message_callback')
async def message_callback(request):
post = await request.post()
Message.addMessage(post.get("id"), post.get("message"))
return web.json_response({})

View File

@@ -0,0 +1,58 @@
import json
import os
import folder_paths
import server
from .utils import find_tags
class easyModelManager:
def __init__(self):
self.img_suffixes = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".tif", ".tiff"]
self.default_suffixes = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]
self.models_config = {
"checkpoints": {"suffix": self.default_suffixes},
"loras": {"suffix": self.default_suffixes},
"unet": {"suffix": self.default_suffixes},
}
self.model_lists = {}
def find_thumbnail(self, model_type, name):
file_no_ext = os.path.splitext(name)[0]
for ext in self.img_suffixes:
full_path = folder_paths.get_full_path(model_type, file_no_ext + ext)
if os.path.isfile(str(full_path)):
return full_path
return None
def get_model_lists(self, model_type):
if model_type not in self.models_config:
return []
filenames = folder_paths.get_filename_list(model_type)
model_lists = []
for name in filenames:
model_suffix = os.path.splitext(name)[-1]
if model_suffix not in self.models_config[model_type]["suffix"]:
continue
else:
cfg = {
"name": os.path.basename(os.path.splitext(name)[0]),
"full_name": name,
"remark": '',
"file_path": folder_paths.get_full_path(model_type, name),
"type": model_type,
"suffix": model_suffix,
"dir_tags": find_tags(name),
"cover": self.find_thumbnail(model_type, name),
"metadata": None,
"sha256": None
}
model_lists.append(cfg)
return model_lists
def get_model_info(self, model_type, model_name):
pass
# if __name__ == "__main__":
# manager = easyModelManager()
# print(manager.get_model_lists("checkpoints"))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,148 @@
import torch
import torch.nn as nn
from comfy.model_patcher import ModelPatcher
from typing import Union
T = torch.Tensor
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d
class StyleAlignedArgs:
def __init__(self, share_attn: str) -> None:
self.adain_keys = "k" in share_attn
self.adain_values = "v" in share_attn
self.adain_queries = "q" in share_attn
share_attention: bool = True
adain_queries: bool = True
adain_keys: bool = True
adain_values: bool = True
def expand_first(
feat: T,
scale=1.0,
) -> T:
"""
Expand the first element so it has the same shape as the rest of the batch.
"""
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
if scale == 1:
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
else:
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
return feat_style.reshape(*feat.shape)
def concat_first(feat: T, dim=2, scale=1.0) -> T:
"""
concat the the feature and the style feature expanded above
"""
feat_style = expand_first(feat, scale=scale)
return torch.cat((feat, feat_style), dim=dim)
def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]":
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
def adain(feat: T) -> T:
feat_mean, feat_std = calc_mean_std(feat)
feat_style_mean = expand_first(feat_mean)
feat_style_std = expand_first(feat_std)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
return feat
class SharedAttentionProcessor:
def __init__(self, args: StyleAlignedArgs, scale: float):
self.args = args
self.scale = scale
def __call__(self, q, k, v, extra_options):
if self.args.adain_queries:
q = adain(q)
if self.args.adain_keys:
k = adain(k)
if self.args.adain_values:
v = adain(v)
if self.args.share_attention:
k = concat_first(k, -2, scale=self.scale)
v = concat_first(v, -2)
return q, k, v
def get_norm_layers(
layer: nn.Module,
norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]",
share_layer_norm: bool,
share_group_norm: bool,
):
if isinstance(layer, nn.LayerNorm) and share_layer_norm:
norm_layers_["layer"].append(layer)
if isinstance(layer, nn.GroupNorm) and share_group_norm:
norm_layers_["group"].append(layer)
else:
for child_layer in layer.children():
get_norm_layers(
child_layer, norm_layers_, share_layer_norm, share_group_norm
)
def register_norm_forward(
norm_layer: Union[nn.GroupNorm, nn.LayerNorm],
) -> Union[nn.GroupNorm, nn.LayerNorm]:
if not hasattr(norm_layer, "orig_forward"):
setattr(norm_layer, "orig_forward", norm_layer.forward)
orig_forward = norm_layer.orig_forward
def forward_(hidden_states: T) -> T:
n = hidden_states.shape[-2]
hidden_states = concat_first(hidden_states, dim=-2)
hidden_states = orig_forward(hidden_states) # type: ignore
return hidden_states[..., :n, :]
norm_layer.forward = forward_ # type: ignore
return norm_layer
def register_shared_norm(
model: ModelPatcher,
share_group_norm: bool = True,
share_layer_norm: bool = True,
):
norm_layers = {"group": [], "layer": []}
get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm)
print(
f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms."
)
return [register_norm_forward(layer) for layer in norm_layers["group"]] + [
register_norm_forward(layer) for layer in norm_layers["layer"]
]
SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"]
SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"]
def styleAlignBatch(model, share_norm, share_attn, scale=1.0):
m = model.clone()
share_group_norm = share_norm in ["group", "both"]
share_layer_norm = share_norm in ["layer", "both"]
register_shared_norm(model, share_group_norm, share_layer_norm)
args = StyleAlignedArgs(share_attn)
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
return m

View File

@@ -0,0 +1,247 @@
#credit to shadowcz007 for this module
#from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py
import re
import os
import folder_paths
import comfy.utils
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .utils import install_package
try:
from lark import Lark, Transformer, v_args
except:
print('install lark...')
install_package('lark')
from lark import Lark, Transformer, v_args
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator')
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en')
zh_en_model, zh_en_tokenizer = None, None
def correct_prompt_syntax(prompt=""):
# print("input prompt",prompt)
corrected_elements = []
# 处理成统一的英文标点
prompt = prompt.replace('', '(').replace('', ')').replace('', ',').replace(';', ',').replace('', '.').replace('',':').replace('\\',',')
# 删除多余的空格
prompt = re.sub(r'\s+', ' ', prompt).strip()
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']')
# 分词
prompt_elements = prompt.split(',')
def balance_brackets(element, open_bracket, close_bracket):
open_brackets_count = element.count(open_bracket)
close_brackets_count = element.count(close_bracket)
return element + close_bracket * (open_brackets_count - close_brackets_count)
for element in prompt_elements:
element = element.strip()
# 处理空元素
if not element:
continue
# 检查并处理圆括号、方括号、尖括号
if element[0] in '([':
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']')
elif element[0] == '<':
corrected_element = balance_brackets(element, '<', '>')
else:
# 删除开头的右括号或右方括号
corrected_element = element.lstrip(')]')
corrected_elements.append(corrected_element)
# 重组修正后的prompt
return ','.join(corrected_elements)
def detect_language(input_str):
# 统计中文和英文字符的数量
count_cn = count_en = 0
for char in input_str:
if '\u4e00' <= char <= '\u9fff':
count_cn += 1
elif char.isalpha():
count_en += 1
# 根据统计的字符数量判断主要语言
if count_cn > count_en:
return "cn"
elif count_en > count_cn:
return "en"
else:
return "unknow"
def has_chinese(text):
has_cn = False
_text = text
_text = re.sub(r'<.*?>', '', _text)
_text = re.sub(r'__.*?__', '', _text)
_text = re.sub(r'embedding:.*?$', '', _text)
for char in _text:
if '\u4e00' <= char <= '\u9fff':
has_cn = True
break
elif char.isalpha():
continue
return has_cn
def translate(text):
global zh_en_model_path, zh_en_model, zh_en_tokenizer
if not os.path.exists(zh_en_model_path):
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
if zh_en_model is None:
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
encoded = zh_en_tokenizer([text], return_tensors="pt")
encoded.to(zh_en_model.device)
sequences = zh_en_model.generate(**encoded)
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
@v_args(inline=True) # Decorator to flatten the tree directly into the function arguments
class ChinesePromptTranslate(Transformer):
def sentence(self, *args):
return ", ".join(args)
def phrase(self, *args):
return "".join(args)
def emphasis(self, *args):
# Reconstruct the emphasis with translated content
return "(" + "".join(args) + ")"
def weak_emphasis(self, *args):
print('weak_emphasis:', args)
return "[" + "".join(args) + "]"
def embedding(self, *args):
print('prompt embedding', args[0])
if len(args) == 1:
embedding_name = str(args[0])
return f"embedding:{embedding_name}"
elif len(args) > 1:
embedding_name, *numbers = args
if len(numbers) == 2:
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}"
elif len(numbers) == 1:
return f"embedding:{embedding_name}:{numbers[0]}"
else:
return f"embedding:{embedding_name}"
def lora(self, *args):
if len(args) == 1:
return f"<lora:{args[0]}>"
elif len(args) > 1:
# print('lora', args)
_, loar_name, *numbers = args
loar_name = str(loar_name).strip()
if len(numbers) == 2:
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>"
elif len(numbers) == 1:
return f"<lora:{loar_name}:{numbers[0]}>"
else:
return f"<lora:{loar_name}>"
def weight(self, word, number):
translated_word = translate(str(word)).rstrip('.')
return f"({translated_word}:{str(number).strip()})"
def schedule(self, *args):
print('prompt schedule', args)
data = [str(arg).strip() for arg in args]
return f"[{':'.join(data)}]"
def word(self, word):
# Translate each word using the dictionary
word = str(word)
match_cn = re.search(r'@.*?@', word)
if re.search(r'__.*?__', word):
return word.rstrip('.')
elif match_cn:
chinese = match_cn.group()
before = word.split('@', 1)
before = before[0] if len(before) > 0 else ''
before = translate(str(before)).rstrip('.') if before else ''
after = word.rsplit('@', 1)
after = after[len(after)-1] if len(after) > 1 else ''
after = translate(after).rstrip('.') if after else ''
return before + chinese.replace('@', '').rstrip('.') + after
elif detect_language(word) == "cn":
return translate(word).rstrip('.')
else:
return word.rstrip('.')
#定义Prompt文法
grammar = r"""
start: sentence
sentence: phrase ("," phrase)*
phrase: emphasis | weight | word | lora | embedding | schedule
emphasis: "(" sentence ")" -> emphasis
| "[" sentence "]" -> weak_emphasis
weight: "(" word ":" NUMBER ")"
schedule: "[" word ":" word ":" NUMBER "]"
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">"
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)?
word: WORD
NUMBER: /\s*-?\d+(\.\d+)?\s*/
WORD: /[^,:\(\)\[\]<>]+/
"""
def zh_to_en(text):
global zh_en_model_path, zh_en_model, zh_en_tokenizer
# 进度条
pbar = comfy.utils.ProgressBar(len(text) + 1)
texts = [correct_prompt_syntax(t) for t in text]
install_package('sentencepiece', '0.2.0')
if not os.path.exists(zh_en_model_path):
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
if zh_en_model is None:
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
prompt_result = []
en_texts = []
for t in texts:
if t:
# translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t))
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate())
# print('t',t)
result = parser.parse(t).children
# print('en_result',result)
# en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax)
en_texts.append(result[0])
zh_en_model.to('cpu')
# print("test en_text", en_texts)
# en_text.to("cuda" if torch.cuda.is_available() else "cpu")
pbar.update(1)
for t in en_texts:
prompt_result.append(t)
pbar.update(1)
# print('prompt_result', prompt_result, )
if len(prompt_result) == 0:
prompt_result = [""]
return prompt_result

View File

@@ -0,0 +1,282 @@
class AlwaysEqualProxy(str):
def __eq__(self, _):
return True
def __ne__(self, _):
return False
class TautologyStr(str):
def __ne__(self, other):
return False
class ByPassTypeTuple(tuple):
def __getitem__(self, index):
if index>0:
index=0
item = super().__getitem__(index)
if isinstance(item, str):
return TautologyStr(item)
return item
comfy_ui_revision = None
def get_comfyui_revision():
try:
import git
import os
import folder_paths
repo = git.Repo(os.path.dirname(folder_paths.__file__))
comfy_ui_revision = len(list(repo.iter_commits('HEAD')))
except:
comfy_ui_revision = "Unknown"
return comfy_ui_revision
import sys
import importlib.util
import importlib.metadata
import comfy.model_management as mm
import gc
from packaging import version
from server import PromptServer
def is_package_installed(package):
try:
module = importlib.util.find_spec(package)
return module is not None
except ImportError as e:
print(e)
return False
def install_package(package, v=None, compare=True, compare_version=None):
run_install = True
if is_package_installed(package):
try:
installed_version = importlib.metadata.version(package)
if v is not None:
if compare_version is None:
compare_version = v
if not compare or version.parse(installed_version) >= version.parse(compare_version):
run_install = False
else:
run_install = False
except:
run_install = False
if run_install:
import subprocess
package_command = package + '==' + v if v is not None else package
PromptServer.instance.send_sync("easyuse-toast", {'content': f"Installing {package_command}...", 'duration': 5000})
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', package_command], capture_output=True, text=True)
if result.returncode == 0:
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed successfully", 'type': 'success', 'duration': 5000})
print(f"Package {package} installed successfully")
return True
else:
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed failed", 'type': 'error', 'duration': 5000})
print(f"Package {package} installed failed")
return False
else:
return False
def compare_revision(num):
global comfy_ui_revision
if not comfy_ui_revision:
comfy_ui_revision = get_comfyui_revision()
return True if comfy_ui_revision == 'Unknown' or int(comfy_ui_revision) >= num else False
def find_tags(string: str, sep="/") -> list[str]:
"""
find tags from string use the sep for split
Note: string may contain the \\ or / for path separator
"""
if not string:
return []
string = string.replace("\\", "/")
while "//" in string:
string = string.replace("//", "/")
if string and sep in string:
return string.split(sep)[:-1]
return []
from comfy.model_base import BaseModel
import comfy.supported_models
import comfy.supported_models_base
def get_sd_version(model):
base: BaseModel = model.model
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
if isinstance(model_config, comfy.supported_models.SDXL):
return 'sdxl'
elif isinstance(model_config, comfy.supported_models.SDXLRefiner):
return 'sdxl_refiner'
elif isinstance(
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
):
return 'sd1'
elif isinstance(
model_config, (comfy.supported_models.SVD_img2vid)
):
return 'svd'
elif isinstance(model_config, comfy.supported_models.SD3):
return 'sd3'
elif isinstance(model_config, comfy.supported_models.HunyuanDiT):
return 'hydit'
elif isinstance(model_config, comfy.supported_models.Flux):
return 'flux'
elif isinstance(model_config, comfy.supported_models.GenmoMochi):
return 'mochi'
else:
return 'unknown'
def find_nearest_steps(clip_id, prompt):
"""Find the nearest KSampler or preSampling node that references the given id."""
def check_link_to_clip(node_id, clip_id, visited=None, node=None):
"""Check if a given node links directly or indirectly to a loader node."""
if visited is None:
visited = set()
if node_id in visited:
return False
visited.add(node_id)
if "pipe" in node["inputs"]:
link_ids = node["inputs"]["pipe"]
for id in link_ids:
if id != 0 and id == str(clip_id):
return True
return False
for id in prompt:
node = prompt[id]
if "Sampler" in node["class_type"] or "sampler" in node["class_type"] or "Sampling" in node["class_type"]:
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
if check_link_to_clip(id, clip_id, None, node):
steps = node["inputs"]["steps"] if "steps" in node["inputs"] else 1
return steps
return 1
def find_wildcards_seed(clip_id, text, prompt):
""" Find easy wildcards seed value"""
def find_link_clip_id(id, seed, wildcard_id):
node = prompt[id]
if "positive" in node['inputs']:
link_ids = node["inputs"]["positive"]
if type(link_ids) == list:
for id in link_ids:
if id != 0:
if id == wildcard_id:
wildcard_node = prompt[wildcard_id]
seed = wildcard_node["inputs"]["seed"] if "seed" in wildcard_node["inputs"] else None
if seed is None:
seed = wildcard_node["inputs"]["seed_num"] if "seed_num" in wildcard_node["inputs"] else None
return seed
else:
return find_link_clip_id(id, seed, wildcard_id)
else:
return None
else:
return None
if "__" in text:
seed = None
for id in prompt:
node = prompt[id]
if "wildcards" in node["class_type"]:
wildcard_id = id
return find_link_clip_id(str(clip_id), seed, wildcard_id)
return seed
else:
return None
def is_linked_styles_selector(prompt, unique_id, prompt_type='positive'):
unique_id = unique_id.split('.')[len(unique_id.split('.')) - 1] if "." in unique_id else unique_id
inputs_values = prompt[unique_id]['inputs'][prompt_type] if prompt_type in prompt[unique_id][
'inputs'] else None
if type(inputs_values) == list and inputs_values != 'undefined' and inputs_values[0]:
return True if prompt[inputs_values[0]] and prompt[inputs_values[0]]['class_type'] == 'easy stylesSelector' else False
else:
return False
use_mirror = False
def get_local_filepath(url, dirname, local_file_name=None):
"""Get local file path when is already downloaded or download it"""
import os
from server import PromptServer
from urllib.parse import urlparse
from torch.hub import download_url_to_file
global use_mirror
if not os.path.exists(dirname):
os.makedirs(dirname)
if not local_file_name:
parsed_url = urlparse(url)
local_file_name = os.path.basename(parsed_url.path)
destination = os.path.join(dirname, local_file_name)
if not os.path.exists(destination):
try:
if use_mirror:
url = url.replace('huggingface.co', 'hf-mirror.com')
print(f'downloading {url} to {destination}')
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Downloading model to {destination}, please wait...', 'duration': 10000})
download_url_to_file(url, destination)
except Exception as e:
use_mirror = True
url = url.replace('huggingface.co', 'hf-mirror.com')
print(f'Unable to download from huggingface, trying mirror: {url}')
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Unable to connect to huggingface, trying mirror: {url}', 'duration': 10000})
try:
download_url_to_file(url, destination)
except Exception as err:
error_msg = str(err.args[0]) if err.args else str(err)
PromptServer.instance.send_sync("easyuse-toast",
{'content': f'Unable to download model from {url}', 'type':'error'})
raise Exception(f'Download failed. Original URL and mirror both failed.\nError: {error_msg}')
return destination
def to_lora_patch_dict(state_dict: dict) -> dict:
""" Convert raw lora state_dict to patch_dict that can be applied on
modelpatcher."""
patch_dict = {}
for k, w in state_dict.items():
model_key, patch_type, weight_index = k.split('::')
if model_key not in patch_dict:
patch_dict[model_key] = {}
if patch_type not in patch_dict[model_key]:
patch_dict[model_key][patch_type] = [None] * 16
patch_dict[model_key][patch_type][int(weight_index)] = w
patch_flat = {}
for model_key, v in patch_dict.items():
for patch_type, weight_list in v.items():
patch_flat[model_key] = (patch_type, weight_list)
return patch_flat
def easySave(images, filename_prefix, output_type, prompt=None, extra_pnginfo=None):
"""Save or Preview Image"""
from nodes import PreviewImage, SaveImage
if output_type in ["Hide", "None"]:
return list()
elif output_type in ["Preview", "Preview&Choose"]:
filename_prefix = 'easyPreview'
results = PreviewImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
return results['ui']['images']
else:
results = SaveImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
return results['ui']['images']
def getMetadata(filepath):
with open(filepath, "rb") as file:
# https://github.com/huggingface/safetensors#format
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
header_size = int.from_bytes(file.read(8), "little", signed=False)
if header_size <= 0:
raise BufferError("Invalid header size")
header = file.read(header_size)
if header_size <= 0:
raise BufferError("Invalid header")
return header
def cleanGPUUsedForce():
gc.collect()
mm.unload_all_models()
mm.soft_empty_cache()

View File

@@ -0,0 +1,476 @@
import json
import os
import random
import re
from math import prod
import yaml
import folder_paths
from .log import log_node_info
easy_wildcard_dict = {}
def get_wildcard_list():
return [f"__{x}__" for x in easy_wildcard_dict.keys()]
def wildcard_normalize(x):
return x.replace("\\", "/").lower()
def read_wildcard(k, v):
if isinstance(v, list):
k = wildcard_normalize(k)
easy_wildcard_dict[k] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
new_key = f"{k}/{k2}"
new_key = wildcard_normalize(new_key)
read_wildcard(new_key, v2)
def read_wildcard_dict(wildcard_path):
global easy_wildcard_dict
for root, directories, files in os.walk(wildcard_path, followlinks=True):
for file in files:
if file.endswith('.txt'):
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, wildcard_path)
key = os.path.splitext(rel_path)[0].replace('\\', '/').lower()
try:
with open(file_path, 'r', encoding="UTF-8", errors="ignore") as f:
lines = f.read().splitlines()
easy_wildcard_dict[key] = lines
except UnicodeDecodeError:
with open(file_path, 'r', encoding="ISO-8859-1") as f:
lines = f.read().splitlines()
easy_wildcard_dict[key] = lines
elif file.endswith('.yaml'):
file_path = os.path.join(root, file)
with open(file_path, 'r') as f:
yaml_data = yaml.load(f, Loader=yaml.FullLoader)
for k, v in yaml_data.items():
read_wildcard(k, v)
elif file.endswith('.json'):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r') as f:
json_data = json.load(f)
for key, value in json_data.items():
key = wildcard_normalize(key)
easy_wildcard_dict[key] = value
except ValueError:
print('json files load error')
return easy_wildcard_dict
def process(text, seed=None):
if seed is not None:
random.seed(seed)
def replace_options(string):
replacements_found = False
def replace_option(match):
nonlocal replacements_found
options = match.group(1).split('|')
multi_select_pattern = options[0].split('$$')
select_range = None
select_sep = ' '
range_pattern = r'(\d+)(-(\d+))?'
range_pattern2 = r'-(\d+)'
if len(multi_select_pattern) > 1:
r = re.match(range_pattern, options[0])
if r is None:
r = re.match(range_pattern2, options[0])
a = '1'
b = r.group(1).strip()
else:
a = r.group(1).strip()
b = r.group(3).strip()
if r is not None:
if b is not None and is_numeric_string(a) and is_numeric_string(b):
# PATTERN: num1-num2
select_range = int(a), int(b)
elif is_numeric_string(a):
# PATTERN: num
x = int(a)
select_range = (x, x)
if select_range is not None and len(multi_select_pattern) == 2:
# PATTERN: count$$
options[0] = multi_select_pattern[1]
elif select_range is not None and len(multi_select_pattern) == 3:
# PATTERN: count$$ sep $$
select_sep = multi_select_pattern[1]
options[0] = multi_select_pattern[2]
adjusted_probabilities = []
total_prob = 0
for option in options:
parts = option.split('::', 1)
if len(parts) == 2 and is_numeric_string(parts[0].strip()):
config_value = float(parts[0].strip())
else:
config_value = 1 # Default value if no configuration is provided
adjusted_probabilities.append(config_value)
total_prob += config_value
normalized_probabilities = [prob / total_prob for prob in adjusted_probabilities]
if select_range is None:
select_count = 1
else:
select_count = random.randint(select_range[0], select_range[1])
if select_count > len(options):
selected_items = options
else:
selected_items = random.choices(options, weights=normalized_probabilities, k=select_count)
selected_items = set(selected_items)
try_count = 0
while len(selected_items) < select_count and try_count < 10:
remaining_count = select_count - len(selected_items)
additional_items = random.choices(options, weights=normalized_probabilities, k=remaining_count)
selected_items |= set(additional_items)
try_count += 1
selected_items2 = [re.sub(r'^\s*[0-9.]+::', '', x, 1) for x in selected_items]
replacement = select_sep.join(selected_items2)
if '::' in replacement:
pass
replacements_found = True
return replacement
pattern = r'{([^{}]*?)}'
replaced_string = re.sub(pattern, replace_option, string)
return replaced_string, replacements_found
def replace_wildcard(string):
global easy_wildcard_dict
pattern = r"__([\w\s.\-+/*\\]+?)__"
matches = re.findall(pattern, string)
replacements_found = False
for match in matches:
keyword = match.lower()
keyword = wildcard_normalize(keyword)
if keyword in easy_wildcard_dict:
replacement = random.choice(easy_wildcard_dict[keyword])
replacements_found = True
string = string.replace(f"__{match}__", replacement, 1)
elif '*' in keyword:
subpattern = keyword.replace('*', '.*').replace('+', r'\+')
total_patterns = []
found = False
for k, v in easy_wildcard_dict.items():
if re.match(subpattern, k) is not None:
total_patterns += v
found = True
if found:
replacement = random.choice(total_patterns)
replacements_found = True
string = string.replace(f"__{match}__", replacement, 1)
elif '/' not in keyword:
string_fallback = string.replace(f"__{match}__", f"__*/{match}__", 1)
string, replacements_found = replace_wildcard(string_fallback)
return string, replacements_found
replace_depth = 100
stop_unwrap = False
while not stop_unwrap and replace_depth > 1:
replace_depth -= 1 # prevent infinite loop
# pass1: replace options
pass1, is_replaced1 = replace_options(text)
while is_replaced1:
pass1, is_replaced1 = replace_options(pass1)
# pass2: replace wildcards
text, is_replaced2 = replace_wildcard(pass1)
stop_unwrap = not is_replaced1 and not is_replaced2
return text
def is_numeric_string(input_str):
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None
def safe_float(x):
if is_numeric_string(x):
return float(x)
else:
return 1.0
def extract_lora_values(string):
pattern = r'<lora:([^>]+)>'
matches = re.findall(pattern, string)
def touch_lbw(text):
return re.sub(r'LBW=[A-Za-z][A-Za-z0-9_-]*:', r'LBW=', text)
items = [touch_lbw(match.strip(':')) for match in matches]
added = set()
result = []
for item in items:
item = item.split(':')
lora = None
a = None
b = None
lbw = None
lbw_a = None
lbw_b = None
if len(item) > 0:
lora = item[0]
for sub_item in item[1:]:
if is_numeric_string(sub_item):
if a is None:
a = float(sub_item)
elif b is None:
b = float(sub_item)
elif sub_item.startswith("LBW="):
for lbw_item in sub_item[4:].split(';'):
if lbw_item.startswith("A="):
lbw_a = safe_float(lbw_item[2:].strip())
elif lbw_item.startswith("B="):
lbw_b = safe_float(lbw_item[2:].strip())
elif lbw_item.strip() != '':
lbw = lbw_item
if a is None:
a = 1.0
if b is None:
b = 1.0
if lora is not None and lora not in added:
result.append((lora, a, b, lbw, lbw_a, lbw_b))
added.add(lora)
return result
def remove_lora_tags(string):
pattern = r'<lora:[^>]+>'
result = re.sub(pattern, '', string)
return result
def process_with_loras(wildcard_opt, model, clip, title="Positive", seed=None, can_load_lora=True, pipe_lora_stack=[], easyCache=None):
pass1 = process(wildcard_opt, seed)
loras = extract_lora_values(pass1)
pass2 = remove_lora_tags(pass1)
has_noodle_key = True if "__" in wildcard_opt else False
has_loras = True if loras != [] else False
show_wildcard_prompt = True if has_noodle_key or has_loras else False
if can_load_lora and has_loras:
for lora_name, model_weight, clip_weight, lbw, lbw_a, lbw_b in loras:
if (lora_name.split('.')[-1]) not in folder_paths.supported_pt_extensions:
lora_name = lora_name+".safetensors"
lora = {
"lora_name": lora_name, "model": model, "clip": clip, "model_strength": model_weight,
"clip_strength": clip_weight,
"lbw_a": lbw_a,
"lbw_b": lbw_b,
"lbw": lbw
}
model, clip = easyCache.load_lora(lora)
lora["model"] = model
lora["clip"] = clip
pipe_lora_stack.append(lora)
log_node_info("easy wildcards",f"{title}: {pass2}")
if pass1 != pass2:
log_node_info("easy wildcards",f'{title}_decode: {pass1}')
return model, clip, pass2, pass1, show_wildcard_prompt, pipe_lora_stack
def expand_wildcard(keyword: str) -> tuple[str]:
"""传入文件通配符的关键词,从 easy_wildcard_dict 中获取通配符的所有选项。"""
global easy_wildcard_dict
if keyword in easy_wildcard_dict:
return tuple(easy_wildcard_dict[keyword])
elif '*' in keyword:
subpattern = keyword.replace('*', '.*').replace('+', r"\+")
total_pattern = []
for k, v in easy_wildcard_dict.items():
if re.match(subpattern, k) is not None:
total_pattern.extend(v)
if total_pattern:
return tuple(total_pattern)
elif '/' not in keyword:
return expand_wildcard(f"*/{keyword}")
def expand_options(options: str) -> tuple[str]:
"""传入去掉 {} 的选项。
展开选项通配符,返回该选项中的每一项,这里的每一项都是一个替换项。
不会对选项内容进行任何处理,即便存在空格或特殊符号,也会原样返回。"""
return tuple(options.split("|"))
def decimal_to_irregular(n, bases):
"""
将十进制数转换为不规则进制
:param n: 十进制数
:param bases: 各位置的基数列表,从低位到高位
:return: 不规则进制表示的列表,从低位到高位
"""
if n == 0:
return [0] * len(bases) if bases else [0]
digits = []
remaining = n
# 从低位到高位处理
for base in bases:
digit = remaining % base
digits.append(digit)
remaining = remaining // base
return digits
class WildcardProcessor:
"""通配符处理器
通配符格式:
+ option : {a|b}
+ wildcard: __keyword__ 通配符内容将从 Easy-Use 插件提供的 easy_wildcard_dict 中获取
"""
RE_OPTIONS = re.compile(r"{([^{}]*?)}")
RE_WILDCARD = re.compile(r"__([\w\s.\-+/*\\]+?)__")
RE_REPLACER = re.compile(r"{([^{}]*?)}|__([\w\s.\-+/*\\]+?)__")
# 将输入的提示词转化成符合 python str.format 要求格式的模板,并将 option 和 wildcard 按照顺序在模板中留下 {0}, {1} 等占位符
template: str
# option、wildcard 的替换项列表,按照在模板中出现的顺序排列,相同的替换项列表只保留第一份
replacers: dict[int, tuple[str]]
# 占位符的编号和替换项列表的索引的映射,占位符编号按照在模板中出现的顺序排列,方便减少替换项的存储占用
placeholder_mapping: dict[str, int] # placeholder_id => replacer_id
# 各替换项列表的项数,按照在模板中出现的顺序排列,提前计算,方便后续使用
placeholder_choices: dict[str, int] # placeholder_id => len(replacer)
def __init__(self, text: str):
self.__make_template(text)
self.__total = None
def random(self, seed=None) -> str:
"从所有可能性中随机获取一个"
if seed is not None:
random.seed(seed)
return self.getn(random.randint(0, self.total() - 1))
def getn(self, n: int) -> str:
"从所有可能性中获取第 n 个,以 self.total() 为周期循环"
n = n % self.total()
indice = decimal_to_irregular(n, self.placeholder_choices.values())
replacements = {
placeholder_id: self.replacers[self.placeholder_mapping[placeholder_id]][i]
for placeholder_id, i in zip(self.placeholder_mapping.keys(), indice)
}
return self.template.format(**replacements)
def getmany(self, limit: int, offset: int = 0) -> list[str]:
"""返回一组可能性组成的列表,为了避免结果太长导致内存占用超限,使用 limit 限制列表的长度,使用 offset 调整偏移。
若 limit 和 offset 的设置导致预期的结果长度超过剩下的实际长度,则会回到开头。
"""
return [self.getn(n) for n in range(offset, offset + limit)]
def total(self) -> int:
"计算可能性的数目"
if self.__total is None:
self.__total = prod(self.placeholder_choices.values())
return self.__total
def __make_template(self, text: str):
"""将输入的提示词转化成符合 python str.format 要求格式的模板,
并将 option 和 wildcard 按照顺序在模板中留下 {r0}, {r1} 等占位符,
即使遇到相同的 option 或 wildcard留下的占位符编号也不同从而使每项都独立变化。
"""
self.placeholder_mapping = {}
placeholder_id = 0
replacer_id = 0
replacers_rev = {} # replacers => id
blocks = []
# 记录所处理过的通配符末尾在文本中的位置,用于拼接完整的模板
tail = 0
for match in self.RE_REPLACER.finditer(text):
# 提取并展开通配符内容
m = match.group(0)
if m.startswith("{"):
choices = expand_options(m[1:-1])
elif m.startswith("__"):
keyword = m[2:-2].lower()
keyword = wildcard_normalize(keyword)
choices = expand_wildcard(keyword)
else:
raise ValueError(f"{m!r} is not a wildcard or option")
# 记录通配符的替换项列表和ID相同的通配符只保留第一个
if choices not in replacers_rev:
replacers_rev[choices] = replacer_id
replacer_id += 1
# 拼接通配符前方文本
start, end = match.span()
blocks.append(text[tail:start])
tail = end
# 将通配符替换为占位符,并记录占位符和替换项列表的索引的映射
blocks.append(f"{{r{placeholder_id}}}")
self.placeholder_mapping[f"r{placeholder_id}"] = replacers_rev[choices]
placeholder_id += 1
if tail < len(text):
blocks.append(text[tail:])
self.template = "".join(blocks)
self.replacers = {v: k for k, v in replacers_rev.items()}
self.placeholder_choices = {
placeholder_id: len(self.replacers[replacer_id])
for placeholder_id, replacer_id in self.placeholder_mapping.items()
}
def test_option():
text = "{|a|b|c}"
answer = ["", "a", "b", "c"]
p = WildcardProcessor(text)
assert p.total() == len(answer)
assert p.getn(0) == answer[0]
assert p.getmany(4) == answer
assert p.getmany(4, 1) == answer[1:]
def test_same():
text = "{a|b},{a|b}"
answer = ["a,a", "b,a", "a,b", "b,b"]
p = WildcardProcessor(text)
assert p.total() == len(answer)
assert p.getn(0) == answer[0]
assert p.getmany(4) == answer
assert p.getmany(4, 1) == answer[1:]

View File

@@ -0,0 +1,697 @@
import os, torch
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from .utils import easySave, get_sd_version
from .adv_encode import advanced_encode
from .controlnet import easyControlnet
from .log import log_node_warn
from ..modules.layer_diffuse import LayerDiffuse
from ..config import RESOURCES_DIR
from nodes import CLIPTextEncode
import pprint
try:
from comfy_extras.nodes_flux import FluxGuidance
except:
FluxGuidance = None
class easyXYPlot():
def __init__(self, xyPlotData, save_prefix, image_output, prompt, extra_pnginfo, my_unique_id, sampler, easyCache):
self.x_node_type, self.x_type = sampler.safe_split(xyPlotData.get("x_axis"), ': ')
self.y_node_type, self.y_type = sampler.safe_split(xyPlotData.get("y_axis"), ': ')
self.x_values = xyPlotData.get("x_vals") if self.x_type != "None" else []
self.y_values = xyPlotData.get("y_vals") if self.y_type != "None" else []
self.custom_font = xyPlotData.get("custom_font")
self.grid_spacing = xyPlotData.get("grid_spacing")
self.latent_id = 0
self.output_individuals = xyPlotData.get("output_individuals")
self.x_label, self.y_label = [], []
self.max_width, self.max_height = 0, 0
self.latents_plot = []
self.image_list = []
self.num_cols = len(self.x_values) if len(self.x_values) > 0 else 1
self.num_rows = len(self.y_values) if len(self.y_values) > 0 else 1
self.total = self.num_cols * self.num_rows
self.num = 0
self.save_prefix = save_prefix
self.image_output = image_output
self.prompt = prompt
self.extra_pnginfo = extra_pnginfo
self.my_unique_id = my_unique_id
self.sampler = sampler
self.easyCache = easyCache
# Helper Functions
@staticmethod
def define_variable(plot_image_vars, value_type, value, index):
plot_image_vars[value_type] = value
if value_type in ["seed", "Seeds++ Batch"]:
value_label = f"seed: {value}"
else:
value_label = f"{value_type}: {value}"
if "ControlNet" in value_type:
value_label = f"ControlNet {index + 1}"
if value_type in ['Lora', 'Checkpoint']:
arr = value.split(',')
model_name = os.path.basename(os.path.splitext(arr[0])[0])
trigger_words = ' ' + arr[3] if value_type == 'Lora' and len(arr[3]) > 2 else ''
lora_weight = float(arr[1]) if value_type == 'Lora' and len(arr) > 1 else 0
lora_weight_desc = f"({lora_weight:.2f})" if lora_weight > 0 else ''
value_label = f"{model_name[:30]}{lora_weight_desc} {trigger_words}"
if value_type in ["ModelMergeBlocks"]:
if ":" in value:
line = value.split(':')
value_label = f"{line[0]}"
elif len(value) > 16:
value_label = f"ModelMergeBlocks {index + 1}"
else:
value_label = f"MMB: {value}"
if value_type in ["Pos Condition"]:
value_label = f"pos cond {index + 1}" if index>0 else f"pos cond"
if value_type in ["Neg Condition"]:
value_label = f"neg cond {index + 1}" if index>0 else f"neg cond"
if value_type in ["Positive Prompt S/R"]:
value_label = f"pos prompt {index + 1}" if index>0 else f"pos prompt"
if value_type in ["Negative Prompt S/R"]:
value_label = f"neg prompt {index + 1}" if index>0 else f"neg prompt"
if value_type in ["steps", "cfg", "denoise", "clip_skip",
"lora_model_strength", "lora_clip_strength"]:
value_label = f"{value_type}: {value}"
if value_type == "positive":
value_label = f"pos prompt {index + 1}"
elif value_type == "negative":
value_label = f"neg prompt {index + 1}"
return plot_image_vars, value_label
@staticmethod
def get_font(font_size, font_path=None):
if font_path is None:
font_path = str(Path(os.path.join(RESOURCES_DIR, 'OpenSans-Medium.ttf')))
return ImageFont.truetype(font_path, font_size)
@staticmethod
def update_label(label, value, num_items):
if len(label) < num_items:
return [*label, value]
return label
@staticmethod
def rearrange_tensors(latent, num_cols, num_rows):
new_latent = []
for i in range(num_rows):
for j in range(num_cols):
index = j * num_rows + i
new_latent.append(latent[index])
return new_latent
def calculate_background_dimensions(self):
border_size = int((self.max_width // 8) * 1.5) if self.y_type != "None" or self.x_type != "None" else 0
bg_width = self.num_cols * (self.max_width + self.grid_spacing) - self.grid_spacing + border_size * (
self.y_type != "None")
bg_height = self.num_rows * (self.max_height + self.grid_spacing) - self.grid_spacing + border_size * (
self.x_type != "None")
# Add space at the bottom of the image for common informaiton about the image
bg_height = bg_height + (border_size*2)
# print(f"Grid Size: width = {bg_width} height = {bg_height} border_size = {border_size}")
x_offset_initial = border_size if self.y_type != "None" else 0
y_offset = border_size if self.x_type != "None" else 0
return bg_width, bg_height, x_offset_initial, y_offset
def adjust_font_size(self, text, initial_font_size, label_width):
font = self.get_font(initial_font_size, self.custom_font)
text_width = font.getbbox(text)
# pprint.pp(f"Initial font size: {initial_font_size}, text: {text}, text_width: {text_width}")
if text_width and text_width[2]:
text_width = text_width[2]
scaling_factor = 0.9
if text_width > (label_width * scaling_factor):
# print(f"Adjusting font size from {initial_font_size} to fit text width {text_width} into label width {label_width} scaling_factor {scaling_factor}")
return int(initial_font_size * (label_width / text_width) * scaling_factor)
else:
return initial_font_size
def textsize(self, d, text, font):
_, _, width, height = d.textbbox((0, 0), text=text, font=font)
return width, height
def create_label(self, img, text, initial_font_size, is_x_label=True, max_font_size=70, min_font_size=10, label_width=0, label_height=0):
# if the label_width is specified, leave it along. Otherwise do the old logic.
if label_width == 0:
label_width = img.width if is_x_label else img.height
text_lines = text.split('\n')
longest_line = max(text_lines, key=len)
# Adjust font size
font_size = self.adjust_font_size(longest_line, initial_font_size, label_width)
font_size = min(max_font_size, font_size) # Ensure font isn't too large
font_size = max(min_font_size, font_size) # Ensure font isn't too small
if label_height == 0:
label_height = int(font_size * 1.5) if is_x_label else font_size
label_bg = Image.new('RGBA', (label_width, label_height), color=(255, 255, 255, 0))
d = ImageDraw.Draw(label_bg)
font = self.get_font(font_size, self.custom_font)
# Check if text will fit, if not insert ellipsis and reduce text
if self.textsize(d, text, font=font)[0] > label_width:
while self.textsize(d, text + '...', font=font)[0] > label_width and len(text) > 0:
text = text[:-1]
text = text + '...'
# Compute text width and height for multi-line text
text_widths, text_heights = zip(*[self.textsize(d, line, font=font) for line in text_lines])
max_text_width = max(text_widths)
total_text_height = sum(text_heights)
# Compute position for each line of text
lines_positions = []
current_y = 0
for line, line_width, line_height in zip(text_lines, text_widths, text_heights):
text_x = (label_width - line_width) // 2
text_y = current_y + (label_height - total_text_height) // 2
current_y += line_height
lines_positions.append((line, (text_x, text_y)))
# Draw each line of text
for line, (text_x, text_y) in lines_positions:
d.text((text_x, text_y), line, fill='black', font=font)
return label_bg
def sample_plot_image(self, plot_image_vars, samples, preview_latent, latents_plot, image_list, disable_noise,
start_step, last_step, force_full_denoise, x_value=None, y_value=None):
model, clip, vae, positive, negative, seed, steps, cfg = None, None, None, None, None, None, None, None
sampler_name, scheduler, denoise = None, None, None
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] if "a1111_prompt_style" in plot_image_vars else False
clip = clip if clip is not None else plot_image_vars["clip"]
steps = plot_image_vars['steps'] if "steps" in plot_image_vars else 1
sd_version = get_sd_version(plot_image_vars['model'])
# 高级用法
if plot_image_vars["x_node_type"] == "advanced" or plot_image_vars["y_node_type"] == "advanced":
if self.x_type == "Seeds++ Batch" or self.y_type == "Seeds++ Batch":
seed = int(x_value) if self.x_type == "Seeds++ Batch" else int(y_value)
if self.x_type == "Steps" or self.y_type == "Steps":
steps = int(x_value) if self.x_type == "Steps" else int(y_value)
if self.x_type == "StartStep" or self.y_type == "StartStep":
start_step = int(x_value) if self.x_type == "StartStep" else int(y_value)
if self.x_type == "EndStep" or self.y_type == "EndStep":
last_step = int(x_value) if self.x_type == "EndStep" else int(y_value)
if self.x_type == "CFG Scale" or self.y_type == "CFG Scale":
cfg = float(x_value) if self.x_type == "CFG Scale" else float(y_value)
if self.x_type == "Sampler" or self.y_type == "Sampler":
sampler_name = x_value if self.x_type == "Sampler" else y_value
if self.x_type == "Scheduler" or self.y_type == "Scheduler":
scheduler = x_value if self.x_type == "Scheduler" else y_value
if self.x_type == "Sampler&Scheduler" or self.y_type == "Sampler&Scheduler":
arr = x_value.split(',') if self.x_type == "Sampler&Scheduler" else y_value.split(',')
if arr[0] and arr[0]!= 'None':
sampler_name = arr[0]
if arr[1] and arr[1]!= 'None':
scheduler = arr[1]
if self.x_type == "Denoise" or self.y_type == "Denoise":
denoise = float(x_value) if self.x_type == "Denoise" else float(y_value)
if self.x_type == "Pos Condition" or self.y_type == "Pos Condition":
positive = plot_image_vars['positive_cond_stack'][int(x_value)] if self.x_type == "Pos Condition" else plot_image_vars['positive_cond_stack'][int(y_value)]
if self.x_type == "Neg Condition" or self.y_type == "Neg Condition":
negative = plot_image_vars['negative_cond_stack'][int(x_value)] if self.x_type == "Neg Condition" else plot_image_vars['negative_cond_stack'][int(y_value)]
# 模型叠加
if self.x_type == "ModelMergeBlocks" or self.y_type == "ModelMergeBlocks":
ckpt_name_1, ckpt_name_2 = plot_image_vars['models']
model1, clip1, vae1, clip_vision = self.easyCache.load_checkpoint(ckpt_name_1)
model2, clip2, vae2, clip_vision = self.easyCache.load_checkpoint(ckpt_name_2)
xy_values = x_value if self.x_type == "ModelMergeBlocks" else y_value
if ":" in xy_values:
xy_line = xy_values.split(':')
xy_values = xy_line[1]
xy_arrs = xy_values.split(',')
# ModelMergeBlocks
if len(xy_arrs) == 3:
input, middle, out = xy_arrs
kwargs = {
"input": input,
"middle": middle,
"out": out
}
elif len(xy_arrs) == 30:
kwargs = {}
kwargs["time_embed."] = xy_arrs[0]
kwargs["label_emb."] = xy_arrs[1]
for i in range(12):
kwargs["input_blocks.{}.".format(i)] = xy_arrs[2+i]
for i in range(3):
kwargs["middle_block.{}.".format(i)] = xy_arrs[14+i]
for i in range(12):
kwargs["output_blocks.{}.".format(i)] = xy_arrs[17+i]
kwargs["out."] = xy_arrs[29]
else:
raise Exception("ModelMergeBlocks weight length error")
default_ratio = next(iter(kwargs.values()))
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
ratio = float(default_ratio)
k_unet = k[len("diffusion_model."):]
last_arg_size = 0
for arg in kwargs:
if k_unet.startswith(arg) and last_arg_size < len(arg):
ratio = float(kwargs[arg])
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
vae_use = plot_image_vars['vae_use']
clip = clip2 if vae_use == 'Use Model 2' else clip1
if vae_use == 'Use Model 2':
vae = vae2
elif vae_use == 'Use Model 1':
vae = vae1
else:
vae = self.easyCache.load_vae(vae_use)
model = m
# 如果存在lora_stack叠加lora
optional_lora_stack = plot_image_vars['lora_stack']
if optional_lora_stack is not None and optional_lora_stack != []:
for lora in optional_lora_stack:
model, clip = self.easyCache.load_lora(lora)
# 处理clip
clip = clip.clone()
if plot_image_vars['clip_skip'] != 0:
clip.clip_layer(plot_image_vars['clip_skip'])
# CheckPoint
if self.x_type == "Checkpoint" or self.y_type == "Checkpoint":
xy_values = x_value if self.x_type == "Checkpoint" else y_value
ckpt_name, clip_skip, vae_name = xy_values.split(",")
ckpt_name = ckpt_name.replace('*', ',')
vae_name = vae_name.replace('*', ',')
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
if vae_name != 'None':
vae = self.easyCache.load_vae(vae_name)
# 如果存在lora_stack叠加lora
optional_lora_stack = plot_image_vars['lora_stack']
if optional_lora_stack is not None and optional_lora_stack != []:
for lora in optional_lora_stack:
lora['model'] = model
lora['clip'] = clip
model, clip = self.easyCache.load_lora(lora)
# 处理clip
clip = clip.clone()
if clip_skip != 'None':
clip.clip_layer(int(clip_skip))
positive = plot_image_vars['positive']
negative = plot_image_vars['negative']
a1111_prompt_style = plot_image_vars['a1111_prompt_style']
steps = plot_image_vars['steps']
clip = clip if clip is not None else plot_image_vars["clip"]
positive = advanced_encode(clip, positive,
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable",
a1111_prompt_style=a1111_prompt_style, steps=steps)
negative = advanced_encode(clip, negative,
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable",
a1111_prompt_style=a1111_prompt_style, steps=steps)
if "positive_cond" in plot_image_vars:
positive = positive + plot_image_vars["positive_cond"]
if "negative_cond" in plot_image_vars:
negative = negative + plot_image_vars["negative_cond"]
# Lora
if self.x_type == "Lora" or self.y_type == "Lora":
# print(f"Lora: {x_value} {y_value}")
model = model if model is not None else plot_image_vars["model"]
clip = clip if clip is not None else plot_image_vars["clip"]
xy_values = x_value if self.x_type == "Lora" else y_value
lora_name, lora_model_strength, lora_clip_strength, _ = xy_values.split(",")
lora_stack = [{"lora_name": lora_name, "model": model, "clip" :clip, "model_strength": float(lora_model_strength), "clip_strength": float(lora_clip_strength)}]
# print(f"new_lora_stack: {new_lora_stack}")
if 'lora_stack' in plot_image_vars:
lora_stack = lora_stack + plot_image_vars['lora_stack']
if lora_stack is not None and lora_stack != []:
for lora in lora_stack:
# Each generation of the model, must use the reference to previously created model / clip objects.
lora['model'] = model
lora['clip'] = clip
model, clip = self.easyCache.load_lora(lora)
# 提示词
if "Positive" in self.x_type or "Positive" in self.y_type:
if self.x_type == 'Positive Prompt S/R' or self.y_type == 'Positive Prompt S/R':
positive = x_value if self.x_type == "Positive Prompt S/R" else y_value
if sd_version == 'flux':
positive, = CLIPTextEncode().encode(clip, positive)
else:
positive = advanced_encode(clip, positive,
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
# if "positive_cond" in plot_image_vars:
# positive = positive + plot_image_vars["positive_cond"]
if "Negative" in self.x_type or "Negative" in self.y_type:
if self.x_type == 'Negative Prompt S/R' or self.y_type == 'Negative Prompt S/R':
negative = x_value if self.x_type == "Negative Prompt S/R" else y_value
if sd_version == 'flux':
negative, = CLIPTextEncode().encode(clip, negative)
else:
negative = advanced_encode(clip, negative,
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
# if "negative_cond" in plot_image_vars:
# negative = negative + plot_image_vars["negative_cond"]
# ControlNet
if "ControlNet" in self.x_type or "ControlNet" in self.y_type:
cnet = plot_image_vars["cnet"] if "cnet" in plot_image_vars else None
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
negative = plot_image_vars["negative_cond"] if "negative" in plot_image_vars else None
if cnet:
index = x_value if "ControlNet" in self.x_type else y_value
controlnet = cnet[index]
for index, item in enumerate(controlnet):
control_net_name = item[0]
image = item[1]
strength = item[2]
start_percent = item[3]
end_percent = item[4]
provided_control_net = item[5] if len(item) > 5 else None
positive, negative = easyControlnet().apply(control_net_name, image, positive, negative, strength, start_percent, end_percent, provided_control_net, 1)
# Flux guidance
if self.x_type == "Flux Guidance" or self.y_type == "Flux Guidance":
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
flux_guidance = float(x_value) if self.x_type == "Flux Guidance" else float(y_value)
positive, = FluxGuidance().append(positive, flux_guidance)
# 简单用法
if plot_image_vars["x_node_type"] == "loader" or plot_image_vars["y_node_type"] == "loader":
if self.x_type == 'ckpt_name' or self.y_type == 'ckpt_name':
ckpt_name = x_value if self.x_type == "ckpt_name" else y_value
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
if self.x_type == 'lora_name' or self.y_type == 'lora_name':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_name = x_value if self.x_type == "lora_name" else y_value
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": 1, "clip_strength": 1}
model, clip = self.easyCache.load_lora(lora)
if self.x_type == 'lora_model_strength' or self.y_type == 'lora_model_strength':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_model_strength = float(x_value) if self.x_type == "lora_model_strength" else float(y_value)
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": lora_model_strength, "clip_strength": plot_image_vars['lora_clip_strength']}
model, clip = self.easyCache.load_lora(lora)
if self.x_type == 'lora_clip_strength' or self.y_type == 'lora_clip_strength':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_clip_strength = float(x_value) if self.x_type == "lora_clip_strength" else float(y_value)
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": plot_image_vars['lora_model_strength'], "clip_strength": lora_clip_strength}
model, clip = self.easyCache.load_lora(lora)
# Check for custom VAE
if self.x_type == 'vae_name' or self.y_type == 'vae_name':
vae_name = x_value if self.x_type == "vae_name" else y_value
vae = self.easyCache.load_vae(vae_name)
# CLIP skip
if not clip:
raise Exception("No CLIP found")
clip = clip.clone()
clip.clip_layer(plot_image_vars['clip_skip'])
if sd_version == 'flux':
positive, = CLIPTextEncode().encode(clip, positive)
else:
positive = advanced_encode(clip, plot_image_vars['positive'],
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'], w_max=1.0,
apply_to_pooled="enable",a1111_prompt_style=a1111_prompt_style, steps=steps)
if sd_version == 'flux':
negative, = CLIPTextEncode().encode(clip, negative)
else:
negative = advanced_encode(clip, plot_image_vars['negative'],
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'], w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
model = model if model is not None else plot_image_vars["model"]
vae = vae if vae is not None else plot_image_vars["vae"]
positive = positive if positive is not None else plot_image_vars["positive_cond"]
negative = negative if negative is not None else plot_image_vars["negative_cond"]
seed = seed if seed is not None else plot_image_vars["seed"]
steps = steps if steps is not None else plot_image_vars["steps"]
cfg = cfg if cfg is not None else plot_image_vars["cfg"]
sampler_name = sampler_name if sampler_name is not None else plot_image_vars["sampler_name"]
scheduler = scheduler if scheduler is not None else plot_image_vars["scheduler"]
denoise = denoise if denoise is not None else plot_image_vars["denoise"]
noise_device = plot_image_vars["noise_device"] if "noise_device" in plot_image_vars else 'cpu'
# LayerDiffuse
layer_diffusion_method = plot_image_vars["layer_diffusion_method"] if "layer_diffusion_method" in plot_image_vars else None
empty_samples = plot_image_vars["empty_samples"] if "empty_samples" in plot_image_vars else None
if layer_diffusion_method:
samp_blend_samples = plot_image_vars["blend_samples"] if "blend_samples" in plot_image_vars else None
additional_cond = plot_image_vars["layer_diffusion_cond"] if "layer_diffusion_cond" in plot_image_vars else None
images = plot_image_vars["images"].movedim(-1, 1) if "images" in plot_image_vars else None
weight = plot_image_vars['layer_diffusion_weight'] if 'layer_diffusion_weight' in plot_image_vars else 1.0
model, positive, negative = LayerDiffuse().apply_layer_diffusion(model, layer_diffusion_method, weight, samples,
samp_blend_samples, positive,
negative, images, additional_cond)
samples = empty_samples if layer_diffusion_method is not None and empty_samples is not None else samples
# Sample
samples = self.sampler.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, samples,
denoise=denoise, disable_noise=disable_noise, preview_latent=preview_latent,
start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_device=noise_device)
# Decode images and store
latent = samples["samples"]
# Add the latent tensor to the tensors list
latents_plot.append(latent)
# Decode the image
image = vae.decode(latent).cpu()
if self.output_individuals in [True, "True"]:
easySave(image, self.save_prefix, self.image_output)
# Convert the image from tensor to PIL Image and add it to the list
pil_image = self.sampler.tensor2pil(image)
image_list.append(pil_image)
# Update max dimensions
self.max_width = max(self.max_width, pil_image.width)
self.max_height = max(self.max_height, pil_image.height)
# Return the touched variables
return image_list, self.max_width, self.max_height, latents_plot
# Process Functions
def validate_xy_plot(self):
if self.x_type == 'None' and self.y_type == 'None':
log_node_warn(f'#{self.my_unique_id}','No Valid Plot Types - Reverting to default sampling...')
return False
else:
return True
def get_latent(self, samples):
# Extract the 'samples' tensor from the dictionary
latent_image_tensor = samples["samples"]
# Split the tensor into individual image tensors
image_tensors = torch.split(latent_image_tensor, 1, dim=0)
# Create a list of dictionaries containing the individual image tensors
latent_list = [{'samples': image} for image in image_tensors]
# Set latent only to the first latent of batch
if self.latent_id >= len(latent_list):
log_node_warn(f'#{self.my_unique_id}',f'The selected latent_id ({self.latent_id}) is out of range.')
log_node_warn(f'#{self.my_unique_id}', f'Automatically setting the latent_id to the last image in the list (index: {len(latent_list) - 1}).')
self.latent_id = len(latent_list) - 1
return latent_list[self.latent_id]
def get_labels_and_sample(self, plot_image_vars, latent_image, preview_latent, start_step, last_step,
force_full_denoise, disable_noise):
for x_index, x_value in enumerate(self.x_values):
plot_image_vars, x_value_label = self.define_variable(plot_image_vars, self.x_type, x_value,
x_index)
self.x_label = self.update_label(self.x_label, x_value_label, len(self.x_values))
if self.y_type != 'None':
for y_index, y_value in enumerate(self.y_values):
plot_image_vars, y_value_label = self.define_variable(plot_image_vars, self.y_type, y_value,
y_index)
self.y_label = self.update_label(self.y_label, y_value_label, len(self.y_values))
# ttNl(f'{CC.GREY}X: {x_value_label}, Y: {y_value_label}').t(
# f'Plot Values {self.num}/{self.total} ->').p()
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list,
disable_noise, start_step, last_step, force_full_denoise, x_value, y_value)
self.num += 1
else:
# ttNl(f'{CC.GREY}X: {x_value_label}').t(f'Plot Values {self.num}/{self.total} ->').p()
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, disable_noise,
start_step, last_step, force_full_denoise, x_value)
self.num += 1
# Rearrange latent array to match preview image grid
self.latents_plot = self.rearrange_tensors(self.latents_plot, self.num_cols, self.num_rows)
# Concatenate the tensors along the first dimension (dim=0)
self.latents_plot = torch.cat(self.latents_plot, dim=0)
return self.latents_plot
def plot_images_and_labels(self, plot_image_vars):
bg_width, bg_height, x_offset_initial, y_offset = self.calculate_background_dimensions()
background = Image.new('RGBA', (int(bg_width), int(bg_height)), color=(255, 255, 255, 255))
output_image = []
for row_index in range(self.num_rows):
x_offset = x_offset_initial
for col_index in range(self.num_cols):
index = col_index * self.num_rows + row_index
img = self.image_list[index]
output_image.append(self.sampler.pil2tensor(img))
background.paste(img, (x_offset, y_offset))
# Handle X label
if row_index == 0 and self.x_type != "None":
label_bg = self.create_label(img, self.x_label[col_index], int(48 * img.width / 512))
label_y = (y_offset - label_bg.height) // 2
background.alpha_composite(label_bg, (x_offset, label_y))
# Handle Y label
if col_index == 0 and self.y_type != "None":
label_bg = self.create_label(img, self.y_label[row_index], int(48 * img.height / 512), False)
label_bg = label_bg.rotate(90, expand=True)
label_x = (x_offset - label_bg.width) // 2
label_y = y_offset + (img.height - label_bg.height) // 2
background.alpha_composite(label_bg, (label_x, label_y))
x_offset += img.width + self.grid_spacing
y_offset += img.height + self.grid_spacing
# lookup used models in the image
common_label = ""
# Update to add a function to do the heavy lifting. Parameters are plot_image_vars name, label to use, names of the axis,
# pprint.pp(plot_image_vars)
# We don't process LORAs here because there can be multiple of them.
labels = [
{"id": "ckpt_name", "id_desc": "ckpt", "axis_type" : "Checkpoint"},
{"id": "vae_name", "id_desc": '', "axis_type" : "vae_name"},
{"id": "sampler_name", "id_desc": "sampler", "axis_type" : "Sampler"},
{"id": "scheduler", "id_desc": '', "axis_type" : "Scheduler"},
{"id": "steps", "id_desc": '', "axis_type" : "Steps"},
{"id": "Flux Guidance", "id_desc": 'guidance', "axis_type" : "Flux Guidance"},
{"id": "seed", "id_desc": '', "axis_type" : "Seeds++ Batch"}
]
for item in labels:
# Only add the label if it's not one of the axis
# print(f"Checking item: {item['id']} axis_type {item['axis_type']} x_type: {self.x_type} y_type: {self.y_type}")
if self.x_type != item['axis_type'] and self.y_type != item['axis_type']:
common_label += self.add_common_label(item['id'], plot_image_vars, item['id_desc'])
common_label += f"\n"
if plot_image_vars['lora_stack'] is not None and plot_image_vars['lora_stack'] != []:
# print(f"lora_stack: {plot_image_vars['lora_stack']}")
for lora in plot_image_vars['lora_stack']:
lora_name = lora['lora_name']
lora_weight = lora['model_strength']
if lora_name is not None and len(lora_name) > 0 and lora_weight > 0:
common_label += f"LORA: {lora_name} weight: {lora_weight:.2f} \n"
common_label = common_label.strip()
if len(common_label) > 0:
label_height = background.height - y_offset
label_bg = self.create_label(background, common_label, int(48 * background.width / 512), label_width=background.width, label_height=label_height)
label_x = (background.width - label_bg.width) // 2
label_y = y_offset
# print(f"Adding common label: {common_label} x = {label_x} y = {label_y}")
background.alpha_composite(label_bg, (label_x, label_y))
return (self.sampler.pil2tensor(background), output_image)
def add_common_label(self, tag, plot_image_vars, description = ''):
label = ''
if description == '': description = tag
if tag in plot_image_vars and plot_image_vars[tag] is not None and plot_image_vars[tag] != 'None':
label += f"{description}: {plot_image_vars[tag]} "
# print(f"add_common_label: {tag} description: {description} label: {label}" )
return label