Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
113
custom_nodes/ComfyUI-Easy-Use/py/libs/add_resources.py
Normal file
113
custom_nodes/ComfyUI-Easy-Use/py/libs/add_resources.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import urllib.parse
|
||||
from os import PathLike
|
||||
from aiohttp import web
|
||||
from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher
|
||||
from server import PromptServer
|
||||
from pathlib import Path
|
||||
|
||||
# 文件限制大小(MB)
|
||||
max_size = 50
|
||||
def suffix_limiter(self: web.StaticResource, request: web.Request):
|
||||
suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"}
|
||||
rel_url = request.match_info["filename"]
|
||||
try:
|
||||
filename = Path(rel_url)
|
||||
if filename.anchor:
|
||||
raise web.HTTPForbidden()
|
||||
filepath = self._directory.joinpath(filename).resolve()
|
||||
if filepath.exists() and filepath.suffix.lower() not in suffixes:
|
||||
raise web.HTTPForbidden(reason="File type is not allowed")
|
||||
finally:
|
||||
pass
|
||||
|
||||
def filesize_limiter(self: web.StaticResource, request: web.Request):
|
||||
rel_url = request.match_info["filename"]
|
||||
try:
|
||||
filename = Path(rel_url)
|
||||
filepath = self._directory.joinpath(filename).resolve()
|
||||
if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024:
|
||||
raise web.HTTPForbidden(reason="File size is too large")
|
||||
finally:
|
||||
pass
|
||||
class LimitResource(web.StaticResource):
|
||||
limiters = []
|
||||
|
||||
def push_limiter(self, limiter):
|
||||
self.limiters.append(limiter)
|
||||
|
||||
async def _handle(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
for limiter in self.limiters:
|
||||
limiter(self, request)
|
||||
except (ValueError, FileNotFoundError) as error:
|
||||
raise web.HTTPNotFound() from error
|
||||
|
||||
return await super()._handle(request)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
name = "'" + self.name + "'" if self.name is not None else ""
|
||||
return f'<LimitResource {name} {self._prefix} -> {self._directory!r}>'
|
||||
|
||||
class LimitRouter(web.StaticDef):
|
||||
def __repr__(self) -> str:
|
||||
info = []
|
||||
for name, value in sorted(self.kwargs.items()):
|
||||
info.append(f", {name}={value!r}")
|
||||
return f'<LimitRouter {self.prefix} -> {self.path}{"".join(info)}>'
|
||||
|
||||
def register(self, router: UrlDispatcher) -> list[AbstractRoute]:
|
||||
# resource = router.add_static(self.prefix, self.path, **self.kwargs)
|
||||
def add_static(
|
||||
self: UrlDispatcher,
|
||||
prefix: str,
|
||||
path: PathLike,
|
||||
*,
|
||||
name=None,
|
||||
expect_handler=None,
|
||||
chunk_size: int = 256 * 1024,
|
||||
show_index: bool = False,
|
||||
follow_symlinks: bool = False,
|
||||
append_version: bool = False,
|
||||
) -> web.AbstractResource:
|
||||
assert prefix.startswith("/")
|
||||
if prefix.endswith("/"):
|
||||
prefix = prefix[:-1]
|
||||
resource = LimitResource(
|
||||
prefix,
|
||||
path,
|
||||
name=name,
|
||||
expect_handler=expect_handler,
|
||||
chunk_size=chunk_size,
|
||||
show_index=show_index,
|
||||
follow_symlinks=follow_symlinks,
|
||||
append_version=append_version,
|
||||
)
|
||||
resource.push_limiter(suffix_limiter)
|
||||
resource.push_limiter(filesize_limiter)
|
||||
self.register_resource(resource)
|
||||
return resource
|
||||
resource = add_static(router, self.prefix, self.path, **self.kwargs)
|
||||
routes = resource.get_info().get("routes", {})
|
||||
return list(routes.values())
|
||||
|
||||
def path_to_url(path):
|
||||
if not path:
|
||||
return path
|
||||
path = path.replace("\\", "/")
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
while path.startswith("//"):
|
||||
path = path[1:]
|
||||
path = path.replace("//", "/")
|
||||
return path
|
||||
|
||||
def add_static_resource(prefix, path,limit=False):
|
||||
app = PromptServer.instance.app
|
||||
prefix = path_to_url(prefix)
|
||||
prefix = urllib.parse.quote(prefix)
|
||||
prefix = path_to_url(prefix)
|
||||
if limit:
|
||||
route = LimitRouter(prefix, path, {"follow_symlinks": True})
|
||||
else:
|
||||
route = web.static(prefix, path, follow_symlinks=True)
|
||||
app.add_routes([route])
|
||||
427
custom_nodes/ComfyUI-Easy-Use/py/libs/adv_encode.py
Normal file
427
custom_nodes/ComfyUI-Easy-Use/py/libs/adv_encode.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import re
|
||||
import itertools
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG
|
||||
try:
|
||||
from comfy.text_encoders.sd3_clip import SD3ClipModel, T5XXLModel
|
||||
except ImportError:
|
||||
from comfy.sd3_clip import SD3ClipModel, T5XXLModel
|
||||
|
||||
from nodes import NODE_CLASS_MAPPINGS, ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
|
||||
|
||||
def _grouper(n, iterable):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def _norm_mag(w, n):
|
||||
d = w - 1
|
||||
return 1 + np.sign(d) * np.sqrt(np.abs(d) ** 2 / n)
|
||||
# return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
|
||||
|
||||
|
||||
def divide_length(word_ids, weights):
|
||||
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
|
||||
sums[0] = 1
|
||||
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def shift_mean_weight(word_ids, weights):
|
||||
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x, y) if id != 0])
|
||||
weights = [[w if id == 0 else w + delta
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def scale_to_norm(weights, word_ids, w_max):
|
||||
top = np.max(weights)
|
||||
w_max = min(top, w_max)
|
||||
weights = [[w_max if id == 0 else (w / top) * w_max
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def from_zero(weights, base_emb):
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
return base_emb * weight_tensor
|
||||
|
||||
|
||||
def mask_word_id(tokens, word_ids, target_id, mask_token):
|
||||
new_tokens = [[mask_token if wid == target_id else t
|
||||
for t, wid in zip(x, y)] for x, y in zip(tokens, word_ids)]
|
||||
mask = np.array(word_ids) == target_id
|
||||
return (new_tokens, mask)
|
||||
|
||||
|
||||
def batched_clip_encode(tokens, length, encode_func, num_chunks):
|
||||
embs = []
|
||||
for e in _grouper(32, tokens):
|
||||
enc, pooled = encode_func(e)
|
||||
enc = enc.reshape((len(e), length, -1))
|
||||
embs.append(enc)
|
||||
embs = torch.cat(embs)
|
||||
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
|
||||
return embs
|
||||
|
||||
|
||||
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
pooled_base = base_emb[0, length - 1:length, :]
|
||||
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
|
||||
weight_dict = dict((id, w)
|
||||
for id, w in zip(wids, np.array(weights).reshape(-1)[inds])
|
||||
if w != 1.0)
|
||||
|
||||
if len(weight_dict) == 0:
|
||||
return torch.zeros_like(base_emb), base_emb[0, length - 1:length, :]
|
||||
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
|
||||
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
# TODO: find most suitable masking token here
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
ws = []
|
||||
masked_tokens = []
|
||||
masks = []
|
||||
|
||||
# create prompts
|
||||
for id, w in weight_dict.items():
|
||||
masked, m = mask_word_id(tokens, word_ids, id, m_token)
|
||||
masked_tokens.extend(masked)
|
||||
|
||||
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
|
||||
m = m.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
masks.append(m)
|
||||
|
||||
ws.append(w)
|
||||
|
||||
# batch process prompts
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
masks = torch.cat(masks)
|
||||
|
||||
embs = (base_emb.expand(embs.shape) - embs)
|
||||
pooled = embs[0, length - 1:length, :]
|
||||
|
||||
embs *= masks
|
||||
embs = embs.sum(axis=0, keepdim=True)
|
||||
|
||||
pooled_start = pooled_base.expand(len(ws), -1)
|
||||
ws = torch.tensor(ws).reshape(-1, 1).expand(pooled_start.shape)
|
||||
pooled = (pooled - pooled_start) * (ws - 1)
|
||||
pooled = pooled.mean(axis=0, keepdim=True)
|
||||
|
||||
return ((weight_tensor - 1) * embs), pooled_base + pooled
|
||||
|
||||
|
||||
def mask_inds(tokens, inds, mask_token):
|
||||
clip_len = len(tokens[0])
|
||||
inds_set = set(inds)
|
||||
new_tokens = [[mask_token if i * clip_len + j in inds_set else t
|
||||
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
|
||||
return new_tokens
|
||||
|
||||
|
||||
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
w, w_inv = np.unique(weights, return_inverse=True)
|
||||
|
||||
if np.sum(w < 1) == 0:
|
||||
return base_emb, tokens, base_emb[0, length - 1:length, :]
|
||||
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
# using the comma token as a masking token seems to work better than aos tokens for SD 1.x
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
masked_tokens = []
|
||||
|
||||
masked_current = tokens
|
||||
for i in range(len(w)):
|
||||
if w[i] >= 1:
|
||||
continue
|
||||
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
|
||||
masked_tokens.extend(masked_current)
|
||||
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
embs = torch.cat([base_emb, embs])
|
||||
w = w[w <= 1.0]
|
||||
w_mix = np.diff([0] + w.tolist())
|
||||
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1, 1, 1))
|
||||
|
||||
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
|
||||
return weighted_emb, masked_current, weighted_emb[0, length - 1:length, :]
|
||||
|
||||
|
||||
def scale_emb_to_mag(base_emb, weighted_emb):
|
||||
norm_base = torch.linalg.norm(base_emb)
|
||||
norm_weighted = torch.linalg.norm(weighted_emb)
|
||||
embeddings_final = (norm_base / norm_weighted) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def recover_dist(base_emb, weighted_emb):
|
||||
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
|
||||
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def A1111_renorm(base_emb, weighted_emb):
|
||||
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266,
|
||||
length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
|
||||
tokens = [[t for t, _, _ in x] for x in tokenized]
|
||||
weights = [[w for _, w, _ in x] for x in tokenized]
|
||||
word_ids = [[wid for _, _, wid in x] for x in tokenized]
|
||||
|
||||
# weight normalization
|
||||
# ====================
|
||||
|
||||
# distribute down/up weights over word lengths
|
||||
if token_normalization.startswith("length"):
|
||||
weights = divide_length(word_ids, weights)
|
||||
|
||||
# make mean of word tokens 1
|
||||
if token_normalization.endswith("mean"):
|
||||
weights = shift_mean_weight(word_ids, weights)
|
||||
|
||||
# weight interpretation
|
||||
# =====================
|
||||
pooled = None
|
||||
|
||||
if weight_interpretation == "comfy":
|
||||
weighted_tokens = [[(t, w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, pooled_base = encode_func(weighted_tokens)
|
||||
pooled = pooled_base
|
||||
else:
|
||||
unweighted_tokens = [[(t, 1.0) for t, _, _ in x] for x in tokenized]
|
||||
base_emb, pooled_base = encode_func(unweighted_tokens)
|
||||
|
||||
if weight_interpretation == "A1111":
|
||||
weighted_emb = from_zero(weights, base_emb)
|
||||
weighted_emb = A1111_renorm(base_emb, weighted_emb)
|
||||
pooled = pooled_base
|
||||
|
||||
if weight_interpretation == "compel":
|
||||
pos_tokens = [[(t, w) if w >= 1.0 else (t, 1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, _ = encode_func(pos_tokens)
|
||||
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
|
||||
|
||||
if weight_interpretation == "comfy++":
|
||||
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
|
||||
# unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
|
||||
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weighted_emb += embs
|
||||
|
||||
if weight_interpretation == "down_weight":
|
||||
weights = scale_to_norm(weights, word_ids, w_max)
|
||||
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
|
||||
if return_pooled:
|
||||
if apply_to_pooled:
|
||||
return weighted_emb, pooled
|
||||
else:
|
||||
return weighted_emb, pooled_base
|
||||
return weighted_emb, None
|
||||
|
||||
|
||||
def encode_token_weights_g(model, token_weight_pairs):
|
||||
return model.clip_g.encode_token_weights(token_weight_pairs)
|
||||
|
||||
|
||||
def encode_token_weights_l(model, token_weight_pairs):
|
||||
l_out, pooled = model.clip_l.encode_token_weights(token_weight_pairs)
|
||||
return l_out, pooled
|
||||
|
||||
def encode_token_weights_t5(model, token_weight_pairs):
|
||||
return model.t5xxl.encode_token_weights(token_weight_pairs)
|
||||
|
||||
|
||||
def encode_token_weights(model, token_weight_pairs, encode_func):
|
||||
if model.layer_idx is not None:
|
||||
# 2016 [c2cb8e88] 及以上版本去除了sdxl clip的clip_layer方法
|
||||
# if compare_revision(2016):
|
||||
model.cond_stage_model.set_clip_options({'layer': model.layer_idx})
|
||||
# else:
|
||||
# model.cond_stage_model.clip_layer(model.layer_idx)
|
||||
|
||||
model_management.load_model_gpu(model.patcher)
|
||||
return encode_func(model.cond_stage_model, token_weight_pairs)
|
||||
|
||||
def prepareXL(embs_l, embs_g, pooled, clip_balance):
|
||||
l_w = 1 - max(0, clip_balance - .5) * 2
|
||||
g_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if embs_l is not None:
|
||||
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
|
||||
else:
|
||||
return embs_g, pooled
|
||||
|
||||
def prepareSD3(out, pooled, clip_balance):
|
||||
lg_w = 1 - max(0, clip_balance - .5) * 2
|
||||
t5_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if out.shape[0] > 1:
|
||||
return torch.cat([out[0] * lg_w, out[1] * t5_w], dim=-1), pooled
|
||||
else:
|
||||
return out, pooled
|
||||
|
||||
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5,
|
||||
apply_to_pooled=True, width=1024, height=1024, crop_w=0, crop_h=0, target_width=1024, target_height=1024, a1111_prompt_style=False, steps=1):
|
||||
|
||||
# Use clip text encode by smzNodes like same as a1111, when if you need installed the smzNodes
|
||||
if a1111_prompt_style:
|
||||
if "smZ CLIPTextEncode" in NODE_CLASS_MAPPINGS:
|
||||
cls = NODE_CLASS_MAPPINGS['smZ CLIPTextEncode']
|
||||
embeddings_final, = cls().encode(clip, text, weight_interpretation, True, True, False, False, 6, 1024, 1024, 0, 0, 1024, 1024, '', '', steps)
|
||||
return embeddings_final
|
||||
else:
|
||||
raise Exception(f"[smzNodes Not Found] you need to install 'ComfyUI-smzNodes'")
|
||||
|
||||
time_start = 0
|
||||
time_end = 1
|
||||
match = re.search(r'TIMESTEP.*$', text)
|
||||
if match:
|
||||
timestep = match.group()
|
||||
timestep = timestep.split(' ')
|
||||
timestep = timestep[0]
|
||||
text = text.replace(timestep, '')
|
||||
value = timestep.split(':')
|
||||
if len(value) >= 3:
|
||||
time_start = float(value[1])
|
||||
time_end = float(value[2])
|
||||
elif len(value) == 2:
|
||||
time_start = float(value[1])
|
||||
time_end = 1
|
||||
elif len(value) == 1:
|
||||
time_start = 0.1
|
||||
time_end = 1
|
||||
|
||||
pass3 = [x.strip() for x in text.split("BREAK")]
|
||||
pass3 = [x for x in pass3 if x != '']
|
||||
|
||||
if len(pass3) == 0:
|
||||
pass3 = ['']
|
||||
|
||||
# pass3_str = [f'[{x}]' for x in pass3]
|
||||
# print(f"CLIP: {str.join(' + ', pass3_str)}")
|
||||
|
||||
conditioning = None
|
||||
|
||||
for text in pass3:
|
||||
tokenized = clip.tokenize(text, return_word_ids=True)
|
||||
if SD3ClipModel and isinstance(clip.cond_stage_model, SD3ClipModel):
|
||||
lg_out = None
|
||||
pooled = None
|
||||
out = None
|
||||
|
||||
if len(tokenized['l']) > 0 or len(tokenized['g']) > 0:
|
||||
if clip.cond_stage_model.clip_l is not None:
|
||||
lg_out, l_pooled = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max, return_pooled=True,)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=model_management.intermediate_device())
|
||||
|
||||
if clip.cond_stage_model.clip_g is not None:
|
||||
g_out, g_pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
|
||||
w_max=w_max, return_pooled=True)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
g_out = None
|
||||
g_pooled = torch.zeros((1, 1280), device=model_management.intermediate_device())
|
||||
|
||||
if lg_out is not None:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
# t5xxl
|
||||
if 't5xxl' in tokenized:
|
||||
t5_out, t5_pooled = advanced_encode_from_tokens(tokenized['t5xxl'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_t5),
|
||||
w_max=w_max, return_pooled=True)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros((1, 77, 4096), device=model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=model_management.intermediate_device())
|
||||
|
||||
embeddings_final, pooled = prepareSD3(out, pooled, clip_balance)
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
|
||||
elif isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)):
|
||||
embs_l = None
|
||||
embs_g = None
|
||||
pooled = None
|
||||
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
|
||||
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,
|
||||
return_pooled=False)
|
||||
if 'g' in tokenized:
|
||||
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x,
|
||||
encode_token_weights_g),
|
||||
w_max=w_max,
|
||||
return_pooled=True,
|
||||
apply_to_pooled=apply_to_pooled)
|
||||
|
||||
embeddings_final, pooled = prepareXL(embs_l, embs_g, pooled, clip_balance)
|
||||
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
# cond = [[embeddings_final,
|
||||
# {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w,
|
||||
# "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
|
||||
else:
|
||||
embeddings_final, pooled = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,return_pooled=True,)
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
|
||||
if conditioning is not None:
|
||||
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
|
||||
else:
|
||||
conditioning = cond
|
||||
|
||||
# setTimeStepRange
|
||||
if time_start > 0 or time_end < 1:
|
||||
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
|
||||
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
|
||||
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
|
||||
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
|
||||
372
custom_nodes/ComfyUI-Easy-Use/py/libs/api/bizyair.py
Normal file
372
custom_nodes/ComfyUI-Easy-Use/py/libs/api/bizyair.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import yaml
|
||||
import pathlib
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import zlib
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from enum import Enum
|
||||
from functools import singledispatch
|
||||
from typing import Any, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path, 'config.yaml')
|
||||
|
||||
class BizyAIRAPI:
|
||||
def __init__(self):
|
||||
self.base_url = 'https://bizyair-api.siliconflow.cn/x/v1'
|
||||
self.api_key = None
|
||||
|
||||
|
||||
def getAPIKey(self):
|
||||
if self.api_key is None:
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if 'BIZYAIR_API_KEY' not in data:
|
||||
raise Exception("Please add BIZYAIR_API_KEY to config.yaml")
|
||||
self.api_key = data['BIZYAIR_API_KEY']
|
||||
else:
|
||||
raise Exception("Please add config.yaml to root path")
|
||||
return self.api_key
|
||||
|
||||
def send_post_request(self, url, payload, headers):
|
||||
try:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
response_data = response.read().decode("utf-8")
|
||||
return response_data
|
||||
except urllib.error.URLError as e:
|
||||
if "Unauthorized" in str(e):
|
||||
raise Exception(
|
||||
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
|
||||
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to connect to the server: {e}, if you have no key, "
|
||||
)
|
||||
|
||||
# joycaption
|
||||
def joyCaption(self, payload, image, apikey_override=None, API_URL='/supernode/joycaption2'):
|
||||
if apikey_override is not None:
|
||||
api_key = apikey_override
|
||||
else:
|
||||
api_key = self.getAPIKey()
|
||||
url = f"{self.base_url}{API_URL}"
|
||||
print('Sending request to:', url)
|
||||
auth = f"Bearer {api_key}"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": auth,
|
||||
}
|
||||
input_image = encode_data(image, disable_image_marker=True)
|
||||
payload["image"] = input_image
|
||||
|
||||
ret: str = self.send_post_request(url=url, payload=payload, headers=headers)
|
||||
ret = json.loads(ret)
|
||||
|
||||
try:
|
||||
if "result" in ret:
|
||||
ret = json.loads(ret["result"])
|
||||
except Exception as e:
|
||||
raise Exception(f"Unexpected response: {ret} {e=}")
|
||||
|
||||
if ret["type"] == "error":
|
||||
raise Exception(ret["message"])
|
||||
|
||||
msg = ret["data"]
|
||||
if msg["type"] not in ("comfyair", "bizyair",):
|
||||
raise Exception(f"Unexpected response type: {msg}")
|
||||
|
||||
caption = msg["data"]
|
||||
|
||||
return caption
|
||||
|
||||
bizyairAPI = BizyAIRAPI()
|
||||
|
||||
|
||||
|
||||
BIZYAIR_DEBUG = True
|
||||
# Marker to identify base64-encoded tensors
|
||||
TENSOR_MARKER = "TENSOR:"
|
||||
IMAGE_MARKER = "IMAGE:"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: Image.Image) -> Image.Image:
|
||||
if image.mode != "RGB":
|
||||
return image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def encode_image_to_base64(
|
||||
image: Image.Image, format: str = "png", quality: int = 100, lossless=False
|
||||
) -> str:
|
||||
image = convert_image_to_rgb(image)
|
||||
with io.BytesIO() as output:
|
||||
image.save(output, format=format, quality=quality, lossless=lossless)
|
||||
output.seek(0)
|
||||
img_bytes = output.getvalue()
|
||||
if BIZYAIR_DEBUG:
|
||||
print(f"encode_image_to_base64: {format_bytes(len(img_bytes))}")
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def decode_base64_to_np(img_data: str, format: str = "png") -> np.ndarray:
|
||||
img_bytes = base64.b64decode(img_data)
|
||||
if BIZYAIR_DEBUG:
|
||||
print(f"decode_base64_to_np: {format_bytes(len(img_bytes))}")
|
||||
with io.BytesIO(img_bytes) as input_buffer:
|
||||
img = Image.open(input_buffer)
|
||||
# https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/nodes.py#L1511
|
||||
img = img.convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def decode_base64_to_image(img_data: str) -> Image.Image:
|
||||
img_bytes = base64.b64decode(img_data)
|
||||
with io.BytesIO(img_bytes) as input_buffer:
|
||||
img = Image.open(input_buffer)
|
||||
if BIZYAIR_DEBUG:
|
||||
format_info = img.format.upper() if img.format else "Unknown"
|
||||
print(f"decode image format: {format_info}")
|
||||
return img
|
||||
|
||||
|
||||
def format_bytes(num_bytes: int) -> str:
|
||||
"""
|
||||
Converts a number of bytes to a human-readable string with units (B, KB, or MB).
|
||||
|
||||
:param num_bytes: The number of bytes to convert.
|
||||
:return: A string representing the number of bytes in a human-readable format.
|
||||
"""
|
||||
if num_bytes < 1024:
|
||||
return f"{num_bytes} B"
|
||||
elif num_bytes < 1024 * 1024:
|
||||
return f"{num_bytes / 1024:.2f} KB"
|
||||
else:
|
||||
return f"{num_bytes / (1024 * 1024):.2f} MB"
|
||||
|
||||
|
||||
def _legacy_encode_comfy_image(image: torch.Tensor, image_format="png") -> str:
|
||||
input_image = image.cpu().detach().numpy()
|
||||
i = 255.0 * input_image[0]
|
||||
input_image = np.clip(i, 0, 255).astype(np.uint8)
|
||||
base64ed_image = encode_image_to_base64(
|
||||
Image.fromarray(input_image), format=image_format
|
||||
)
|
||||
return base64ed_image
|
||||
|
||||
|
||||
def _legacy_decode_comfy_image(
|
||||
img_data: Union[List, str], image_format="png"
|
||||
) -> torch.tensor:
|
||||
if isinstance(img_data, List):
|
||||
decoded_imgs = [decode_comfy_image(x, old_version=True) for x in img_data]
|
||||
|
||||
combined_imgs = torch.cat(decoded_imgs, dim=0)
|
||||
return combined_imgs
|
||||
|
||||
out = decode_base64_to_np(img_data, format=image_format)
|
||||
out = np.array(out).astype(np.float32) / 255.0
|
||||
output = torch.from_numpy(out)[None,]
|
||||
return output
|
||||
|
||||
|
||||
def _new_encode_comfy_image(images: torch.Tensor, image_format="WEBP", **kwargs) -> str:
|
||||
"""https://docs.comfy.org/essentials/custom_node_snippets#save-an-image-batch
|
||||
Encode a batch of images to base64 strings.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): A batch of images.
|
||||
image_format (str, optional): The format of the images. Defaults to "WEBP".
|
||||
|
||||
Returns:
|
||||
str: A JSON string containing the base64-encoded images.
|
||||
"""
|
||||
results = {}
|
||||
for batch_number, image in enumerate(images):
|
||||
i = 255.0 * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
base64ed_image = encode_image_to_base64(img, format=image_format, **kwargs)
|
||||
results[batch_number] = base64ed_image
|
||||
|
||||
return json.dumps(results)
|
||||
|
||||
|
||||
def _new_decode_comfy_image(img_datas: str, image_format="WEBP") -> torch.tensor:
|
||||
"""
|
||||
Decode a batch of base64-encoded images.
|
||||
|
||||
Args:
|
||||
img_datas (str): A JSON string containing the base64-encoded images.
|
||||
image_format (str, optional): The format of the images. Defaults to "WEBP".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the decoded images.
|
||||
"""
|
||||
img_datas = json.loads(img_datas)
|
||||
|
||||
decoded_imgs = []
|
||||
for img_data in img_datas.values():
|
||||
decoded_image = decode_base64_to_np(img_data, format=image_format)
|
||||
decoded_image = np.array(decoded_image).astype(np.float32) / 255.0
|
||||
decoded_imgs.append(torch.from_numpy(decoded_image)[None,])
|
||||
|
||||
return torch.cat(decoded_imgs, dim=0)
|
||||
|
||||
|
||||
def encode_comfy_image(
|
||||
image: torch.Tensor, image_format="WEBP", old_version=False, lossless=False
|
||||
) -> str:
|
||||
if old_version:
|
||||
return _legacy_encode_comfy_image(image, image_format)
|
||||
return _new_encode_comfy_image(image, image_format, lossless=lossless)
|
||||
|
||||
|
||||
def decode_comfy_image(
|
||||
img_data: Union[List, str], image_format="WEBP", old_version=False
|
||||
) -> torch.tensor:
|
||||
if old_version:
|
||||
return _legacy_decode_comfy_image(img_data, image_format)
|
||||
return _new_decode_comfy_image(img_data, image_format)
|
||||
|
||||
|
||||
def tensor_to_base64(tensor: torch.Tensor, compress=True) -> str:
|
||||
tensor_np = tensor.cpu().detach().numpy()
|
||||
|
||||
tensor_bytes = pickle.dumps(tensor_np)
|
||||
if compress:
|
||||
tensor_bytes = zlib.compress(tensor_bytes)
|
||||
|
||||
tensor_b64 = base64.b64encode(tensor_bytes).decode("utf-8")
|
||||
return tensor_b64
|
||||
|
||||
|
||||
def base64_to_tensor(tensor_b64: str, compress=True) -> torch.Tensor:
|
||||
tensor_bytes = base64.b64decode(tensor_b64)
|
||||
|
||||
if compress:
|
||||
tensor_bytes = zlib.decompress(tensor_bytes)
|
||||
|
||||
tensor_np = pickle.loads(tensor_bytes)
|
||||
|
||||
tensor = torch.from_numpy(tensor_np)
|
||||
return tensor
|
||||
|
||||
|
||||
@singledispatch
|
||||
def decode_data(input, old_version=False):
|
||||
raise NotImplementedError(f"Unsupported type: {type(input)}")
|
||||
|
||||
|
||||
@decode_data.register(int)
|
||||
@decode_data.register(float)
|
||||
@decode_data.register(bool)
|
||||
@decode_data.register(type(None))
|
||||
def _(input, **kwargs):
|
||||
return input
|
||||
|
||||
|
||||
@decode_data.register(dict)
|
||||
def _(input, **kwargs):
|
||||
return {k: decode_data(v, **kwargs) for k, v in input.items()}
|
||||
|
||||
|
||||
@decode_data.register(list)
|
||||
def _(input, **kwargs):
|
||||
return [decode_data(x, **kwargs) for x in input]
|
||||
|
||||
|
||||
@decode_data.register(str)
|
||||
def _(input: str, **kwargs):
|
||||
if input.startswith(TENSOR_MARKER):
|
||||
tensor_b64 = input[len(TENSOR_MARKER) :]
|
||||
return base64_to_tensor(tensor_b64)
|
||||
elif input.startswith(IMAGE_MARKER):
|
||||
tensor_b64 = input[len(IMAGE_MARKER) :]
|
||||
old_version = kwargs.get("old_version", False)
|
||||
return decode_comfy_image(tensor_b64, old_version=old_version)
|
||||
return input
|
||||
|
||||
|
||||
@singledispatch
|
||||
def encode_data(output, disable_image_marker=False, old_version=False):
|
||||
raise NotImplementedError(f"Unsupported type: {type(output)}")
|
||||
|
||||
|
||||
@encode_data.register(dict)
|
||||
def _(output, **kwargs):
|
||||
return {k: encode_data(v, **kwargs) for k, v in output.items()}
|
||||
|
||||
|
||||
@encode_data.register(list)
|
||||
def _(output, **kwargs):
|
||||
return [encode_data(x, **kwargs) for x in output]
|
||||
|
||||
|
||||
def is_image_tensor(tensor) -> bool:
|
||||
"""https://docs.comfy.org/essentials/custom_node_datatypes#image
|
||||
|
||||
Check if the given tensor is in the format of an IMAGE (shape [B, H, W, C] where C=3).
|
||||
|
||||
`Args`:
|
||||
tensor (torch.Tensor): The tensor to check.
|
||||
|
||||
`Returns`:
|
||||
bool: True if the tensor is in the IMAGE format, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
return False
|
||||
|
||||
if len(tensor.shape) != 4:
|
||||
return False
|
||||
|
||||
B, H, W, C = tensor.shape
|
||||
if C != 3:
|
||||
return False
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
@encode_data.register(torch.Tensor)
|
||||
def _(output, **kwargs):
|
||||
if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
|
||||
old_version = kwargs.get("old_version", False)
|
||||
lossless = kwargs.get("lossless", True)
|
||||
return IMAGE_MARKER + encode_comfy_image(
|
||||
output, image_format="WEBP", old_version=old_version, lossless=lossless
|
||||
)
|
||||
return TENSOR_MARKER + tensor_to_base64(output)
|
||||
|
||||
|
||||
@encode_data.register(int)
|
||||
@encode_data.register(float)
|
||||
@encode_data.register(bool)
|
||||
@encode_data.register(type(None))
|
||||
def _(output, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
@encode_data.register(str)
|
||||
def _(output, **kwargs):
|
||||
return output
|
||||
51
custom_nodes/ComfyUI-Easy-Use/py/libs/api/fluxai.py
Normal file
51
custom_nodes/ComfyUI-Easy-Use/py/libs/api/fluxai.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
import requests
|
||||
import pathlib
|
||||
from aiohttp import web
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path,'config.yaml')
|
||||
class FluxAIAPI:
|
||||
def __init__(self):
|
||||
self.api_url = "https://fluxaiimagegenerator.com/api"
|
||||
self.origin = "https://fluxaiimagegenerator.com"
|
||||
self.user_agent = None
|
||||
self.cookie = None
|
||||
|
||||
def promptGenerate(self, text, cookies=None):
|
||||
cookie = self.cookie if cookies is None else cookies
|
||||
if cookie is None:
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if 'FLUXAI_COOKIE' not in data:
|
||||
raise Exception("Please add FLUXAI_COOKIE to config.yaml")
|
||||
if "FLUXAI_USER_AGENT" in data:
|
||||
self.user_agent = data["FLUXAI_USER_AGENT"]
|
||||
self.cookie = cookie = data['FLUXAI_COOKIE']
|
||||
|
||||
headers = {
|
||||
"Cookie": cookie,
|
||||
"Referer": "https://fluxaiimagegenerator.com/flux-prompt-generator",
|
||||
"Origin": self.origin,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.user_agent is not None:
|
||||
headers['User-Agent'] = self.user_agent
|
||||
|
||||
url = self.api_url + '/prompt'
|
||||
json = {
|
||||
"prompt": text
|
||||
}
|
||||
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
res = response.json()
|
||||
if "error" in res:
|
||||
return res['error']
|
||||
elif "data" in res and "prompt" in res['data']:
|
||||
return res['data']['prompt']
|
||||
|
||||
fluxaiAPI = FluxAIAPI()
|
||||
|
||||
200
custom_nodes/ComfyUI-Easy-Use/py/libs/api/stability.py
Normal file
200
custom_nodes/ComfyUI-Easy-Use/py/libs/api/stability.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
import requests
|
||||
import pathlib
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
from ..image import tensor2pil, pil2tensor, image2base64, pil2byte
|
||||
from ..log import log_node_error
|
||||
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path,'config.yaml')
|
||||
default_key = [{'name':'Default', 'key':''}]
|
||||
|
||||
|
||||
class StabilityAPI:
|
||||
def __init__(self):
|
||||
self.api_url = "https://api.stability.ai"
|
||||
self.api_keys = None
|
||||
self.api_current = 0
|
||||
self.user_info = {}
|
||||
|
||||
def getErrors(self, code):
|
||||
errors = {
|
||||
400: "Bad Request",
|
||||
403: "ApiKey Forbidden",
|
||||
413: "Your request was larger than 10MiB.",
|
||||
429: "You have made more than 150 requests in 10 seconds.",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
return errors.get(code, "Unknown Error")
|
||||
|
||||
def getAPIKeys(self):
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if not data:
|
||||
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
if 'STABILITY_API_KEY' not in data:
|
||||
data['STABILITY_API_KEY'] = default_key
|
||||
data['STABILITY_API_DEFAULT'] = 0
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
api_keys = data['STABILITY_API_KEY']
|
||||
self.api_current = data['STABILITY_API_DEFAULT']
|
||||
self.api_keys = api_keys
|
||||
return api_keys
|
||||
else:
|
||||
# create a yaml file
|
||||
with open(config_path, 'w') as f:
|
||||
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
|
||||
yaml.dump(data, f)
|
||||
return data['STABILITY_API_KEY']
|
||||
pass
|
||||
|
||||
def setAPIKeys(self, api_keys):
|
||||
if len(api_keys) > 0:
|
||||
self.api_keys = api_keys
|
||||
# load and save the yaml file
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data['STABILITY_API_KEY'] = api_keys
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
return True
|
||||
|
||||
def setAPIDefault(self, current):
|
||||
if current is not None:
|
||||
self.api_current = current
|
||||
# load and save the yaml file
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data['STABILITY_API_DEFAULT'] = current
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
return True
|
||||
|
||||
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
|
||||
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
files = None
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"mode": mode,
|
||||
"model": model,
|
||||
"seed": seed,
|
||||
"output_format": output_format,
|
||||
}
|
||||
if model == 'sd3':
|
||||
data['negative_prompt'] = negative_prompt
|
||||
|
||||
if mode == 'text-to-image':
|
||||
files = {"none": ''}
|
||||
data['aspect_ratio'] = aspect_ratio
|
||||
elif mode == 'image-to-image':
|
||||
pil_image = tensor2pil(image)
|
||||
image_byte = pil2byte(pil_image)
|
||||
files = {"image": ("output.png", image_byte, 'image/png')}
|
||||
data['strength'] = strength
|
||||
|
||||
response = requests.post(url,
|
||||
headers={"authorization": f"{api_key}", "accept": "application/json"},
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
|
||||
json_data = response.json()
|
||||
image_base64 = json_data['image']
|
||||
image_data = image2base64(image_base64)
|
||||
output_t = pil2tensor(image_data)
|
||||
return output_t
|
||||
else:
|
||||
if 'application/json' in response.headers['Content-Type']:
|
||||
error_info = response.json()
|
||||
log_node_error(node_name, error_info.get('name', 'No name provided'))
|
||||
log_node_error(node_name, error_info.get('errors', ['No details provided']))
|
||||
error_status_text = self.getErrors(response.status_code)
|
||||
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
|
||||
raise Exception(f"Failed to generate image: {error_status_text}")
|
||||
|
||||
# get user account
|
||||
async def getUserAccount(self, cache=True):
|
||||
url = f"{self.api_url}/v1/user/account"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
name = self.api_keys[self.api_current]['name']
|
||||
if cache and name in self.user_info:
|
||||
return self.user_info[name]
|
||||
else:
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
|
||||
if response.status_code == 200:
|
||||
user_info = response.json()
|
||||
self.user_info[name] = user_info
|
||||
return user_info
|
||||
else:
|
||||
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
|
||||
return None
|
||||
|
||||
# get user balance
|
||||
async def getUserBalance(self):
|
||||
url = f"{self.api_url}/v1/user/balance"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
response = requests.get(url, headers={
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
})
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
|
||||
return None
|
||||
|
||||
stableAPI = StabilityAPI()
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
|
||||
async def get_stability_api_keys(request):
|
||||
stableAPI.getAPIKeys()
|
||||
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
|
||||
|
||||
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
|
||||
async def set_stability_api_keys(request):
|
||||
post = await request.post()
|
||||
api_keys = post.get("api_keys")
|
||||
current = post.get('current')
|
||||
if api_keys is not None:
|
||||
api_keys = json.loads(api_keys)
|
||||
stableAPI.setAPIKeys(api_keys)
|
||||
if current is not None:
|
||||
print(current)
|
||||
stableAPI.setAPIDefault(int(current))
|
||||
account = await stableAPI.getUserAccount()
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'account': account, 'balance': balance})
|
||||
else:
|
||||
return web.json_response({'status': 'ok'})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
|
||||
async def set_stability_api_default(request):
|
||||
post = await request.post()
|
||||
current = post.get("current")
|
||||
if current is not None and current < len(stableAPI.api_keys):
|
||||
stableAPI.api_current = current
|
||||
return web.json_response({'status': 'ok'})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
|
||||
async def get_account_info(request):
|
||||
account = await stableAPI.getUserAccount()
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'account': account, 'balance': balance})
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/balance")
|
||||
async def get_balance_info(request):
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'balance': balance})
|
||||
86
custom_nodes/ComfyUI-Easy-Use/py/libs/cache.py
Normal file
86
custom_nodes/ComfyUI-Easy-Use/py/libs/cache.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
class TaggedCache:
|
||||
def __init__(self, tag_settings: Optional[dict]=None):
|
||||
self._tag_settings = tag_settings or {} # tag cache size
|
||||
self._data = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
return tag_data[key]
|
||||
raise KeyError(f'Key `{key}` does not exist')
|
||||
|
||||
def __setitem__(self, key, value: tuple):
|
||||
# value: (tag: str, (islist: bool, data: *))
|
||||
|
||||
# if key already exists, pop old value
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
tag_data.pop(key, None)
|
||||
break
|
||||
|
||||
tag = value[0]
|
||||
if tag not in self._data:
|
||||
|
||||
try:
|
||||
from cachetools import LRUCache
|
||||
|
||||
default_size = 20
|
||||
if 'ckpt' in tag:
|
||||
default_size = 5
|
||||
elif tag in ['latent', 'image']:
|
||||
default_size = 100
|
||||
|
||||
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size))
|
||||
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
# TODO: implement a simple lru dict
|
||||
self._data[tag] = {}
|
||||
self._data[tag][key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
del tag_data[key]
|
||||
return
|
||||
raise KeyError(f'Key `{key}` does not exist')
|
||||
|
||||
def __contains__(self, key):
|
||||
return any(key in tag_data for tag_data in self._data.values())
|
||||
|
||||
def items(self):
|
||||
yield from itertools.chain(*map(lambda x :x.items(), self._data.values()))
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
return tag_data[key]
|
||||
return default
|
||||
|
||||
def clear(self):
|
||||
# clear all cache
|
||||
self._data = {}
|
||||
|
||||
cache_settings = {}
|
||||
cache = TaggedCache(cache_settings)
|
||||
cache_count = {}
|
||||
|
||||
def update_cache(k, tag, v):
|
||||
cache[k] = (tag, v)
|
||||
cnt = cache_count.get(k)
|
||||
if cnt is None:
|
||||
cnt = 0
|
||||
cache_count[k] = cnt
|
||||
else:
|
||||
cache_count[k] += 1
|
||||
def remove_cache(key):
|
||||
global cache
|
||||
if key == '*':
|
||||
cache = TaggedCache(cache_settings)
|
||||
elif key in cache:
|
||||
del cache[key]
|
||||
else:
|
||||
print(f"invalid {key}")
|
||||
153
custom_nodes/ComfyUI-Easy-Use/py/libs/chooser.py
Normal file
153
custom_nodes/ComfyUI-Easy-Use/py/libs/chooser.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
from comfy import model_management as mm
|
||||
from comfy_execution.graph import ExecutionBlocker
|
||||
import time
|
||||
|
||||
class ChooserCancelled(Exception):
|
||||
pass
|
||||
|
||||
def get_chooser_cache():
|
||||
"""获取选择器缓存"""
|
||||
if not hasattr(PromptServer.instance, '_easyuse_chooser_node'):
|
||||
PromptServer.instance._easyuse_chooser_node = {}
|
||||
return PromptServer.instance._easyuse_chooser_node
|
||||
|
||||
def cleanup_session_data(node_id):
|
||||
"""清理会话数据"""
|
||||
node_data = get_chooser_cache()
|
||||
if node_id in node_data:
|
||||
session_keys = ["event", "selected", "images", "total_count", "cancelled"]
|
||||
for key in session_keys:
|
||||
if key in node_data[node_id]:
|
||||
del node_data[node_id][key]
|
||||
|
||||
def wait_for_chooser(id, images, mode, period=0.1):
|
||||
try:
|
||||
node_data = get_chooser_cache()
|
||||
images = [images[i:i + 1, ...] for i in range(images.shape[0])]
|
||||
if mode == "Keep Last Selection":
|
||||
if id in node_data and "last_selection" in node_data[id]:
|
||||
last_selection = node_data[id]["last_selection"]
|
||||
if last_selection and len(last_selection) > 0:
|
||||
valid_indices = [idx for idx in last_selection if 0 <= idx < len(images)]
|
||||
if valid_indices:
|
||||
try:
|
||||
PromptServer.instance.send_sync("easyuse-image-keep-selection", {
|
||||
"id": id,
|
||||
"selected": valid_indices
|
||||
})
|
||||
except Exception as e:
|
||||
pass
|
||||
cleanup_session_data(id)
|
||||
indices_str = ','.join(str(i) for i in valid_indices)
|
||||
images = [images[idx] for idx in valid_indices]
|
||||
images = torch.cat(images, dim=0)
|
||||
return {"result": (images,)}
|
||||
|
||||
if id in node_data:
|
||||
del node_data[id]
|
||||
|
||||
event = Event()
|
||||
node_data[id] = {
|
||||
"event": event,
|
||||
"images": images,
|
||||
"selected": None,
|
||||
"total_count": len(images),
|
||||
"cancelled": False,
|
||||
}
|
||||
|
||||
while id in node_data:
|
||||
node_info = node_data[id]
|
||||
if node_info.get("cancelled", False):
|
||||
cleanup_session_data(id)
|
||||
raise ChooserCancelled("Manual selection cancelled")
|
||||
|
||||
if "selected" in node_info and node_info["selected"] is not None:
|
||||
break
|
||||
|
||||
time.sleep(period)
|
||||
|
||||
if id in node_data:
|
||||
node_info = node_data[id]
|
||||
selected_indices = node_info.get("selected")
|
||||
|
||||
if selected_indices is not None and len(selected_indices) > 0:
|
||||
valid_indices = [idx for idx in selected_indices if 0 <= idx < len(images)]
|
||||
if valid_indices:
|
||||
selected_images = [images[idx] for idx in valid_indices]
|
||||
|
||||
if id not in node_data:
|
||||
node_data[id] = {}
|
||||
node_data[id]["last_selection"] = valid_indices
|
||||
cleanup_session_data(id)
|
||||
selected_images = torch.cat(selected_images, dim=0)
|
||||
return {"result": (selected_images,)}
|
||||
else:
|
||||
cleanup_session_data(id)
|
||||
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
else:
|
||||
cleanup_session_data(id)
|
||||
return {
|
||||
"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
else:
|
||||
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
|
||||
except ChooserCancelled:
|
||||
raise mm.InterruptProcessingException()
|
||||
except Exception as e:
|
||||
node_data = get_chooser_cache()
|
||||
if id in node_data:
|
||||
cleanup_session_data(id)
|
||||
if 'image_list' in locals() and len(images) > 0:
|
||||
return {"result": (images[0])}
|
||||
else:
|
||||
return {"result": (ExecutionBlocker(None),)}
|
||||
|
||||
|
||||
@PromptServer.instance.routes.post('/easyuse/image_chooser_message')
|
||||
async def handle_image_selection(request):
|
||||
try:
|
||||
data = await request.json()
|
||||
node_id = data.get("node_id")
|
||||
selected = data.get("selected", [])
|
||||
action = data.get("action")
|
||||
|
||||
node_data = get_chooser_cache()
|
||||
|
||||
if node_id not in node_data:
|
||||
return web.json_response({"code": -1, "error": "Node data does not exist"})
|
||||
|
||||
try:
|
||||
node_info = node_data[node_id]
|
||||
|
||||
if "total_count" not in node_info:
|
||||
return web.json_response({"code": -1, "error": "The node has been processed"})
|
||||
|
||||
if action == "cancel":
|
||||
node_info["cancelled"] = True
|
||||
node_info["selected"] = []
|
||||
elif action == "select" and isinstance(selected, list):
|
||||
valid_indices = [idx for idx in selected if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
|
||||
if valid_indices:
|
||||
node_info["selected"] = valid_indices
|
||||
node_info["cancelled"] = False
|
||||
else:
|
||||
return web.json_response({"code": -1, "error": "Invalid Selection Index"})
|
||||
else:
|
||||
return web.json_response({"code": -1, "error": "Invalid operation"})
|
||||
|
||||
node_info["event"].set()
|
||||
return web.json_response({"code": 1})
|
||||
|
||||
except Exception as e:
|
||||
if node_id in node_data and "event" in node_data[node_id]:
|
||||
node_data[node_id]["event"].set()
|
||||
return web.json_response({"code": -1, "message": "Processing Failed"})
|
||||
|
||||
except Exception as e:
|
||||
return web.json_response({"code": -1, "message": "Request Failed"})
|
||||
115
custom_nodes/ComfyUI-Easy-Use/py/libs/colorfix.py
Normal file
115
custom_nodes/ComfyUI-Easy-Use/py/libs/colorfix.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchvision.transforms import ToTensor, ToPILImage
|
||||
|
||||
def adain_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply adaptive instance normalization
|
||||
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def wavelet_color_fix(target: Image, source: Image):
|
||||
source = source.resize(target.size, resample=Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply wavelet reconstruction
|
||||
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def calc_mean_std(feat: Tensor, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
||||
"""Adaptive instance normalization.
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
57
custom_nodes/ComfyUI-Easy-Use/py/libs/conditioning.py
Normal file
57
custom_nodes/ComfyUI-Easy-Use/py/libs/conditioning.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from .utils import find_wildcards_seed, find_nearest_steps, is_linked_styles_selector
|
||||
from .log import log_node_warn
|
||||
from .translate import zh_to_en, has_chinese
|
||||
from .wildcards import process_with_loras
|
||||
from .adv_encode import advanced_encode
|
||||
|
||||
from nodes import ConditioningConcat, ConditioningCombine, ConditioningAverage, ConditioningSetTimestepRange, CLIPTextEncode
|
||||
|
||||
def prompt_to_cond(type, model, clip, clip_skip, lora_stack, text, prompt_token_normalization, prompt_weight_interpretation, a1111_prompt_style ,my_unique_id, prompt, easyCache, can_load_lora=True, steps=None, model_type=None):
|
||||
styles_selector = is_linked_styles_selector(prompt, my_unique_id, type)
|
||||
title = "Positive encoding" if type == 'positive' else "Negative encoding"
|
||||
|
||||
# Translate cn to en
|
||||
if model_type not in ['hydit'] and text is not None and has_chinese(text):
|
||||
text = zh_to_en([text])[0]
|
||||
|
||||
if model_type in ['hydit', 'flux', 'mochi']:
|
||||
log_node_warn(title + "...")
|
||||
embeddings_final, = CLIPTextEncode().encode(clip, text) if text is not None else (None,)
|
||||
|
||||
return (embeddings_final, "", model, clip)
|
||||
|
||||
log_node_warn(title + "...")
|
||||
|
||||
positive_seed = find_wildcards_seed(my_unique_id, text, prompt)
|
||||
model, clip, text, cond_decode, show_prompt, pipe_lora_stack = process_with_loras(
|
||||
text, model, clip, type, positive_seed, can_load_lora, lora_stack, easyCache)
|
||||
wildcard_prompt = cond_decode if show_prompt or styles_selector else ""
|
||||
|
||||
clipped = clip.clone()
|
||||
# 当clip模型不存在t5xxl时,可执行跳过层
|
||||
if not hasattr(clip.cond_stage_model, 't5xxl'):
|
||||
if clip_skip != 0:
|
||||
clipped.clip_layer(clip_skip)
|
||||
|
||||
steps = steps if steps is not None else find_nearest_steps(my_unique_id, prompt)
|
||||
return (advanced_encode(clipped, text, prompt_token_normalization,
|
||||
prompt_weight_interpretation, w_max=1.0,
|
||||
apply_to_pooled='enable',
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps) if text is not None else None, wildcard_prompt, model, clipped)
|
||||
|
||||
def set_cond(old_cond, new_cond, mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end):
|
||||
if not old_cond:
|
||||
return new_cond
|
||||
else:
|
||||
if mode == "replace":
|
||||
return new_cond
|
||||
elif mode == "concat":
|
||||
return ConditioningConcat().concat(new_cond, old_cond)[0]
|
||||
elif mode == "combine":
|
||||
return ConditioningCombine().combine(old_cond, new_cond)[0]
|
||||
elif mode == 'average':
|
||||
return ConditioningAverage().addWeighted(new_cond, old_cond, average_strength)[0]
|
||||
elif mode == 'timestep':
|
||||
cond_1 = ConditioningSetTimestepRange().set_range(old_cond, old_cond_start, old_cond_end)[0]
|
||||
cond_2 = ConditioningSetTimestepRange().set_range(new_cond, new_cond_start, new_cond_end)[0]
|
||||
return ConditioningCombine().combine(cond_1, cond_2)[0]
|
||||
93
custom_nodes/ComfyUI-Easy-Use/py/libs/controlnet.py
Normal file
93
custom_nodes/ComfyUI-Easy-Use/py/libs/controlnet.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import folder_paths
|
||||
import comfy.controlnet
|
||||
import comfy.model_management
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
|
||||
union_controlnet_types = {"auto": -1, "openpose": 0, "depth": 1, "hed/pidi/scribble/ted": 2, "canny/lineart/anime_lineart/mlsd": 3, "normal": 4, "segment": 5, "tile": 6, "repaint": 7}
|
||||
|
||||
class easyControlnet:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def apply(self, control_net_name, image, positive, negative, strength, start_percent=0, end_percent=1, control_net=None, scale_soft_weights=1, mask=None, union_type=None, easyCache=None, use_cache=True, model=None, vae=None):
|
||||
if strength == 0:
|
||||
return (positive, negative)
|
||||
|
||||
# kolors controlnet patch
|
||||
from ..modules.kolors.loader import is_kolors_model, applyKolorsUnet
|
||||
if is_kolors_model(model):
|
||||
from ..modules.kolors.model_patch import patch_controlnet
|
||||
if control_net is None:
|
||||
with applyKolorsUnet():
|
||||
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
|
||||
control_net = patch_controlnet(model, control_net)
|
||||
else:
|
||||
if control_net is None:
|
||||
if easyCache is not None:
|
||||
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
|
||||
else:
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
control_net = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
|
||||
# union controlnet
|
||||
if union_type is not None:
|
||||
control_net = control_net.copy()
|
||||
type_number = union_controlnet_types[union_type]
|
||||
if type_number >= 0:
|
||||
control_net.set_extra_arg("control_type", [type_number])
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.to(self.device)
|
||||
|
||||
if mask is not None and len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
control_hint = image.movedim(-1, 1)
|
||||
|
||||
is_cond = True
|
||||
if negative is None:
|
||||
p = []
|
||||
for t in positive:
|
||||
n = [t[0], t[1].copy()]
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
||||
if 'control' in t[1]:
|
||||
c_net.set_previous_controlnet(t[1]['control'])
|
||||
n[1]['control'] = c_net
|
||||
n[1]['control_apply_to_uncond'] = True
|
||||
if mask is not None:
|
||||
n[1]['mask'] = mask
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
p.append(n)
|
||||
positive = p
|
||||
else:
|
||||
cnets = {}
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
|
||||
prev_cnet = d.get('control', None)
|
||||
if prev_cnet in cnets:
|
||||
c_net = cnets[prev_cnet]
|
||||
else:
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
||||
c_net.set_previous_controlnet(prev_cnet)
|
||||
cnets[prev_cnet] = c_net
|
||||
|
||||
d['control'] = c_net
|
||||
d['control_apply_to_uncond'] = False
|
||||
|
||||
if mask is not None:
|
||||
d['mask'] = mask
|
||||
d['set_area_to_bounds'] = False
|
||||
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
out.append(c)
|
||||
positive = out[0]
|
||||
negative = out[1]
|
||||
|
||||
return (positive, negative)
|
||||
167
custom_nodes/ComfyUI-Easy-Use/py/libs/dynthres_core.py
Normal file
167
custom_nodes/ComfyUI-Easy-Use/py/libs/dynthres_core.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import torch, math
|
||||
|
||||
######################### DynThresh Core #########################
|
||||
|
||||
class DynThresh:
|
||||
|
||||
Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
|
||||
Startpoints = ["MEAN", "ZERO"]
|
||||
Variabilities = ["AD", "STD"]
|
||||
|
||||
def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, max_steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
|
||||
self.mimic_scale = mimic_scale
|
||||
self.threshold_percentile = threshold_percentile
|
||||
self.mimic_mode = mimic_mode
|
||||
self.cfg_mode = cfg_mode
|
||||
self.max_steps = max_steps
|
||||
self.cfg_scale_min = cfg_scale_min
|
||||
self.mimic_scale_min = mimic_scale_min
|
||||
self.experiment_mode = experiment_mode
|
||||
self.sched_val = sched_val
|
||||
self.sep_feat_channels = separate_feature_channels
|
||||
self.scaling_startpoint = scaling_startpoint
|
||||
self.variability_measure = variability_measure
|
||||
self.interpolate_phi = interpolate_phi
|
||||
|
||||
def interpret_scale(self, scale, mode, min):
|
||||
scale -= min
|
||||
max = self.max_steps - 1
|
||||
frac = self.step / max
|
||||
if mode == "Constant":
|
||||
pass
|
||||
elif mode == "Linear Down":
|
||||
scale *= 1.0 - frac
|
||||
elif mode == "Half Cosine Down":
|
||||
scale *= math.cos(frac)
|
||||
elif mode == "Cosine Down":
|
||||
scale *= math.cos(frac * 1.5707)
|
||||
elif mode == "Linear Up":
|
||||
scale *= frac
|
||||
elif mode == "Half Cosine Up":
|
||||
scale *= 1.0 - math.cos(frac)
|
||||
elif mode == "Cosine Up":
|
||||
scale *= 1.0 - math.cos(frac * 1.5707)
|
||||
elif mode == "Power Up":
|
||||
scale *= math.pow(frac, self.sched_val)
|
||||
elif mode == "Power Down":
|
||||
scale *= 1.0 - math.pow(frac, self.sched_val)
|
||||
elif mode == "Linear Repeating":
|
||||
portion = (frac * self.sched_val) % 1.0
|
||||
scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2
|
||||
elif mode == "Cosine Repeating":
|
||||
scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5
|
||||
elif mode == "Sawtooth":
|
||||
scale *= (frac * self.sched_val) % 1.0
|
||||
scale += min
|
||||
return scale
|
||||
|
||||
def dynthresh(self, cond, uncond, cfg_scale, weights):
|
||||
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
|
||||
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
|
||||
# uncond shape is (batch, 4, height, width)
|
||||
conds_per_batch = cond.shape[0] / uncond.shape[0]
|
||||
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
|
||||
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
|
||||
|
||||
### Normal first part of the CFG Scale logic, basically
|
||||
diff = cond_stacked - uncond.unsqueeze(1)
|
||||
if weights is not None:
|
||||
diff = diff * weights
|
||||
relative = diff.sum(1)
|
||||
|
||||
### Get the normal result for both mimic and normal scale
|
||||
mim_target = uncond + relative * mimic_scale
|
||||
cfg_target = uncond + relative * cfg_scale
|
||||
### If we weren't doing mimic scale, we'd just return cfg_target here
|
||||
|
||||
### Now recenter the values relative to their average rather than absolute, to allow scaling from average
|
||||
mim_flattened = mim_target.flatten(2)
|
||||
cfg_flattened = cfg_target.flatten(2)
|
||||
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
|
||||
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
|
||||
mim_centered = mim_flattened - mim_means
|
||||
cfg_centered = cfg_flattened - cfg_means
|
||||
|
||||
if self.sep_feat_channels:
|
||||
if self.variability_measure == 'STD':
|
||||
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
|
||||
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
|
||||
else: # 'AD'
|
||||
mim_scaleref = mim_centered.abs().max(dim=2).values.unsqueeze(2)
|
||||
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile, dim=2).unsqueeze(2)
|
||||
|
||||
else:
|
||||
if self.variability_measure == 'STD':
|
||||
mim_scaleref = mim_centered.std()
|
||||
cfg_scaleref = cfg_centered.std()
|
||||
else: # 'AD'
|
||||
mim_scaleref = mim_centered.abs().max()
|
||||
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile)
|
||||
|
||||
if self.scaling_startpoint == 'ZERO':
|
||||
scaling_factor = mim_scaleref / cfg_scaleref
|
||||
result = cfg_flattened * scaling_factor
|
||||
|
||||
else: # 'MEAN'
|
||||
if self.variability_measure == 'STD':
|
||||
cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref
|
||||
else: # 'AD'
|
||||
### Get the maximum value of all datapoints (with an optional threshold percentile on the uncond)
|
||||
max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref)
|
||||
### Clamp to the max
|
||||
cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref)
|
||||
### Now shrink from the max to normalize and grow to the mimic scale (instead of the CFG scale)
|
||||
cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref
|
||||
|
||||
### Now add it back onto the averages to get into real scale again and return
|
||||
result = cfg_renormalized + cfg_means
|
||||
|
||||
actual_res = result.unflatten(2, mim_target.shape[2:])
|
||||
|
||||
if self.interpolate_phi != 1.0:
|
||||
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
|
||||
|
||||
if self.experiment_mode == 1:
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
if num[0][0][y][x] > 1.0:
|
||||
num[0][1][y][x] *= 0.5
|
||||
if num[0][1][y][x] > 1.0:
|
||||
num[0][1][y][x] *= 0.5
|
||||
if num[0][2][y][x] > 1.5:
|
||||
num[0][2][y][x] *= 0.5
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 2:
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
over_scale = False
|
||||
for z in range(0, 4):
|
||||
if abs(num[0][z][y][x]) > 1.5:
|
||||
over_scale = True
|
||||
if over_scale:
|
||||
for z in range(0, 4):
|
||||
num[0][z][y][x] *= 0.7
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 3:
|
||||
coefs = torch.tensor([
|
||||
# R G B W
|
||||
[0.298, 0.207, 0.208, 0.0], # L1
|
||||
[0.187, 0.286, 0.173, 0.0], # L2
|
||||
[-0.158, 0.189, 0.264, 0.0], # L3
|
||||
[-0.184, -0.271, -0.473, 1.0], # L4
|
||||
], device=uncond.device)
|
||||
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
|
||||
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
|
||||
max_rgb = max(max_r, max_g, max_b)
|
||||
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
|
||||
if self.step / (self.max_steps - 1) > 0.2:
|
||||
if max_rgb < 2.0 and max_w < 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
else:
|
||||
if max_rgb > 2.4 and max_w > 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
|
||||
|
||||
return actual_res
|
||||
27
custom_nodes/ComfyUI-Easy-Use/py/libs/easing.py
Normal file
27
custom_nodes/ComfyUI-Easy-Use/py/libs/easing.py
Normal file
@@ -0,0 +1,27 @@
|
||||
@staticmethod
|
||||
def easyIn(t: float)-> float:
|
||||
return t*t
|
||||
@staticmethod
|
||||
def easyOut(t: float)-> float:
|
||||
return -(t * (t - 2))
|
||||
@staticmethod
|
||||
def easyInOut(t: float)-> float:
|
||||
if t < 0.5:
|
||||
return 2*t*t
|
||||
else:
|
||||
return (-2*t*t) + (4*t) - 1
|
||||
|
||||
class EasingBase:
|
||||
|
||||
def easing(self, t: float, function='linear') -> float:
|
||||
if function == 'easyIn':
|
||||
return easyIn(t)
|
||||
elif function == 'easyOut':
|
||||
return easyOut(t)
|
||||
elif function == 'easyInOut':
|
||||
return easyInOut(t)
|
||||
else:
|
||||
return t
|
||||
|
||||
def ease(self, start, end, t) -> float:
|
||||
return end * t + start * (1 - t)
|
||||
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
from torchvision.transforms.functional import gaussian_blur
|
||||
from comfy.k_diffusion.sampling import default_noise_sampler, get_ancestral_step, to_d, BrownianTreeNoiseSampler
|
||||
from tqdm.auto import trange
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x + noise * sigma_up * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x + noise * sigma_up * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
solver_type="midpoint",
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {"heun", "midpoint"}:
|
||||
raise ValueError("solver_type must be 'heun' or 'midpoint'")
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == "heun":
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == "midpoint":
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
denoised = None # 次ステップとサイズがあわないのでとりあえずNoneにしておく。
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True)
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
noise_sampler=None,
|
||||
eta=None,
|
||||
s_noise=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
|
||||
return x
|
||||
227
custom_nodes/ComfyUI-Easy-Use/py/libs/image.py
Normal file
227
custom_nodes/ComfyUI-Easy-Use/py/libs/image.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import os
|
||||
import base64
|
||||
import torch
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from typing import List, Union
|
||||
|
||||
import folder_paths
|
||||
from .utils import install_package
|
||||
|
||||
# PIL to Tensor
|
||||
def pil2tensor(image):
|
||||
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
# Tensor to PIL
|
||||
def tensor2pil(image):
|
||||
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
# np to Tensor
|
||||
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
|
||||
if isinstance(img_np, list):
|
||||
return torch.cat([np2tensor(img) for img in img_np], dim=0)
|
||||
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
|
||||
# Tensor to np
|
||||
def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
|
||||
if len(tensor.shape) == 3: # Single image
|
||||
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||||
else: # Batch of images
|
||||
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
|
||||
|
||||
def pil2byte(pil_image, format='PNG'):
|
||||
byte_arr = BytesIO()
|
||||
pil_image.save(byte_arr, format=format)
|
||||
byte_arr.seek(0)
|
||||
return byte_arr
|
||||
|
||||
def image2base64(image_base64):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_data = Image.open(BytesIO(image_bytes))
|
||||
return image_data
|
||||
|
||||
# Get new bounds
|
||||
def get_new_bounds(width, height, left, right, top, bottom):
|
||||
"""Returns the new bounds for an image with inset crop data."""
|
||||
left = 0 + left
|
||||
right = width - right
|
||||
top = 0 + top
|
||||
bottom = height - bottom
|
||||
return (left, right, top, bottom)
|
||||
|
||||
def RGB2RGBA(image: Image, mask: Image) -> Image:
|
||||
(R, G, B) = image.convert('RGB').split()
|
||||
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
|
||||
|
||||
def image2mask(image: Image) -> torch.Tensor:
|
||||
_image = image.convert('RGBA')
|
||||
alpha = _image.split()[0]
|
||||
bg = Image.new("L", _image.size)
|
||||
_image = Image.merge('RGBA', (bg, bg, bg, alpha))
|
||||
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
|
||||
return ret_mask
|
||||
|
||||
def mask2image(mask: torch.Tensor) -> Image:
|
||||
masks = tensor2np(mask)
|
||||
for m in masks:
|
||||
_mask = Image.fromarray(m).convert("L")
|
||||
_image = Image.new("RGBA", _mask.size, color='white')
|
||||
_image = Image.composite(
|
||||
_image, Image.new("RGBA", _mask.size, color='black'), _mask)
|
||||
return _image
|
||||
|
||||
# 图像融合
|
||||
class blendImage:
|
||||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def blend_mode(self, img1, img2, mode):
|
||||
if mode == "normal":
|
||||
return img2
|
||||
elif mode == "multiply":
|
||||
return img1 * img2
|
||||
elif mode == "screen":
|
||||
return 1 - (1 - img1) * (1 - img2)
|
||||
elif mode == "overlay":
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1),
|
||||
img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
elif mode == "difference":
|
||||
return img1 - img2
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str = 'normal'):
|
||||
image2 = image2.to(image1.device)
|
||||
if image1.shape != image2.shape:
|
||||
image2 = image2.permute(0, 3, 1, 2)
|
||||
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic',
|
||||
crop='center')
|
||||
image2 = image2.permute(0, 2, 3, 1)
|
||||
|
||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = torch.clamp(blended_image, 0, 1)
|
||||
return blended_image
|
||||
|
||||
|
||||
def empty_image(width, height, batch_size=1, color=0):
|
||||
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
|
||||
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
|
||||
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
|
||||
return torch.cat((r, g, b), dim=-1)
|
||||
|
||||
|
||||
class ResizeMode(Enum):
|
||||
RESIZE = "Just Resize"
|
||||
INNER_FIT = "Crop and Resize"
|
||||
OUTER_FIT = "Resize and Fill"
|
||||
def int_value(self):
|
||||
if self == ResizeMode.RESIZE:
|
||||
return 0
|
||||
elif self == ResizeMode.INNER_FIT:
|
||||
return 1
|
||||
elif self == ResizeMode.OUTER_FIT:
|
||||
return 2
|
||||
assert False, "NOTREACHED"
|
||||
|
||||
# credit by https://github.com/chflame163/ComfyUI_LayerStyle/blob/main/py/imagefunc.py#L591C1-L617C22
|
||||
def fit_resize_image(image: Image, target_width: int, target_height: int, fit: str, resize_sampler: str,
|
||||
background_color: str = '#000000') -> Image:
|
||||
image = image.convert('RGB')
|
||||
orig_width, orig_height = image.size
|
||||
if image is not None:
|
||||
if fit == 'letterbox':
|
||||
if orig_width / orig_height > target_width / target_height: # 更宽,上下留黑
|
||||
fit_width = target_width
|
||||
fit_height = int(target_width / orig_width * orig_height)
|
||||
else: # 更瘦,左右留黑
|
||||
fit_height = target_height
|
||||
fit_width = int(target_height / orig_height * orig_width)
|
||||
fit_image = image.resize((fit_width, fit_height), resize_sampler)
|
||||
ret_image = Image.new('RGB', size=(target_width, target_height), color=background_color)
|
||||
ret_image.paste(fit_image, box=((target_width - fit_width) // 2, (target_height - fit_height) // 2))
|
||||
elif fit == 'crop':
|
||||
if orig_width / orig_height > target_width / target_height: # 更宽,裁左右
|
||||
fit_width = int(orig_height * target_width / target_height)
|
||||
fit_image = image.crop(
|
||||
((orig_width - fit_width) // 2, 0, (orig_width - fit_width) // 2 + fit_width, orig_height))
|
||||
else: # 更瘦,裁上下
|
||||
fit_height = int(orig_width * target_height / target_width)
|
||||
fit_image = image.crop(
|
||||
(0, (orig_height - fit_height) // 2, orig_width, (orig_height - fit_height) // 2 + fit_height))
|
||||
ret_image = fit_image.resize((target_width, target_height), resize_sampler)
|
||||
else:
|
||||
ret_image = image.resize((target_width, target_height), resize_sampler)
|
||||
return ret_image
|
||||
|
||||
# CLIP反推
|
||||
import comfy.utils
|
||||
from torchvision import transforms
|
||||
Config, Interrogator = None, None
|
||||
class CI_Inference:
|
||||
ci_model = None
|
||||
cache_path: str
|
||||
|
||||
def __init__(self):
|
||||
self.ci_model = None
|
||||
self.low_vram = False
|
||||
self.cache_path = os.path.join(folder_paths.models_dir, "clip_interrogator")
|
||||
|
||||
def _load_model(self, model_name, low_vram=False):
|
||||
if not (self.ci_model and model_name == self.ci_model.config.clip_model_name and self.low_vram == low_vram):
|
||||
self.low_vram = low_vram
|
||||
print(f"Load model: {model_name}")
|
||||
|
||||
config = Config(
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
download_cache=True,
|
||||
clip_model_name=model_name,
|
||||
clip_model_path=self.cache_path,
|
||||
cache_path=self.cache_path,
|
||||
caption_model_name='blip-large'
|
||||
)
|
||||
|
||||
if low_vram:
|
||||
config.apply_low_vram_defaults()
|
||||
|
||||
self.ci_model = Interrogator(config)
|
||||
|
||||
def _interrogate(self, image, mode, caption=None):
|
||||
if mode == 'best':
|
||||
prompt = self.ci_model.interrogate(image, caption=caption)
|
||||
elif mode == 'classic':
|
||||
prompt = self.ci_model.interrogate_classic(image, caption=caption)
|
||||
elif mode == 'fast':
|
||||
prompt = self.ci_model.interrogate_fast(image, caption=caption)
|
||||
elif mode == 'negative':
|
||||
prompt = self.ci_model.interrogate_negative(image)
|
||||
else:
|
||||
raise Exception(f"Unknown mode {mode}")
|
||||
return prompt
|
||||
|
||||
def image_to_prompt(self, image, mode, model_name='ViT-L-14/openai', low_vram=False):
|
||||
try:
|
||||
from clip_interrogator import Config, Interrogator
|
||||
global Config, Interrogator
|
||||
except:
|
||||
install_package("clip_interrogator", "0.6.0")
|
||||
from clip_interrogator import Config, Interrogator
|
||||
|
||||
pbar = comfy.utils.ProgressBar(len(image))
|
||||
|
||||
self._load_model(model_name, low_vram)
|
||||
prompt = []
|
||||
for i in range(len(image)):
|
||||
im = image[i]
|
||||
|
||||
im = tensor2pil(im)
|
||||
im = im.convert('RGB')
|
||||
|
||||
_prompt = self._interrogate(im, mode)
|
||||
pbar.update(1)
|
||||
prompt.append(_prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
ci = CI_Inference()
|
||||
237
custom_nodes/ComfyUI-Easy-Use/py/libs/lllite.py
Normal file
237
custom_nodes/ComfyUI-Easy-Use/py/libs/lllite.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import math
|
||||
import torch
|
||||
import comfy
|
||||
|
||||
|
||||
def extra_options_to_module_prefix(extra_options):
|
||||
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
|
||||
|
||||
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
|
||||
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
|
||||
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
|
||||
# block_index is: 0-1 or 0-9, depends on the block
|
||||
# input 7 and 8, middle has 10 blocks
|
||||
|
||||
# make module name from extra_options
|
||||
block = extra_options["block"]
|
||||
block_index = extra_options["block_index"]
|
||||
if block[0] == "input":
|
||||
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
||||
elif block[0] == "middle":
|
||||
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
|
||||
elif block[0] == "output":
|
||||
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
||||
else:
|
||||
raise Exception("invalid block name")
|
||||
return module_pfx
|
||||
|
||||
|
||||
def load_control_net_lllite_patch(path, cond_image, multiplier, num_steps, start_percent, end_percent):
|
||||
# calculate start and end step
|
||||
start_step = math.floor(num_steps * start_percent * 0.01) if start_percent > 0 else 0
|
||||
end_step = math.floor(num_steps * end_percent * 0.01) if end_percent > 0 else num_steps
|
||||
|
||||
# load weights
|
||||
ctrl_sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
|
||||
# split each weights for each module
|
||||
module_weights = {}
|
||||
for key, value in ctrl_sd.items():
|
||||
fragments = key.split(".")
|
||||
module_name = fragments[0]
|
||||
weight_name = ".".join(fragments[1:])
|
||||
|
||||
if module_name not in module_weights:
|
||||
module_weights[module_name] = {}
|
||||
module_weights[module_name][weight_name] = value
|
||||
|
||||
# load each module
|
||||
modules = {}
|
||||
for module_name, weights in module_weights.items():
|
||||
# ここの自動判定を何とかしたい
|
||||
if "conditioning1.4.weight" in weights:
|
||||
depth = 3
|
||||
elif weights["conditioning1.2.weight"].shape[-1] == 4:
|
||||
depth = 2
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
module = LLLiteModule(
|
||||
name=module_name,
|
||||
is_conv2d=weights["down.0.weight"].ndim == 4,
|
||||
in_dim=weights["down.0.weight"].shape[1],
|
||||
depth=depth,
|
||||
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
|
||||
mlp_dim=weights["down.0.weight"].shape[0],
|
||||
multiplier=multiplier,
|
||||
num_steps=num_steps,
|
||||
start_step=start_step,
|
||||
end_step=end_step,
|
||||
)
|
||||
info = module.load_state_dict(weights)
|
||||
modules[module_name] = module
|
||||
if len(modules) == 1:
|
||||
module.is_first = True
|
||||
|
||||
print(f"loaded {path} successfully, {len(modules)} modules")
|
||||
|
||||
# cond imageをセットする
|
||||
cond_image = cond_image.permute(0, 3, 1, 2) # b,h,w,3 -> b,3,h,w
|
||||
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-+1
|
||||
|
||||
for module in modules.values():
|
||||
module.set_cond_image(cond_image)
|
||||
|
||||
class control_net_lllite_patch:
|
||||
def __init__(self, modules):
|
||||
self.modules = modules
|
||||
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
module_pfx = extra_options_to_module_prefix(extra_options)
|
||||
|
||||
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
|
||||
if is_attn1:
|
||||
module_pfx = module_pfx + "_attn1"
|
||||
else:
|
||||
module_pfx = module_pfx + "_attn2"
|
||||
|
||||
module_pfx_to_q = module_pfx + "_to_q"
|
||||
module_pfx_to_k = module_pfx + "_to_k"
|
||||
module_pfx_to_v = module_pfx + "_to_v"
|
||||
|
||||
if module_pfx_to_q in self.modules:
|
||||
q = q + self.modules[module_pfx_to_q](q)
|
||||
if module_pfx_to_k in self.modules:
|
||||
k = k + self.modules[module_pfx_to_k](k)
|
||||
if module_pfx_to_v in self.modules:
|
||||
v = v + self.modules[module_pfx_to_v](v)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def to(self, device):
|
||||
for d in self.modules.keys():
|
||||
self.modules[d] = self.modules[d].to(device)
|
||||
return self
|
||||
|
||||
return control_net_lllite_patch(modules)
|
||||
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
is_conv2d: bool,
|
||||
in_dim: int,
|
||||
depth: int,
|
||||
cond_emb_dim: int,
|
||||
mlp_dim: int,
|
||||
multiplier: int,
|
||||
num_steps: int,
|
||||
start_step: int,
|
||||
end_step: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.is_conv2d = is_conv2d
|
||||
self.multiplier = multiplier
|
||||
self.num_steps = num_steps
|
||||
self.start_step = start_step
|
||||
self.end_step = end_step
|
||||
self.is_first = False
|
||||
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
|
||||
self.conditioning1 = torch.nn.Sequential(*modules)
|
||||
|
||||
if self.is_conv2d:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim, in_dim),
|
||||
)
|
||||
|
||||
self.depth = depth
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
self.current_step = 0
|
||||
|
||||
# @torch.inference_mode()
|
||||
def set_cond_image(self, cond_image):
|
||||
# print("set_cond_image", self.name)
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
self.current_step = 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.num_steps > 0:
|
||||
if self.current_step < self.start_step:
|
||||
self.current_step += 1
|
||||
return torch.zeros_like(x)
|
||||
elif self.current_step >= self.end_step:
|
||||
if self.is_first and self.current_step == self.end_step:
|
||||
print(f"end LLLite: step {self.current_step}")
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.num_steps:
|
||||
self.current_step = 0 # reset
|
||||
return torch.zeros_like(x)
|
||||
else:
|
||||
if self.is_first and self.current_step == self.start_step:
|
||||
print(f"start LLLite: step {self.current_step}")
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.num_steps:
|
||||
self.current_step = 0 # reset
|
||||
|
||||
if self.cond_emb is None:
|
||||
# print(f"cond_emb is None, {self.name}")
|
||||
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
self.cond_emb = cx
|
||||
|
||||
cx = self.cond_emb
|
||||
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
|
||||
|
||||
# uncond/condでxはバッチサイズが2倍
|
||||
if x.shape[0] != cx.shape[0]:
|
||||
if self.is_conv2d:
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
|
||||
else:
|
||||
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
cx = self.up(cx)
|
||||
return cx * self.multiplier
|
||||
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import re, time, os, psutil
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import comfy.controlnet
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
from collections import defaultdict
|
||||
from .log import log_node_info, log_node_error
|
||||
from ..modules.dit.pixArt.loader import load_pixart
|
||||
|
||||
diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy fluxLoader", "easy comfyLoader", "easy hunyuanDiTLoader", "easy zero123Loader", "easy svdLoader"]
|
||||
stable_cascade_loaders = ["easy cascadeLoader"]
|
||||
dit_loaders = ['easy pixArtLoader']
|
||||
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV", "easy controlnetLoader++"]
|
||||
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"]
|
||||
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"]
|
||||
model_merge_node = ["easy XYInputs: ModelMergeBlocks"]
|
||||
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy fluxLoader"]
|
||||
|
||||
class easyLoader:
|
||||
def __init__(self):
|
||||
self.loaded_objects = {
|
||||
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)}
|
||||
"unet": defaultdict(tuple),
|
||||
"clip": defaultdict(tuple),
|
||||
"clip_vision": defaultdict(tuple),
|
||||
"bvae": defaultdict(tuple),
|
||||
"vae": defaultdict(object),
|
||||
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}}
|
||||
"controlnet": defaultdict(dict),
|
||||
"t5": defaultdict(tuple),
|
||||
"chatglm3": defaultdict(tuple),
|
||||
}
|
||||
self.memory_threshold = self.determine_memory_threshold(1)
|
||||
self.lora_name_cache = []
|
||||
|
||||
def clean_values(self, values: str):
|
||||
original_values = values.split("; ")
|
||||
cleaned_values = []
|
||||
|
||||
for value in original_values:
|
||||
cleaned_value = value.strip(';').strip()
|
||||
if cleaned_value == "":
|
||||
continue
|
||||
try:
|
||||
cleaned_value = int(cleaned_value)
|
||||
except ValueError:
|
||||
try:
|
||||
cleaned_value = float(cleaned_value)
|
||||
except ValueError:
|
||||
pass
|
||||
cleaned_values.append(cleaned_value)
|
||||
|
||||
return cleaned_values
|
||||
|
||||
def clear_unused_objects(self, desired_names: set, object_type: str):
|
||||
keys = set(self.loaded_objects[object_type].keys())
|
||||
for key in keys - desired_names:
|
||||
del self.loaded_objects[object_type][key]
|
||||
|
||||
def get_input_value(self, entry, key, prompt=None):
|
||||
val = entry["inputs"][key]
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, list):
|
||||
if prompt is not None and val[0]:
|
||||
return prompt[val[0]]['inputs'][key]
|
||||
else:
|
||||
return val[0]
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""):
|
||||
for idx in range(1, num_loras + 1):
|
||||
lora_name_key = f"{suffix}lora{idx}_name"
|
||||
desired_lora_names.add(self.get_input_value(entry, lora_name_key))
|
||||
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name"))
|
||||
|
||||
def update_loaded_objects(self, prompt):
|
||||
desired_ckpt_names = set()
|
||||
desired_unet_names = set()
|
||||
desired_clip_names = set()
|
||||
desired_vae_names = set()
|
||||
desired_lora_names = set()
|
||||
desired_lora_settings = set()
|
||||
desired_controlnet_names = set()
|
||||
desired_t5_names = set()
|
||||
desired_glm3_names = set()
|
||||
|
||||
for entry in prompt.values():
|
||||
class_type = entry["class_type"]
|
||||
if class_type in lora_widget:
|
||||
lora_name = self.get_input_value(entry, "lora_name")
|
||||
desired_lora_names.add(lora_name)
|
||||
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
if class_type in diffusion_loaders:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
|
||||
elif class_type in ['easy kolorsLoader']:
|
||||
desired_unet_names.add(self.get_input_value(entry, "unet_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name"))
|
||||
|
||||
elif class_type in dit_loaders:
|
||||
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None
|
||||
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None
|
||||
model_name = self.get_input_value(entry, "model_name")
|
||||
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt)
|
||||
if t5_name:
|
||||
desired_t5_names.add(t5_name)
|
||||
if clip_name:
|
||||
desired_clip_names.add(clip_name)
|
||||
desired_ckpt_names.add(ckpt_name+'_'+model_name)
|
||||
|
||||
elif class_type in stable_cascade_loaders:
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_c"))
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_b"))
|
||||
desired_clip_names.add(self.get_input_value(entry, "clip_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "stage_a"))
|
||||
|
||||
elif class_type in cascade_vae_node:
|
||||
encode_vae_name = self.get_input_value(entry, "encode_vae_name")
|
||||
decode_vae_name = self.get_input_value(entry, "decode_vae_name")
|
||||
if encode_vae_name and encode_vae_name != 'None':
|
||||
desired_vae_names.add(encode_vae_name)
|
||||
if decode_vae_name and decode_vae_name != 'None':
|
||||
desired_vae_names.add(decode_vae_name)
|
||||
|
||||
elif class_type in controlnet_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in instant_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in model_merge_node:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1"))
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2"))
|
||||
vae_use = self.get_input_value(entry, "vae_use")
|
||||
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2':
|
||||
desired_vae_names.add(vae_use)
|
||||
|
||||
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"]
|
||||
for object_type in object_types:
|
||||
if object_type == 'unet':
|
||||
desired_names = desired_unet_names
|
||||
elif object_type in ["ckpt", "clip", "bvae"]:
|
||||
if object_type == 'clip':
|
||||
desired_names = desired_ckpt_names.union(desired_clip_names)
|
||||
else:
|
||||
desired_names = desired_ckpt_names
|
||||
elif object_type == "vae":
|
||||
desired_names = desired_vae_names
|
||||
elif object_type == "controlnet":
|
||||
desired_names = desired_controlnet_names
|
||||
elif object_type == "t5":
|
||||
desired_names = desired_t5_names
|
||||
elif object_type == "chatglm3":
|
||||
desired_names = desired_glm3_names
|
||||
else:
|
||||
desired_names = desired_lora_names
|
||||
self.clear_unused_objects(desired_names, object_type)
|
||||
|
||||
def add_to_cache(self, obj_type, key, value):
|
||||
"""
|
||||
Add an item to the cache with the current timestamp.
|
||||
"""
|
||||
timestamped_value = (value, time.time())
|
||||
self.loaded_objects[obj_type][key] = timestamped_value
|
||||
|
||||
def determine_memory_threshold(self, percentage=0.8):
|
||||
"""
|
||||
Determines the memory threshold as a percentage of the total available memory.
|
||||
Args:
|
||||
- percentage (float): The fraction of total memory to use as the threshold.
|
||||
Should be a value between 0 and 1. Default is 0.8 (80%).
|
||||
Returns:
|
||||
- memory_threshold (int): Memory threshold in bytes.
|
||||
"""
|
||||
total_memory = psutil.virtual_memory().total
|
||||
memory_threshold = total_memory * percentage
|
||||
return memory_threshold
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""
|
||||
Returns the memory usage of the current process in bytes.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss
|
||||
|
||||
def eviction_based_on_memory(self):
|
||||
"""
|
||||
Evicts objects from cache based on memory usage and priority.
|
||||
"""
|
||||
current_memory = self.get_memory_usage()
|
||||
if current_memory < self.memory_threshold:
|
||||
return
|
||||
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"]
|
||||
for obj_type in eviction_order:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
# Sort items based on age (using the timestamp)
|
||||
items = list(self.loaded_objects[obj_type].items())
|
||||
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp
|
||||
|
||||
for item in items:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
del self.loaded_objects[obj_type][item[0]]
|
||||
current_memory = self.get_memory_usage()
|
||||
|
||||
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False):
|
||||
cache_name = ckpt_name
|
||||
if config_name not in [None, "Default"]:
|
||||
cache_name = ckpt_name + "_" + config_name
|
||||
if cache_name in self.loaded_objects["ckpt"]:
|
||||
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None
|
||||
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None
|
||||
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
|
||||
output_clip = False if load_vision else True
|
||||
output_clipvision = True if load_vision else False
|
||||
if config_name not in [None, "Default"]:
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
else:
|
||||
model_options = {}
|
||||
if re.search("nf4", ckpt_name):
|
||||
from ..modules.bitsandbytes_NF4 import OPS
|
||||
model_options = {"custom_operations": OPS}
|
||||
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
|
||||
|
||||
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0])
|
||||
self.add_to_cache("bvae", cache_name, loaded_ckpt[2])
|
||||
|
||||
clip = loaded_ckpt[1]
|
||||
clip_vision = loaded_ckpt[3]
|
||||
if clip:
|
||||
self.add_to_cache("clip", cache_name, clip)
|
||||
if clip_vision:
|
||||
self.add_to_cache("clip_vision", cache_name, clip_vision)
|
||||
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision
|
||||
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in self.loaded_objects["vae"]:
|
||||
return self.loaded_objects["vae"][vae_name][0]
|
||||
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
loaded_vae = comfy.sd.VAE(sd=sd)
|
||||
self.add_to_cache("vae", vae_name, loaded_vae)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_vae
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True):
|
||||
unique_id = f'{control_net_name};{str(scale_soft_weights)}'
|
||||
if use_cache and unique_id in self.loaded_objects["controlnet"]:
|
||||
return self.loaded_objects["controlnet"][unique_id][0]
|
||||
if scale_soft_weights < 1:
|
||||
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS:
|
||||
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights']
|
||||
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False)
|
||||
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced']
|
||||
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe)
|
||||
else:
|
||||
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'")
|
||||
else:
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
control_net = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
if use_cache:
|
||||
self.add_to_cache("controlnet", unique_id, control_net)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return control_net
|
||||
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None):
|
||||
if clip_name in self.loaded_objects["clip"]:
|
||||
return self.loaded_objects["clip"][clip_name][0]
|
||||
|
||||
if type == 'stable_diffusion':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == 'stable_cascade':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
elif type == 'sd3':
|
||||
clip_type = comfy.sd.CLIPType.SD3
|
||||
elif type == 'flux':
|
||||
clip_type = comfy.sd.CLIPType.FLUX
|
||||
elif type == 'stable_audio':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
self.add_to_cache("clip", clip_name, load_clip)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return load_clip
|
||||
|
||||
def load_lora(self, lora, model=None, clip=None, type=None , use_cache=True):
|
||||
lora_name = lora["lora_name"]
|
||||
model = model if model is not None else lora["model"]
|
||||
clip = clip if clip is not None else lora["clip"]
|
||||
model_strength = lora["model_strength"]
|
||||
clip_strength = lora["clip_strength"]
|
||||
lbw = lora["lbw"] if "lbw" in lora else None
|
||||
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None
|
||||
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None
|
||||
|
||||
model_hash = str(model)[44:-1]
|
||||
clip_hash = str(clip)[25:-1] if clip else ''
|
||||
|
||||
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}'
|
||||
|
||||
if use_cache and unique_id in self.loaded_objects["lora"]:
|
||||
log_node_info("Load LORA",f"{lora_name} cached")
|
||||
return self.loaded_objects["lora"][unique_id][0]
|
||||
|
||||
orig_lora_name = lora_name
|
||||
lora_name = self.resolve_lora_name(lora_name)
|
||||
|
||||
if lora_name is not None:
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
else:
|
||||
lora_path = None
|
||||
|
||||
if lora_path is not None:
|
||||
log_node_info("Load LORA",f"{lora_name}: model={model_strength:.3f}, clip={clip_strength:.3f}, LBW={lbw}, A={lbw_a}, B={lbw_b}")
|
||||
if lbw:
|
||||
lbw = lora["lbw"]
|
||||
lbw_a = lora["lbw_a"]
|
||||
lbw_b = lora["lbw_b"]
|
||||
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS:
|
||||
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack')
|
||||
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire']
|
||||
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0,
|
||||
lbw_a, lbw_b, "", lbw)
|
||||
else:
|
||||
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
keys = _lora.keys()
|
||||
if "down_blocks.0.resnets.0.norm1.bias" in keys:
|
||||
print('Using LORA for Resadapter')
|
||||
key_map = {}
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
mapping_norm = {}
|
||||
|
||||
for key in keys:
|
||||
if ".weight" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".weight", "")]
|
||||
mapping_norm[key_name_in_ori_sd] = _lora[key]
|
||||
elif ".bias" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".bias", "")]
|
||||
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[
|
||||
key
|
||||
]
|
||||
else:
|
||||
print("===>Unexpected key", key)
|
||||
mapping_norm[key] = _lora[key]
|
||||
|
||||
for k in mapping_norm.keys():
|
||||
if k not in model.model.state_dict():
|
||||
print("===>Missing key:", k)
|
||||
model.model.load_state_dict(mapping_norm, strict=False)
|
||||
return (model, clip)
|
||||
|
||||
# PixArt
|
||||
if type is not None and type == 'PixArt':
|
||||
from ..modules.dit.pixArt.loader import load_pixart_lora
|
||||
model = load_pixart_lora(model, _lora, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength)
|
||||
|
||||
if use_cache:
|
||||
self.add_to_cache("lora", unique_id, (model, clip))
|
||||
self.eviction_based_on_memory()
|
||||
else:
|
||||
log_node_error(f"LORA NOT FOUND", orig_lora_name)
|
||||
|
||||
return model, clip
|
||||
|
||||
def resolve_lora_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
else:
|
||||
if len(self.lora_name_cache) == 0:
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
self.lora_name_cache.extend(loras)
|
||||
for x in self.lora_name_cache:
|
||||
if x.endswith(name):
|
||||
return x
|
||||
|
||||
# 如果刷新网页后新添加的lora走这个逻辑
|
||||
log_node_info("LORA NOT IN CACHE", f"{name}")
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
for x in loras:
|
||||
if x.endswith(name):
|
||||
self.lora_name_cache.append(x)
|
||||
return x
|
||||
|
||||
return None
|
||||
|
||||
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False):
|
||||
model: ModelPatcher | None = None
|
||||
clip: comfy.sd.CLIP | None = None
|
||||
vae: comfy.sd.VAE | None = None
|
||||
clip_vision = None
|
||||
lora_stack = []
|
||||
|
||||
# Check for model override
|
||||
can_load_lora = True
|
||||
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型
|
||||
# Determine whether there is a model or Lora overlapping xyplot, and if there is, prioritize caching the first model.
|
||||
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks",
|
||||
"easy XYInputs: Checkpoint"]), None)
|
||||
# This will find nodes that aren't actively connected to anything, and skip loading lora's for them.
|
||||
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None)
|
||||
if xy_lora_id is not None:
|
||||
can_load_lora = False
|
||||
if xy_model_id is not None:
|
||||
node = prompt[xy_model_id]
|
||||
if "ckpt_name_1" in node["inputs"]:
|
||||
ckpt_name_1 = node["inputs"]["ckpt_name_1"]
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1)
|
||||
can_load_lora = False
|
||||
elif model_override is not None and clip_override is not None and vae_override is not None:
|
||||
model = model_override
|
||||
clip = clip_override
|
||||
vae = vae_override
|
||||
else:
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name)
|
||||
if model_override is not None:
|
||||
model = model_override
|
||||
if vae_override is not None:
|
||||
vae = vae_override
|
||||
elif clip_override is not None:
|
||||
clip = clip_override
|
||||
|
||||
|
||||
if optional_lora_stack is not None and can_load_lora:
|
||||
for lora in optional_lora_stack:
|
||||
# This is a subtle bit of code because it uses the model created by the last call, and passes it to the next call.
|
||||
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1],
|
||||
"clip_strength": lora[2]}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
lora_stack.append(lora)
|
||||
|
||||
if lora_name != "None" and can_load_lora:
|
||||
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength,
|
||||
"clip_strength": lora_clip_strength}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora_stack.append(lora)
|
||||
|
||||
# Check for custom VAE
|
||||
if vae_name not in ["Baked VAE", "Baked-VAE"]:
|
||||
vae = self.load_vae(vae_name)
|
||||
# CLIP skip
|
||||
if not clip:
|
||||
raise Exception("No CLIP found")
|
||||
|
||||
return model, clip, vae, clip_vision, lora_stack
|
||||
|
||||
# Kolors
|
||||
def load_kolors_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load Kolors UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
else:
|
||||
from ..modules.kolors.loader import applyKolorsUnet
|
||||
with applyKolorsUnet():
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = comfy.sd.load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_chatglm3(self, chatglm3_name):
|
||||
from ..modules.kolors.loader import load_chatglm3
|
||||
if chatglm3_name in self.loaded_objects["chatglm3"]:
|
||||
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached")
|
||||
return self.loaded_objects["chatglm3"][chatglm3_name][0]
|
||||
|
||||
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name))
|
||||
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return chatglm_model
|
||||
|
||||
|
||||
# DiT
|
||||
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs):
|
||||
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]:
|
||||
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0]
|
||||
model = None
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt'
|
||||
if model_type == 'PixArt':
|
||||
pixart_conf = kwargs['pixart_conf']
|
||||
model_conf = pixart_conf[model_name]
|
||||
model = load_pixart(ckpt_path, model_conf)
|
||||
if model:
|
||||
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
return model
|
||||
|
||||
def load_t5_from_sd3_clip(self, sd3_clip, padding):
|
||||
try:
|
||||
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
except:
|
||||
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
import copy
|
||||
|
||||
clip = sd3_clip.clone()
|
||||
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!"
|
||||
|
||||
# remove transformer
|
||||
transformer = clip.cond_stage_model.t5xxl.transformer
|
||||
clip.cond_stage_model.t5xxl.transformer = None
|
||||
|
||||
# clone object
|
||||
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False)
|
||||
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl)
|
||||
# put transformer back
|
||||
clip.cond_stage_model.t5xxl.transformer = transformer
|
||||
tmp.t5xxl.transformer = transformer
|
||||
|
||||
# override special tokens
|
||||
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens)
|
||||
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match
|
||||
|
||||
# tokenizer
|
||||
tok = SD3Tokenizer()
|
||||
tok.t5xxl.min_length = padding
|
||||
|
||||
clip.cond_stage_model = tmp
|
||||
clip.tokenizer = tok
|
||||
|
||||
return clip
|
||||
77
custom_nodes/ComfyUI-Easy-Use/py/libs/log.py
Normal file
77
custom_nodes/ComfyUI-Easy-Use/py/libs/log.py
Normal file
@@ -0,0 +1,77 @@
|
||||
COLORS_FG = {
|
||||
'BLACK': '\33[30m',
|
||||
'RED': '\33[31m',
|
||||
'GREEN': '\33[32m',
|
||||
'YELLOW': '\33[33m',
|
||||
'BLUE': '\33[34m',
|
||||
'MAGENTA': '\33[35m',
|
||||
'CYAN': '\33[36m',
|
||||
'WHITE': '\33[37m',
|
||||
'GREY': '\33[90m',
|
||||
'BRIGHT_RED': '\33[91m',
|
||||
'BRIGHT_GREEN': '\33[92m',
|
||||
'BRIGHT_YELLOW': '\33[93m',
|
||||
'BRIGHT_BLUE': '\33[94m',
|
||||
'BRIGHT_MAGENTA': '\33[95m',
|
||||
'BRIGHT_CYAN': '\33[96m',
|
||||
'BRIGHT_WHITE': '\33[97m',
|
||||
}
|
||||
COLORS_STYLE = {
|
||||
'RESET': '\33[0m',
|
||||
'BOLD': '\33[1m',
|
||||
'NORMAL': '\33[22m',
|
||||
'ITALIC': '\33[3m',
|
||||
'UNDERLINE': '\33[4m',
|
||||
'BLINK': '\33[5m',
|
||||
'BLINK2': '\33[6m',
|
||||
'SELECTED': '\33[7m',
|
||||
}
|
||||
COLORS_BG = {
|
||||
'BLACK': '\33[40m',
|
||||
'RED': '\33[41m',
|
||||
'GREEN': '\33[42m',
|
||||
'YELLOW': '\33[43m',
|
||||
'BLUE': '\33[44m',
|
||||
'MAGENTA': '\33[45m',
|
||||
'CYAN': '\33[46m',
|
||||
'WHITE': '\33[47m',
|
||||
'GREY': '\33[100m',
|
||||
'BRIGHT_RED': '\33[101m',
|
||||
'BRIGHT_GREEN': '\33[102m',
|
||||
'BRIGHT_YELLOW': '\33[103m',
|
||||
'BRIGHT_BLUE': '\33[104m',
|
||||
'BRIGHT_MAGENTA': '\33[105m',
|
||||
'BRIGHT_CYAN': '\33[106m',
|
||||
'BRIGHT_WHITE': '\33[107m',
|
||||
}
|
||||
|
||||
def log_node_success(node_name, message=None):
|
||||
"""Logs a success message."""
|
||||
_log_node(COLORS_FG["GREEN"], node_name, message)
|
||||
|
||||
def log_node_info(node_name, message=None):
|
||||
"""Logs an info message."""
|
||||
_log_node(COLORS_FG["CYAN"], node_name, message)
|
||||
|
||||
|
||||
def log_node_warn(node_name, message=None):
|
||||
"""Logs an warn message."""
|
||||
_log_node(COLORS_FG["YELLOW"], node_name, message)
|
||||
|
||||
def log_node_error(node_name, message=None):
|
||||
"""Logs an warn message."""
|
||||
_log_node(COLORS_FG["RED"], node_name, message)
|
||||
|
||||
def log_node(node_name, message=None):
|
||||
"""Logs a message."""
|
||||
_log_node(COLORS_FG["CYAN"], node_name, message)
|
||||
|
||||
|
||||
def _log_node(color, node_name, message=None, prefix=''):
|
||||
print(_get_log_msg(color, node_name, message, prefix=prefix))
|
||||
|
||||
def _get_log_msg(color, node_name, message=None, prefix=''):
|
||||
msg = f'{COLORS_STYLE["BOLD"]}{color}{prefix}[EasyUse] {node_name.replace(" (EasyUse)", "")}'
|
||||
msg += f':{COLORS_STYLE["RESET"]} {message}' if message is not None else f'{COLORS_STYLE["RESET"]}'
|
||||
return msg
|
||||
|
||||
133
custom_nodes/ComfyUI-Easy-Use/py/libs/math.py
Normal file
133
custom_nodes/ComfyUI-Easy-Use/py/libs/math.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Math utility functions for formula evaluation
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
|
||||
def evaluate_formula(formula: str, a=0, b=0, c=0, d=0) -> float:
|
||||
"""
|
||||
计算字符串数学公式
|
||||
|
||||
支持的运算符和函数:
|
||||
- 基本运算:+, -, *, /, //, %, **
|
||||
- 比较运算:>, <, >=, <=, ==, !=
|
||||
- 数学函数:abs, pow, round, ceil, floor, sqrt, exp, log, log10
|
||||
- 三角函数:sin, cos, tan, asin, acos, atan
|
||||
- 常量:pi, e
|
||||
|
||||
Args:
|
||||
formula: 数学公式字符串,可以使用变量a、b、c、d
|
||||
a: 变量a的值
|
||||
b: 变量b的值
|
||||
c: 变量c的值
|
||||
d: 变量d的值
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
|
||||
Examples:
|
||||
>>> evaluate_formula("a + b", 1, 2)
|
||||
3.0
|
||||
>>> evaluate_formula("pow(a, 2)", 5)
|
||||
25.0
|
||||
>>> evaluate_formula("ceil(a / b)", 5, 2)
|
||||
3.0
|
||||
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 5, 3)
|
||||
3.0
|
||||
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 2, 3)
|
||||
2.0
|
||||
"""
|
||||
# 安全的数学函数白名单
|
||||
safe_dict = {
|
||||
# 基本运算
|
||||
'abs': abs,
|
||||
'pow': pow,
|
||||
'round': round,
|
||||
# 数学函数
|
||||
'ceil': math.ceil,
|
||||
'floor': math.floor,
|
||||
'sqrt': math.sqrt,
|
||||
'exp': math.exp,
|
||||
'log': math.log,
|
||||
'log10': math.log10,
|
||||
# 三角函数
|
||||
'sin': math.sin,
|
||||
'cos': math.cos,
|
||||
'tan': math.tan,
|
||||
'asin': math.asin,
|
||||
'acos': math.acos,
|
||||
'atan': math.atan,
|
||||
# 常量
|
||||
'pi': math.pi,
|
||||
'e': math.e,
|
||||
# 变量
|
||||
'a': float(a),
|
||||
'b': float(b),
|
||||
'c': float(c),
|
||||
'd': float(d),
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用eval计算公式,限制可用的函数和变量
|
||||
result = eval(formula, {"__builtins__": {}}, safe_dict)
|
||||
return float(result)
|
||||
except Exception as e:
|
||||
raise ValueError(f"公式计算错误: {str(e)}")
|
||||
|
||||
|
||||
def ceil_value(value: float) -> int:
|
||||
"""向上取整"""
|
||||
return math.ceil(value)
|
||||
|
||||
|
||||
def floor_value(value: float) -> int:
|
||||
"""向下取整"""
|
||||
return math.floor(value)
|
||||
|
||||
|
||||
def round_value(value: float, decimals: int = 0) -> float:
|
||||
"""
|
||||
四舍五入
|
||||
|
||||
Args:
|
||||
value: 要取整的值
|
||||
decimals: 保留小数位数
|
||||
|
||||
Returns:
|
||||
四舍五入后的值
|
||||
"""
|
||||
return round(value, decimals)
|
||||
|
||||
|
||||
def power(base: float, exponent: float) -> float:
|
||||
"""计算幂运算"""
|
||||
return math.pow(base, exponent)
|
||||
|
||||
|
||||
def sqrt_value(value: float) -> float:
|
||||
"""计算平方根"""
|
||||
if value < 0:
|
||||
raise ValueError("不能对负数求平方根")
|
||||
return math.sqrt(value)
|
||||
|
||||
|
||||
def add(a: float, b: float) -> float:
|
||||
"""加法"""
|
||||
return a + b
|
||||
|
||||
|
||||
def subtract(a: float, b: float) -> float:
|
||||
"""减法"""
|
||||
return a - b
|
||||
|
||||
|
||||
def multiply(a: float, b: float) -> float:
|
||||
"""乘法"""
|
||||
return a * b
|
||||
|
||||
|
||||
def divide(a: float, b: float) -> float:
|
||||
"""除法"""
|
||||
if b == 0:
|
||||
raise ValueError("除数不能为零")
|
||||
return a / b
|
||||
55
custom_nodes/ComfyUI-Easy-Use/py/libs/messages.py
Normal file
55
custom_nodes/ComfyUI-Easy-Use/py/libs/messages.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
import time
|
||||
import json
|
||||
|
||||
class MessageCancelled(Exception):
|
||||
pass
|
||||
|
||||
class Message:
|
||||
stash = {}
|
||||
messages = {}
|
||||
cancelled = False
|
||||
|
||||
@classmethod
|
||||
def addMessage(cls, id, message):
|
||||
if message == '__cancel__':
|
||||
cls.messages = {}
|
||||
cls.cancelled = True
|
||||
elif message == '__start__':
|
||||
cls.messages = {}
|
||||
cls.stash = {}
|
||||
cls.cancelled = False
|
||||
else:
|
||||
cls.messages[str(id)] = message
|
||||
|
||||
@classmethod
|
||||
def waitForMessage(cls, id, period=0.1, asList=False):
|
||||
sid = str(id)
|
||||
while not (sid in cls.messages) and not ("-1" in cls.messages):
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise MessageCancelled()
|
||||
time.sleep(period)
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise MessageCancelled()
|
||||
message = cls.messages.pop(str(id), None) or cls.messages.pop("-1")
|
||||
try:
|
||||
if asList:
|
||||
return [str(x.strip()) for x in message.split(",")]
|
||||
else:
|
||||
try:
|
||||
return json.loads(message)
|
||||
except ValueError:
|
||||
return message
|
||||
except ValueError:
|
||||
print( f"ERROR IN MESSAGE - failed to parse '${message}' as ${'comma separated list of strings' if asList else 'string'}")
|
||||
return [message] if asList else message
|
||||
|
||||
|
||||
@PromptServer.instance.routes.post('/easyuse/message_callback')
|
||||
async def message_callback(request):
|
||||
post = await request.post()
|
||||
Message.addMessage(post.get("id"), post.get("message"))
|
||||
return web.json_response({})
|
||||
58
custom_nodes/ComfyUI-Easy-Use/py/libs/model.py
Normal file
58
custom_nodes/ComfyUI-Easy-Use/py/libs/model.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import os
|
||||
import folder_paths
|
||||
import server
|
||||
from .utils import find_tags
|
||||
|
||||
class easyModelManager:
|
||||
|
||||
def __init__(self):
|
||||
self.img_suffixes = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".tif", ".tiff"]
|
||||
self.default_suffixes = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]
|
||||
self.models_config = {
|
||||
"checkpoints": {"suffix": self.default_suffixes},
|
||||
"loras": {"suffix": self.default_suffixes},
|
||||
"unet": {"suffix": self.default_suffixes},
|
||||
}
|
||||
self.model_lists = {}
|
||||
|
||||
def find_thumbnail(self, model_type, name):
|
||||
file_no_ext = os.path.splitext(name)[0]
|
||||
for ext in self.img_suffixes:
|
||||
full_path = folder_paths.get_full_path(model_type, file_no_ext + ext)
|
||||
if os.path.isfile(str(full_path)):
|
||||
return full_path
|
||||
return None
|
||||
|
||||
def get_model_lists(self, model_type):
|
||||
if model_type not in self.models_config:
|
||||
return []
|
||||
filenames = folder_paths.get_filename_list(model_type)
|
||||
model_lists = []
|
||||
for name in filenames:
|
||||
model_suffix = os.path.splitext(name)[-1]
|
||||
if model_suffix not in self.models_config[model_type]["suffix"]:
|
||||
continue
|
||||
else:
|
||||
cfg = {
|
||||
"name": os.path.basename(os.path.splitext(name)[0]),
|
||||
"full_name": name,
|
||||
"remark": '',
|
||||
"file_path": folder_paths.get_full_path(model_type, name),
|
||||
"type": model_type,
|
||||
"suffix": model_suffix,
|
||||
"dir_tags": find_tags(name),
|
||||
"cover": self.find_thumbnail(model_type, name),
|
||||
"metadata": None,
|
||||
"sha256": None
|
||||
}
|
||||
model_lists.append(cfg)
|
||||
|
||||
return model_lists
|
||||
|
||||
def get_model_info(self, model_type, model_name):
|
||||
pass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# manager = easyModelManager()
|
||||
# print(manager.get_model_lists("checkpoints"))
|
||||
1053
custom_nodes/ComfyUI-Easy-Use/py/libs/sampler.py
Normal file
1053
custom_nodes/ComfyUI-Easy-Use/py/libs/sampler.py
Normal file
File diff suppressed because it is too large
Load Diff
148
custom_nodes/ComfyUI-Easy-Use/py/libs/styleAlign.py
Normal file
148
custom_nodes/ComfyUI-Easy-Use/py/libs/styleAlign.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from typing import Union
|
||||
|
||||
T = torch.Tensor
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
class StyleAlignedArgs:
|
||||
def __init__(self, share_attn: str) -> None:
|
||||
self.adain_keys = "k" in share_attn
|
||||
self.adain_values = "v" in share_attn
|
||||
self.adain_queries = "q" in share_attn
|
||||
|
||||
share_attention: bool = True
|
||||
adain_queries: bool = True
|
||||
adain_keys: bool = True
|
||||
adain_values: bool = True
|
||||
|
||||
|
||||
def expand_first(
|
||||
feat: T,
|
||||
scale=1.0,
|
||||
) -> T:
|
||||
"""
|
||||
Expand the first element so it has the same shape as the rest of the batch.
|
||||
"""
|
||||
b = feat.shape[0]
|
||||
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
|
||||
if scale == 1:
|
||||
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
|
||||
else:
|
||||
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
|
||||
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
|
||||
return feat_style.reshape(*feat.shape)
|
||||
|
||||
|
||||
def concat_first(feat: T, dim=2, scale=1.0) -> T:
|
||||
"""
|
||||
concat the the feature and the style feature expanded above
|
||||
"""
|
||||
feat_style = expand_first(feat, scale=scale)
|
||||
return torch.cat((feat, feat_style), dim=dim)
|
||||
|
||||
|
||||
def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]":
|
||||
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
|
||||
feat_mean = feat.mean(dim=-2, keepdims=True)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adain(feat: T) -> T:
|
||||
feat_mean, feat_std = calc_mean_std(feat)
|
||||
feat_style_mean = expand_first(feat_mean)
|
||||
feat_style_std = expand_first(feat_std)
|
||||
feat = (feat - feat_mean) / feat_std
|
||||
feat = feat * feat_style_std + feat_style_mean
|
||||
return feat
|
||||
|
||||
class SharedAttentionProcessor:
|
||||
def __init__(self, args: StyleAlignedArgs, scale: float):
|
||||
self.args = args
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
if self.args.adain_queries:
|
||||
q = adain(q)
|
||||
if self.args.adain_keys:
|
||||
k = adain(k)
|
||||
if self.args.adain_values:
|
||||
v = adain(v)
|
||||
if self.args.share_attention:
|
||||
k = concat_first(k, -2, scale=self.scale)
|
||||
v = concat_first(v, -2)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
def get_norm_layers(
|
||||
layer: nn.Module,
|
||||
norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]",
|
||||
share_layer_norm: bool,
|
||||
share_group_norm: bool,
|
||||
):
|
||||
if isinstance(layer, nn.LayerNorm) and share_layer_norm:
|
||||
norm_layers_["layer"].append(layer)
|
||||
if isinstance(layer, nn.GroupNorm) and share_group_norm:
|
||||
norm_layers_["group"].append(layer)
|
||||
else:
|
||||
for child_layer in layer.children():
|
||||
get_norm_layers(
|
||||
child_layer, norm_layers_, share_layer_norm, share_group_norm
|
||||
)
|
||||
|
||||
|
||||
def register_norm_forward(
|
||||
norm_layer: Union[nn.GroupNorm, nn.LayerNorm],
|
||||
) -> Union[nn.GroupNorm, nn.LayerNorm]:
|
||||
if not hasattr(norm_layer, "orig_forward"):
|
||||
setattr(norm_layer, "orig_forward", norm_layer.forward)
|
||||
orig_forward = norm_layer.orig_forward
|
||||
|
||||
def forward_(hidden_states: T) -> T:
|
||||
n = hidden_states.shape[-2]
|
||||
hidden_states = concat_first(hidden_states, dim=-2)
|
||||
hidden_states = orig_forward(hidden_states) # type: ignore
|
||||
return hidden_states[..., :n, :]
|
||||
|
||||
norm_layer.forward = forward_ # type: ignore
|
||||
return norm_layer
|
||||
|
||||
|
||||
def register_shared_norm(
|
||||
model: ModelPatcher,
|
||||
share_group_norm: bool = True,
|
||||
share_layer_norm: bool = True,
|
||||
):
|
||||
norm_layers = {"group": [], "layer": []}
|
||||
get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm)
|
||||
print(
|
||||
f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms."
|
||||
)
|
||||
return [register_norm_forward(layer) for layer in norm_layers["group"]] + [
|
||||
register_norm_forward(layer) for layer in norm_layers["layer"]
|
||||
]
|
||||
|
||||
|
||||
SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"]
|
||||
SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"]
|
||||
|
||||
|
||||
def styleAlignBatch(model, share_norm, share_attn, scale=1.0):
|
||||
m = model.clone()
|
||||
share_group_norm = share_norm in ["group", "both"]
|
||||
share_layer_norm = share_norm in ["layer", "both"]
|
||||
register_shared_norm(model, share_group_norm, share_layer_norm)
|
||||
args = StyleAlignedArgs(share_attn)
|
||||
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
|
||||
return m
|
||||
247
custom_nodes/ComfyUI-Easy-Use/py/libs/translate.py
Normal file
247
custom_nodes/ComfyUI-Easy-Use/py/libs/translate.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#credit to shadowcz007 for this module
|
||||
#from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py
|
||||
import re
|
||||
import os
|
||||
import folder_paths
|
||||
|
||||
import comfy.utils
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from .utils import install_package
|
||||
try:
|
||||
from lark import Lark, Transformer, v_args
|
||||
except:
|
||||
print('install lark...')
|
||||
install_package('lark')
|
||||
from lark import Lark, Transformer, v_args
|
||||
|
||||
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator')
|
||||
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en')
|
||||
zh_en_model, zh_en_tokenizer = None, None
|
||||
|
||||
def correct_prompt_syntax(prompt=""):
|
||||
# print("input prompt",prompt)
|
||||
corrected_elements = []
|
||||
# 处理成统一的英文标点
|
||||
prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',')
|
||||
# 删除多余的空格
|
||||
prompt = re.sub(r'\s+', ' ', prompt).strip()
|
||||
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']')
|
||||
|
||||
# 分词
|
||||
prompt_elements = prompt.split(',')
|
||||
|
||||
def balance_brackets(element, open_bracket, close_bracket):
|
||||
open_brackets_count = element.count(open_bracket)
|
||||
close_brackets_count = element.count(close_bracket)
|
||||
return element + close_bracket * (open_brackets_count - close_brackets_count)
|
||||
|
||||
for element in prompt_elements:
|
||||
element = element.strip()
|
||||
|
||||
# 处理空元素
|
||||
if not element:
|
||||
continue
|
||||
|
||||
# 检查并处理圆括号、方括号、尖括号
|
||||
if element[0] in '([':
|
||||
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']')
|
||||
elif element[0] == '<':
|
||||
corrected_element = balance_brackets(element, '<', '>')
|
||||
else:
|
||||
# 删除开头的右括号或右方括号
|
||||
corrected_element = element.lstrip(')]')
|
||||
|
||||
corrected_elements.append(corrected_element)
|
||||
|
||||
# 重组修正后的prompt
|
||||
return ','.join(corrected_elements)
|
||||
|
||||
def detect_language(input_str):
|
||||
# 统计中文和英文字符的数量
|
||||
count_cn = count_en = 0
|
||||
for char in input_str:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
count_cn += 1
|
||||
elif char.isalpha():
|
||||
count_en += 1
|
||||
|
||||
# 根据统计的字符数量判断主要语言
|
||||
if count_cn > count_en:
|
||||
return "cn"
|
||||
elif count_en > count_cn:
|
||||
return "en"
|
||||
else:
|
||||
return "unknow"
|
||||
|
||||
def has_chinese(text):
|
||||
has_cn = False
|
||||
_text = text
|
||||
_text = re.sub(r'<.*?>', '', _text)
|
||||
_text = re.sub(r'__.*?__', '', _text)
|
||||
_text = re.sub(r'embedding:.*?$', '', _text)
|
||||
for char in _text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
has_cn = True
|
||||
break
|
||||
elif char.isalpha():
|
||||
continue
|
||||
return has_cn
|
||||
|
||||
def translate(text):
|
||||
global zh_en_model_path, zh_en_model, zh_en_tokenizer
|
||||
|
||||
if not os.path.exists(zh_en_model_path):
|
||||
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
|
||||
|
||||
if zh_en_model is None:
|
||||
|
||||
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
|
||||
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
|
||||
|
||||
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
with torch.no_grad():
|
||||
encoded = zh_en_tokenizer([text], return_tensors="pt")
|
||||
encoded.to(zh_en_model.device)
|
||||
sequences = zh_en_model.generate(**encoded)
|
||||
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
|
||||
@v_args(inline=True) # Decorator to flatten the tree directly into the function arguments
|
||||
class ChinesePromptTranslate(Transformer):
|
||||
|
||||
def sentence(self, *args):
|
||||
return ", ".join(args)
|
||||
|
||||
def phrase(self, *args):
|
||||
return "".join(args)
|
||||
|
||||
def emphasis(self, *args):
|
||||
# Reconstruct the emphasis with translated content
|
||||
return "(" + "".join(args) + ")"
|
||||
|
||||
def weak_emphasis(self, *args):
|
||||
print('weak_emphasis:', args)
|
||||
return "[" + "".join(args) + "]"
|
||||
|
||||
def embedding(self, *args):
|
||||
print('prompt embedding', args[0])
|
||||
if len(args) == 1:
|
||||
embedding_name = str(args[0])
|
||||
return f"embedding:{embedding_name}"
|
||||
elif len(args) > 1:
|
||||
embedding_name, *numbers = args
|
||||
|
||||
if len(numbers) == 2:
|
||||
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}"
|
||||
elif len(numbers) == 1:
|
||||
return f"embedding:{embedding_name}:{numbers[0]}"
|
||||
else:
|
||||
return f"embedding:{embedding_name}"
|
||||
|
||||
def lora(self, *args):
|
||||
if len(args) == 1:
|
||||
return f"<lora:{args[0]}>"
|
||||
elif len(args) > 1:
|
||||
# print('lora', args)
|
||||
_, loar_name, *numbers = args
|
||||
loar_name = str(loar_name).strip()
|
||||
if len(numbers) == 2:
|
||||
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>"
|
||||
elif len(numbers) == 1:
|
||||
return f"<lora:{loar_name}:{numbers[0]}>"
|
||||
else:
|
||||
return f"<lora:{loar_name}>"
|
||||
|
||||
def weight(self, word, number):
|
||||
translated_word = translate(str(word)).rstrip('.')
|
||||
return f"({translated_word}:{str(number).strip()})"
|
||||
|
||||
def schedule(self, *args):
|
||||
print('prompt schedule', args)
|
||||
data = [str(arg).strip() for arg in args]
|
||||
|
||||
return f"[{':'.join(data)}]"
|
||||
|
||||
def word(self, word):
|
||||
# Translate each word using the dictionary
|
||||
word = str(word)
|
||||
match_cn = re.search(r'@.*?@', word)
|
||||
if re.search(r'__.*?__', word):
|
||||
return word.rstrip('.')
|
||||
elif match_cn:
|
||||
chinese = match_cn.group()
|
||||
before = word.split('@', 1)
|
||||
before = before[0] if len(before) > 0 else ''
|
||||
before = translate(str(before)).rstrip('.') if before else ''
|
||||
after = word.rsplit('@', 1)
|
||||
after = after[len(after)-1] if len(after) > 1 else ''
|
||||
after = translate(after).rstrip('.') if after else ''
|
||||
return before + chinese.replace('@', '').rstrip('.') + after
|
||||
elif detect_language(word) == "cn":
|
||||
return translate(word).rstrip('.')
|
||||
else:
|
||||
return word.rstrip('.')
|
||||
|
||||
|
||||
#定义Prompt文法
|
||||
grammar = r"""
|
||||
start: sentence
|
||||
sentence: phrase ("," phrase)*
|
||||
phrase: emphasis | weight | word | lora | embedding | schedule
|
||||
emphasis: "(" sentence ")" -> emphasis
|
||||
| "[" sentence "]" -> weak_emphasis
|
||||
weight: "(" word ":" NUMBER ")"
|
||||
schedule: "[" word ":" word ":" NUMBER "]"
|
||||
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">"
|
||||
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)?
|
||||
word: WORD
|
||||
|
||||
NUMBER: /\s*-?\d+(\.\d+)?\s*/
|
||||
WORD: /[^,:\(\)\[\]<>]+/
|
||||
"""
|
||||
def zh_to_en(text):
|
||||
global zh_en_model_path, zh_en_model, zh_en_tokenizer
|
||||
# 进度条
|
||||
pbar = comfy.utils.ProgressBar(len(text) + 1)
|
||||
texts = [correct_prompt_syntax(t) for t in text]
|
||||
|
||||
install_package('sentencepiece', '0.2.0')
|
||||
|
||||
if not os.path.exists(zh_en_model_path):
|
||||
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
|
||||
|
||||
if zh_en_model is None:
|
||||
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
|
||||
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
|
||||
|
||||
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
prompt_result = []
|
||||
|
||||
en_texts = []
|
||||
|
||||
for t in texts:
|
||||
if t:
|
||||
# translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t))
|
||||
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate())
|
||||
# print('t',t)
|
||||
result = parser.parse(t).children
|
||||
# print('en_result',result)
|
||||
# en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax)
|
||||
en_texts.append(result[0])
|
||||
|
||||
zh_en_model.to('cpu')
|
||||
# print("test en_text", en_texts)
|
||||
# en_text.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
pbar.update(1)
|
||||
for t in en_texts:
|
||||
prompt_result.append(t)
|
||||
pbar.update(1)
|
||||
|
||||
# print('prompt_result', prompt_result, )
|
||||
if len(prompt_result) == 0:
|
||||
prompt_result = [""]
|
||||
|
||||
return prompt_result
|
||||
282
custom_nodes/ComfyUI-Easy-Use/py/libs/utils.py
Normal file
282
custom_nodes/ComfyUI-Easy-Use/py/libs/utils.py
Normal file
@@ -0,0 +1,282 @@
|
||||
class AlwaysEqualProxy(str):
|
||||
def __eq__(self, _):
|
||||
return True
|
||||
|
||||
def __ne__(self, _):
|
||||
return False
|
||||
|
||||
class TautologyStr(str):
|
||||
def __ne__(self, other):
|
||||
return False
|
||||
|
||||
class ByPassTypeTuple(tuple):
|
||||
def __getitem__(self, index):
|
||||
if index>0:
|
||||
index=0
|
||||
item = super().__getitem__(index)
|
||||
if isinstance(item, str):
|
||||
return TautologyStr(item)
|
||||
return item
|
||||
|
||||
comfy_ui_revision = None
|
||||
def get_comfyui_revision():
|
||||
try:
|
||||
import git
|
||||
import os
|
||||
import folder_paths
|
||||
repo = git.Repo(os.path.dirname(folder_paths.__file__))
|
||||
comfy_ui_revision = len(list(repo.iter_commits('HEAD')))
|
||||
except:
|
||||
comfy_ui_revision = "Unknown"
|
||||
return comfy_ui_revision
|
||||
|
||||
|
||||
import sys
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
import comfy.model_management as mm
|
||||
import gc
|
||||
from packaging import version
|
||||
from server import PromptServer
|
||||
def is_package_installed(package):
|
||||
try:
|
||||
module = importlib.util.find_spec(package)
|
||||
return module is not None
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def install_package(package, v=None, compare=True, compare_version=None):
|
||||
run_install = True
|
||||
if is_package_installed(package):
|
||||
try:
|
||||
installed_version = importlib.metadata.version(package)
|
||||
if v is not None:
|
||||
if compare_version is None:
|
||||
compare_version = v
|
||||
if not compare or version.parse(installed_version) >= version.parse(compare_version):
|
||||
run_install = False
|
||||
else:
|
||||
run_install = False
|
||||
except:
|
||||
run_install = False
|
||||
|
||||
if run_install:
|
||||
import subprocess
|
||||
package_command = package + '==' + v if v is not None else package
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"Installing {package_command}...", 'duration': 5000})
|
||||
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', package_command], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed successfully", 'type': 'success', 'duration': 5000})
|
||||
print(f"Package {package} installed successfully")
|
||||
return True
|
||||
else:
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed failed", 'type': 'error', 'duration': 5000})
|
||||
print(f"Package {package} installed failed")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def compare_revision(num):
|
||||
global comfy_ui_revision
|
||||
if not comfy_ui_revision:
|
||||
comfy_ui_revision = get_comfyui_revision()
|
||||
return True if comfy_ui_revision == 'Unknown' or int(comfy_ui_revision) >= num else False
|
||||
|
||||
def find_tags(string: str, sep="/") -> list[str]:
|
||||
"""
|
||||
find tags from string use the sep for split
|
||||
Note: string may contain the \\ or / for path separator
|
||||
"""
|
||||
if not string:
|
||||
return []
|
||||
string = string.replace("\\", "/")
|
||||
while "//" in string:
|
||||
string = string.replace("//", "/")
|
||||
if string and sep in string:
|
||||
return string.split(sep)[:-1]
|
||||
return []
|
||||
|
||||
|
||||
from comfy.model_base import BaseModel
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
def get_sd_version(model):
|
||||
base: BaseModel = model.model
|
||||
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
|
||||
if isinstance(model_config, comfy.supported_models.SDXL):
|
||||
return 'sdxl'
|
||||
elif isinstance(model_config, comfy.supported_models.SDXLRefiner):
|
||||
return 'sdxl_refiner'
|
||||
elif isinstance(
|
||||
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
|
||||
):
|
||||
return 'sd1'
|
||||
elif isinstance(
|
||||
model_config, (comfy.supported_models.SVD_img2vid)
|
||||
):
|
||||
return 'svd'
|
||||
elif isinstance(model_config, comfy.supported_models.SD3):
|
||||
return 'sd3'
|
||||
elif isinstance(model_config, comfy.supported_models.HunyuanDiT):
|
||||
return 'hydit'
|
||||
elif isinstance(model_config, comfy.supported_models.Flux):
|
||||
return 'flux'
|
||||
elif isinstance(model_config, comfy.supported_models.GenmoMochi):
|
||||
return 'mochi'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def find_nearest_steps(clip_id, prompt):
|
||||
"""Find the nearest KSampler or preSampling node that references the given id."""
|
||||
def check_link_to_clip(node_id, clip_id, visited=None, node=None):
|
||||
"""Check if a given node links directly or indirectly to a loader node."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if node_id in visited:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
if "pipe" in node["inputs"]:
|
||||
link_ids = node["inputs"]["pipe"]
|
||||
for id in link_ids:
|
||||
if id != 0 and id == str(clip_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
for id in prompt:
|
||||
node = prompt[id]
|
||||
if "Sampler" in node["class_type"] or "sampler" in node["class_type"] or "Sampling" in node["class_type"]:
|
||||
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
|
||||
if check_link_to_clip(id, clip_id, None, node):
|
||||
steps = node["inputs"]["steps"] if "steps" in node["inputs"] else 1
|
||||
return steps
|
||||
return 1
|
||||
|
||||
def find_wildcards_seed(clip_id, text, prompt):
|
||||
""" Find easy wildcards seed value"""
|
||||
def find_link_clip_id(id, seed, wildcard_id):
|
||||
node = prompt[id]
|
||||
if "positive" in node['inputs']:
|
||||
link_ids = node["inputs"]["positive"]
|
||||
if type(link_ids) == list:
|
||||
for id in link_ids:
|
||||
if id != 0:
|
||||
if id == wildcard_id:
|
||||
wildcard_node = prompt[wildcard_id]
|
||||
seed = wildcard_node["inputs"]["seed"] if "seed" in wildcard_node["inputs"] else None
|
||||
if seed is None:
|
||||
seed = wildcard_node["inputs"]["seed_num"] if "seed_num" in wildcard_node["inputs"] else None
|
||||
return seed
|
||||
else:
|
||||
return find_link_clip_id(id, seed, wildcard_id)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if "__" in text:
|
||||
seed = None
|
||||
for id in prompt:
|
||||
node = prompt[id]
|
||||
if "wildcards" in node["class_type"]:
|
||||
wildcard_id = id
|
||||
return find_link_clip_id(str(clip_id), seed, wildcard_id)
|
||||
return seed
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_linked_styles_selector(prompt, unique_id, prompt_type='positive'):
|
||||
unique_id = unique_id.split('.')[len(unique_id.split('.')) - 1] if "." in unique_id else unique_id
|
||||
inputs_values = prompt[unique_id]['inputs'][prompt_type] if prompt_type in prompt[unique_id][
|
||||
'inputs'] else None
|
||||
if type(inputs_values) == list and inputs_values != 'undefined' and inputs_values[0]:
|
||||
return True if prompt[inputs_values[0]] and prompt[inputs_values[0]]['class_type'] == 'easy stylesSelector' else False
|
||||
else:
|
||||
return False
|
||||
|
||||
use_mirror = False
|
||||
def get_local_filepath(url, dirname, local_file_name=None):
|
||||
"""Get local file path when is already downloaded or download it"""
|
||||
import os
|
||||
from server import PromptServer
|
||||
from urllib.parse import urlparse
|
||||
from torch.hub import download_url_to_file
|
||||
global use_mirror
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
if not local_file_name:
|
||||
parsed_url = urlparse(url)
|
||||
local_file_name = os.path.basename(parsed_url.path)
|
||||
destination = os.path.join(dirname, local_file_name)
|
||||
if not os.path.exists(destination):
|
||||
try:
|
||||
if use_mirror:
|
||||
url = url.replace('huggingface.co', 'hf-mirror.com')
|
||||
print(f'downloading {url} to {destination}')
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Downloading model to {destination}, please wait...', 'duration': 10000})
|
||||
download_url_to_file(url, destination)
|
||||
except Exception as e:
|
||||
use_mirror = True
|
||||
url = url.replace('huggingface.co', 'hf-mirror.com')
|
||||
print(f'Unable to download from huggingface, trying mirror: {url}')
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Unable to connect to huggingface, trying mirror: {url}', 'duration': 10000})
|
||||
try:
|
||||
download_url_to_file(url, destination)
|
||||
except Exception as err:
|
||||
error_msg = str(err.args[0]) if err.args else str(err)
|
||||
PromptServer.instance.send_sync("easyuse-toast",
|
||||
{'content': f'Unable to download model from {url}', 'type':'error'})
|
||||
raise Exception(f'Download failed. Original URL and mirror both failed.\nError: {error_msg}')
|
||||
return destination
|
||||
|
||||
def to_lora_patch_dict(state_dict: dict) -> dict:
|
||||
""" Convert raw lora state_dict to patch_dict that can be applied on
|
||||
modelpatcher."""
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
if model_key not in patch_dict:
|
||||
patch_dict[model_key] = {}
|
||||
if patch_type not in patch_dict[model_key]:
|
||||
patch_dict[model_key][patch_type] = [None] * 16
|
||||
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||
|
||||
patch_flat = {}
|
||||
for model_key, v in patch_dict.items():
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
return patch_flat
|
||||
|
||||
def easySave(images, filename_prefix, output_type, prompt=None, extra_pnginfo=None):
|
||||
"""Save or Preview Image"""
|
||||
from nodes import PreviewImage, SaveImage
|
||||
if output_type in ["Hide", "None"]:
|
||||
return list()
|
||||
elif output_type in ["Preview", "Preview&Choose"]:
|
||||
filename_prefix = 'easyPreview'
|
||||
results = PreviewImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
|
||||
return results['ui']['images']
|
||||
else:
|
||||
results = SaveImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
|
||||
return results['ui']['images']
|
||||
|
||||
def getMetadata(filepath):
|
||||
with open(filepath, "rb") as file:
|
||||
# https://github.com/huggingface/safetensors#format
|
||||
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
|
||||
header_size = int.from_bytes(file.read(8), "little", signed=False)
|
||||
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header size")
|
||||
|
||||
header = file.read(header_size)
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header")
|
||||
|
||||
return header
|
||||
|
||||
def cleanGPUUsedForce():
|
||||
gc.collect()
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
476
custom_nodes/ComfyUI-Easy-Use/py/libs/wildcards.py
Normal file
476
custom_nodes/ComfyUI-Easy-Use/py/libs/wildcards.py
Normal file
@@ -0,0 +1,476 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from math import prod
|
||||
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .log import log_node_info
|
||||
|
||||
easy_wildcard_dict = {}
|
||||
|
||||
def get_wildcard_list():
|
||||
return [f"__{x}__" for x in easy_wildcard_dict.keys()]
|
||||
|
||||
def wildcard_normalize(x):
|
||||
return x.replace("\\", "/").lower()
|
||||
|
||||
def read_wildcard(k, v):
|
||||
if isinstance(v, list):
|
||||
k = wildcard_normalize(k)
|
||||
easy_wildcard_dict[k] = v
|
||||
elif isinstance(v, dict):
|
||||
for k2, v2 in v.items():
|
||||
new_key = f"{k}/{k2}"
|
||||
new_key = wildcard_normalize(new_key)
|
||||
read_wildcard(new_key, v2)
|
||||
|
||||
def read_wildcard_dict(wildcard_path):
|
||||
global easy_wildcard_dict
|
||||
for root, directories, files in os.walk(wildcard_path, followlinks=True):
|
||||
for file in files:
|
||||
if file.endswith('.txt'):
|
||||
file_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(file_path, wildcard_path)
|
||||
key = os.path.splitext(rel_path)[0].replace('\\', '/').lower()
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding="UTF-8", errors="ignore") as f:
|
||||
lines = f.read().splitlines()
|
||||
easy_wildcard_dict[key] = lines
|
||||
except UnicodeDecodeError:
|
||||
with open(file_path, 'r', encoding="ISO-8859-1") as f:
|
||||
lines = f.read().splitlines()
|
||||
easy_wildcard_dict[key] = lines
|
||||
elif file.endswith('.yaml'):
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, 'r') as f:
|
||||
yaml_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
for k, v in yaml_data.items():
|
||||
read_wildcard(k, v)
|
||||
elif file.endswith('.json'):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for key, value in json_data.items():
|
||||
key = wildcard_normalize(key)
|
||||
easy_wildcard_dict[key] = value
|
||||
except ValueError:
|
||||
print('json files load error')
|
||||
return easy_wildcard_dict
|
||||
|
||||
|
||||
def process(text, seed=None):
|
||||
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
|
||||
def replace_options(string):
|
||||
replacements_found = False
|
||||
|
||||
def replace_option(match):
|
||||
nonlocal replacements_found
|
||||
options = match.group(1).split('|')
|
||||
|
||||
multi_select_pattern = options[0].split('$$')
|
||||
select_range = None
|
||||
select_sep = ' '
|
||||
range_pattern = r'(\d+)(-(\d+))?'
|
||||
range_pattern2 = r'-(\d+)'
|
||||
|
||||
if len(multi_select_pattern) > 1:
|
||||
r = re.match(range_pattern, options[0])
|
||||
|
||||
if r is None:
|
||||
r = re.match(range_pattern2, options[0])
|
||||
a = '1'
|
||||
b = r.group(1).strip()
|
||||
else:
|
||||
a = r.group(1).strip()
|
||||
b = r.group(3).strip()
|
||||
|
||||
if r is not None:
|
||||
if b is not None and is_numeric_string(a) and is_numeric_string(b):
|
||||
# PATTERN: num1-num2
|
||||
select_range = int(a), int(b)
|
||||
elif is_numeric_string(a):
|
||||
# PATTERN: num
|
||||
x = int(a)
|
||||
select_range = (x, x)
|
||||
|
||||
if select_range is not None and len(multi_select_pattern) == 2:
|
||||
# PATTERN: count$$
|
||||
options[0] = multi_select_pattern[1]
|
||||
elif select_range is not None and len(multi_select_pattern) == 3:
|
||||
# PATTERN: count$$ sep $$
|
||||
select_sep = multi_select_pattern[1]
|
||||
options[0] = multi_select_pattern[2]
|
||||
|
||||
adjusted_probabilities = []
|
||||
|
||||
total_prob = 0
|
||||
|
||||
for option in options:
|
||||
parts = option.split('::', 1)
|
||||
if len(parts) == 2 and is_numeric_string(parts[0].strip()):
|
||||
config_value = float(parts[0].strip())
|
||||
else:
|
||||
config_value = 1 # Default value if no configuration is provided
|
||||
|
||||
adjusted_probabilities.append(config_value)
|
||||
total_prob += config_value
|
||||
|
||||
normalized_probabilities = [prob / total_prob for prob in adjusted_probabilities]
|
||||
|
||||
if select_range is None:
|
||||
select_count = 1
|
||||
else:
|
||||
select_count = random.randint(select_range[0], select_range[1])
|
||||
|
||||
if select_count > len(options):
|
||||
selected_items = options
|
||||
else:
|
||||
selected_items = random.choices(options, weights=normalized_probabilities, k=select_count)
|
||||
selected_items = set(selected_items)
|
||||
|
||||
try_count = 0
|
||||
while len(selected_items) < select_count and try_count < 10:
|
||||
remaining_count = select_count - len(selected_items)
|
||||
additional_items = random.choices(options, weights=normalized_probabilities, k=remaining_count)
|
||||
selected_items |= set(additional_items)
|
||||
try_count += 1
|
||||
|
||||
selected_items2 = [re.sub(r'^\s*[0-9.]+::', '', x, 1) for x in selected_items]
|
||||
replacement = select_sep.join(selected_items2)
|
||||
if '::' in replacement:
|
||||
pass
|
||||
|
||||
replacements_found = True
|
||||
return replacement
|
||||
|
||||
pattern = r'{([^{}]*?)}'
|
||||
replaced_string = re.sub(pattern, replace_option, string)
|
||||
|
||||
return replaced_string, replacements_found
|
||||
|
||||
def replace_wildcard(string):
|
||||
global easy_wildcard_dict
|
||||
pattern = r"__([\w\s.\-+/*\\]+?)__"
|
||||
matches = re.findall(pattern, string)
|
||||
replacements_found = False
|
||||
|
||||
for match in matches:
|
||||
keyword = match.lower()
|
||||
keyword = wildcard_normalize(keyword)
|
||||
if keyword in easy_wildcard_dict:
|
||||
replacement = random.choice(easy_wildcard_dict[keyword])
|
||||
replacements_found = True
|
||||
string = string.replace(f"__{match}__", replacement, 1)
|
||||
elif '*' in keyword:
|
||||
subpattern = keyword.replace('*', '.*').replace('+', r'\+')
|
||||
total_patterns = []
|
||||
found = False
|
||||
for k, v in easy_wildcard_dict.items():
|
||||
if re.match(subpattern, k) is not None:
|
||||
total_patterns += v
|
||||
found = True
|
||||
|
||||
if found:
|
||||
replacement = random.choice(total_patterns)
|
||||
replacements_found = True
|
||||
string = string.replace(f"__{match}__", replacement, 1)
|
||||
elif '/' not in keyword:
|
||||
string_fallback = string.replace(f"__{match}__", f"__*/{match}__", 1)
|
||||
string, replacements_found = replace_wildcard(string_fallback)
|
||||
|
||||
return string, replacements_found
|
||||
|
||||
replace_depth = 100
|
||||
stop_unwrap = False
|
||||
while not stop_unwrap and replace_depth > 1:
|
||||
replace_depth -= 1 # prevent infinite loop
|
||||
|
||||
# pass1: replace options
|
||||
pass1, is_replaced1 = replace_options(text)
|
||||
|
||||
while is_replaced1:
|
||||
pass1, is_replaced1 = replace_options(pass1)
|
||||
|
||||
# pass2: replace wildcards
|
||||
text, is_replaced2 = replace_wildcard(pass1)
|
||||
stop_unwrap = not is_replaced1 and not is_replaced2
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def is_numeric_string(input_str):
|
||||
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None
|
||||
|
||||
|
||||
def safe_float(x):
|
||||
if is_numeric_string(x):
|
||||
return float(x)
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
|
||||
def extract_lora_values(string):
|
||||
pattern = r'<lora:([^>]+)>'
|
||||
matches = re.findall(pattern, string)
|
||||
|
||||
def touch_lbw(text):
|
||||
return re.sub(r'LBW=[A-Za-z][A-Za-z0-9_-]*:', r'LBW=', text)
|
||||
|
||||
items = [touch_lbw(match.strip(':')) for match in matches]
|
||||
|
||||
added = set()
|
||||
result = []
|
||||
for item in items:
|
||||
item = item.split(':')
|
||||
|
||||
lora = None
|
||||
a = None
|
||||
b = None
|
||||
lbw = None
|
||||
lbw_a = None
|
||||
lbw_b = None
|
||||
|
||||
if len(item) > 0:
|
||||
lora = item[0]
|
||||
|
||||
for sub_item in item[1:]:
|
||||
if is_numeric_string(sub_item):
|
||||
if a is None:
|
||||
a = float(sub_item)
|
||||
elif b is None:
|
||||
b = float(sub_item)
|
||||
elif sub_item.startswith("LBW="):
|
||||
for lbw_item in sub_item[4:].split(';'):
|
||||
if lbw_item.startswith("A="):
|
||||
lbw_a = safe_float(lbw_item[2:].strip())
|
||||
elif lbw_item.startswith("B="):
|
||||
lbw_b = safe_float(lbw_item[2:].strip())
|
||||
elif lbw_item.strip() != '':
|
||||
lbw = lbw_item
|
||||
|
||||
if a is None:
|
||||
a = 1.0
|
||||
if b is None:
|
||||
b = 1.0
|
||||
|
||||
if lora is not None and lora not in added:
|
||||
result.append((lora, a, b, lbw, lbw_a, lbw_b))
|
||||
added.add(lora)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_lora_tags(string):
|
||||
pattern = r'<lora:[^>]+>'
|
||||
result = re.sub(pattern, '', string)
|
||||
|
||||
return result
|
||||
|
||||
def process_with_loras(wildcard_opt, model, clip, title="Positive", seed=None, can_load_lora=True, pipe_lora_stack=[], easyCache=None):
|
||||
pass1 = process(wildcard_opt, seed)
|
||||
loras = extract_lora_values(pass1)
|
||||
pass2 = remove_lora_tags(pass1)
|
||||
|
||||
has_noodle_key = True if "__" in wildcard_opt else False
|
||||
has_loras = True if loras != [] else False
|
||||
show_wildcard_prompt = True if has_noodle_key or has_loras else False
|
||||
|
||||
if can_load_lora and has_loras:
|
||||
for lora_name, model_weight, clip_weight, lbw, lbw_a, lbw_b in loras:
|
||||
if (lora_name.split('.')[-1]) not in folder_paths.supported_pt_extensions:
|
||||
lora_name = lora_name+".safetensors"
|
||||
lora = {
|
||||
"lora_name": lora_name, "model": model, "clip": clip, "model_strength": model_weight,
|
||||
"clip_strength": clip_weight,
|
||||
"lbw_a": lbw_a,
|
||||
"lbw_b": lbw_b,
|
||||
"lbw": lbw
|
||||
}
|
||||
model, clip = easyCache.load_lora(lora)
|
||||
lora["model"] = model
|
||||
lora["clip"] = clip
|
||||
pipe_lora_stack.append(lora)
|
||||
|
||||
log_node_info("easy wildcards",f"{title}: {pass2}")
|
||||
if pass1 != pass2:
|
||||
log_node_info("easy wildcards",f'{title}_decode: {pass1}')
|
||||
|
||||
return model, clip, pass2, pass1, show_wildcard_prompt, pipe_lora_stack
|
||||
|
||||
|
||||
def expand_wildcard(keyword: str) -> tuple[str]:
|
||||
"""传入文件通配符的关键词,从 easy_wildcard_dict 中获取通配符的所有选项。"""
|
||||
global easy_wildcard_dict
|
||||
if keyword in easy_wildcard_dict:
|
||||
return tuple(easy_wildcard_dict[keyword])
|
||||
elif '*' in keyword:
|
||||
subpattern = keyword.replace('*', '.*').replace('+', r"\+")
|
||||
total_pattern = []
|
||||
for k, v in easy_wildcard_dict.items():
|
||||
if re.match(subpattern, k) is not None:
|
||||
total_pattern.extend(v)
|
||||
if total_pattern:
|
||||
return tuple(total_pattern)
|
||||
elif '/' not in keyword:
|
||||
return expand_wildcard(f"*/{keyword}")
|
||||
|
||||
def expand_options(options: str) -> tuple[str]:
|
||||
"""传入去掉 {} 的选项。
|
||||
展开选项通配符,返回该选项中的每一项,这里的每一项都是一个替换项。
|
||||
不会对选项内容进行任何处理,即便存在空格或特殊符号,也会原样返回。"""
|
||||
return tuple(options.split("|"))
|
||||
|
||||
|
||||
def decimal_to_irregular(n, bases):
|
||||
"""
|
||||
将十进制数转换为不规则进制
|
||||
|
||||
:param n: 十进制数
|
||||
:param bases: 各位置的基数列表,从低位到高位
|
||||
:return: 不规则进制表示的列表,从低位到高位
|
||||
"""
|
||||
if n == 0:
|
||||
return [0] * len(bases) if bases else [0]
|
||||
|
||||
digits = []
|
||||
remaining = n
|
||||
|
||||
# 从低位到高位处理
|
||||
for base in bases:
|
||||
digit = remaining % base
|
||||
digits.append(digit)
|
||||
remaining = remaining // base
|
||||
|
||||
return digits
|
||||
|
||||
|
||||
class WildcardProcessor:
|
||||
"""通配符处理器
|
||||
|
||||
通配符格式:
|
||||
+ option : {a|b}
|
||||
+ wildcard: __keyword__ 通配符内容将从 Easy-Use 插件提供的 easy_wildcard_dict 中获取
|
||||
"""
|
||||
|
||||
RE_OPTIONS = re.compile(r"{([^{}]*?)}")
|
||||
RE_WILDCARD = re.compile(r"__([\w\s.\-+/*\\]+?)__")
|
||||
RE_REPLACER = re.compile(r"{([^{}]*?)}|__([\w\s.\-+/*\\]+?)__")
|
||||
|
||||
# 将输入的提示词转化成符合 python str.format 要求格式的模板,并将 option 和 wildcard 按照顺序在模板中留下 {0}, {1} 等占位符
|
||||
template: str
|
||||
# option、wildcard 的替换项列表,按照在模板中出现的顺序排列,相同的替换项列表只保留第一份
|
||||
replacers: dict[int, tuple[str]]
|
||||
# 占位符的编号和替换项列表的索引的映射,占位符编号按照在模板中出现的顺序排列,方便减少替换项的存储占用
|
||||
placeholder_mapping: dict[str, int] # placeholder_id => replacer_id
|
||||
# 各替换项列表的项数,按照在模板中出现的顺序排列,提前计算,方便后续使用
|
||||
placeholder_choices: dict[str, int] # placeholder_id => len(replacer)
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.__make_template(text)
|
||||
self.__total = None
|
||||
|
||||
def random(self, seed=None) -> str:
|
||||
"从所有可能性中随机获取一个"
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
return self.getn(random.randint(0, self.total() - 1))
|
||||
|
||||
def getn(self, n: int) -> str:
|
||||
"从所有可能性中获取第 n 个,以 self.total() 为周期循环"
|
||||
n = n % self.total()
|
||||
indice = decimal_to_irregular(n, self.placeholder_choices.values())
|
||||
replacements = {
|
||||
placeholder_id: self.replacers[self.placeholder_mapping[placeholder_id]][i]
|
||||
for placeholder_id, i in zip(self.placeholder_mapping.keys(), indice)
|
||||
}
|
||||
return self.template.format(**replacements)
|
||||
|
||||
def getmany(self, limit: int, offset: int = 0) -> list[str]:
|
||||
"""返回一组可能性组成的列表,为了避免结果太长导致内存占用超限,使用 limit 限制列表的长度,使用 offset 调整偏移。
|
||||
若 limit 和 offset 的设置导致预期的结果长度超过剩下的实际长度,则会回到开头。
|
||||
"""
|
||||
return [self.getn(n) for n in range(offset, offset + limit)]
|
||||
|
||||
def total(self) -> int:
|
||||
"计算可能性的数目"
|
||||
if self.__total is None:
|
||||
self.__total = prod(self.placeholder_choices.values())
|
||||
return self.__total
|
||||
|
||||
def __make_template(self, text: str):
|
||||
"""将输入的提示词转化成符合 python str.format 要求格式的模板,
|
||||
并将 option 和 wildcard 按照顺序在模板中留下 {r0}, {r1} 等占位符,
|
||||
即使遇到相同的 option 或 wildcard,留下的占位符编号也不同,从而使每项都独立变化。
|
||||
"""
|
||||
self.placeholder_mapping = {}
|
||||
placeholder_id = 0
|
||||
replacer_id = 0
|
||||
replacers_rev = {} # replacers => id
|
||||
blocks = []
|
||||
# 记录所处理过的通配符末尾在文本中的位置,用于拼接完整的模板
|
||||
tail = 0
|
||||
for match in self.RE_REPLACER.finditer(text):
|
||||
# 提取并展开通配符内容
|
||||
m = match.group(0)
|
||||
if m.startswith("{"):
|
||||
choices = expand_options(m[1:-1])
|
||||
elif m.startswith("__"):
|
||||
keyword = m[2:-2].lower()
|
||||
keyword = wildcard_normalize(keyword)
|
||||
choices = expand_wildcard(keyword)
|
||||
else:
|
||||
raise ValueError(f"{m!r} is not a wildcard or option")
|
||||
|
||||
# 记录通配符的替换项列表和ID,相同的通配符只保留第一个
|
||||
if choices not in replacers_rev:
|
||||
replacers_rev[choices] = replacer_id
|
||||
replacer_id += 1
|
||||
|
||||
# 拼接通配符前方文本
|
||||
start, end = match.span()
|
||||
blocks.append(text[tail:start])
|
||||
tail = end
|
||||
# 将通配符替换为占位符,并记录占位符和替换项列表的索引的映射
|
||||
blocks.append(f"{{r{placeholder_id}}}")
|
||||
self.placeholder_mapping[f"r{placeholder_id}"] = replacers_rev[choices]
|
||||
placeholder_id += 1
|
||||
|
||||
if tail < len(text):
|
||||
blocks.append(text[tail:])
|
||||
self.template = "".join(blocks)
|
||||
self.replacers = {v: k for k, v in replacers_rev.items()}
|
||||
self.placeholder_choices = {
|
||||
placeholder_id: len(self.replacers[replacer_id])
|
||||
for placeholder_id, replacer_id in self.placeholder_mapping.items()
|
||||
}
|
||||
|
||||
|
||||
def test_option():
|
||||
text = "{|a|b|c}"
|
||||
answer = ["", "a", "b", "c"]
|
||||
p = WildcardProcessor(text)
|
||||
assert p.total() == len(answer)
|
||||
assert p.getn(0) == answer[0]
|
||||
assert p.getmany(4) == answer
|
||||
assert p.getmany(4, 1) == answer[1:]
|
||||
|
||||
|
||||
def test_same():
|
||||
text = "{a|b},{a|b}"
|
||||
answer = ["a,a", "b,a", "a,b", "b,b"]
|
||||
p = WildcardProcessor(text)
|
||||
assert p.total() == len(answer)
|
||||
assert p.getn(0) == answer[0]
|
||||
assert p.getmany(4) == answer
|
||||
assert p.getmany(4, 1) == answer[1:]
|
||||
|
||||
697
custom_nodes/ComfyUI-Easy-Use/py/libs/xyplot.py
Normal file
697
custom_nodes/ComfyUI-Easy-Use/py/libs/xyplot.py
Normal file
@@ -0,0 +1,697 @@
|
||||
import os, torch
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from .utils import easySave, get_sd_version
|
||||
from .adv_encode import advanced_encode
|
||||
from .controlnet import easyControlnet
|
||||
from .log import log_node_warn
|
||||
from ..modules.layer_diffuse import LayerDiffuse
|
||||
from ..config import RESOURCES_DIR
|
||||
from nodes import CLIPTextEncode
|
||||
import pprint
|
||||
try:
|
||||
from comfy_extras.nodes_flux import FluxGuidance
|
||||
except:
|
||||
FluxGuidance = None
|
||||
|
||||
class easyXYPlot():
|
||||
|
||||
def __init__(self, xyPlotData, save_prefix, image_output, prompt, extra_pnginfo, my_unique_id, sampler, easyCache):
|
||||
self.x_node_type, self.x_type = sampler.safe_split(xyPlotData.get("x_axis"), ': ')
|
||||
self.y_node_type, self.y_type = sampler.safe_split(xyPlotData.get("y_axis"), ': ')
|
||||
self.x_values = xyPlotData.get("x_vals") if self.x_type != "None" else []
|
||||
self.y_values = xyPlotData.get("y_vals") if self.y_type != "None" else []
|
||||
self.custom_font = xyPlotData.get("custom_font")
|
||||
|
||||
self.grid_spacing = xyPlotData.get("grid_spacing")
|
||||
self.latent_id = 0
|
||||
self.output_individuals = xyPlotData.get("output_individuals")
|
||||
|
||||
self.x_label, self.y_label = [], []
|
||||
self.max_width, self.max_height = 0, 0
|
||||
self.latents_plot = []
|
||||
self.image_list = []
|
||||
|
||||
self.num_cols = len(self.x_values) if len(self.x_values) > 0 else 1
|
||||
self.num_rows = len(self.y_values) if len(self.y_values) > 0 else 1
|
||||
|
||||
self.total = self.num_cols * self.num_rows
|
||||
self.num = 0
|
||||
|
||||
self.save_prefix = save_prefix
|
||||
self.image_output = image_output
|
||||
self.prompt = prompt
|
||||
self.extra_pnginfo = extra_pnginfo
|
||||
self.my_unique_id = my_unique_id
|
||||
|
||||
self.sampler = sampler
|
||||
self.easyCache = easyCache
|
||||
|
||||
# Helper Functions
|
||||
@staticmethod
|
||||
def define_variable(plot_image_vars, value_type, value, index):
|
||||
|
||||
plot_image_vars[value_type] = value
|
||||
if value_type in ["seed", "Seeds++ Batch"]:
|
||||
value_label = f"seed: {value}"
|
||||
else:
|
||||
value_label = f"{value_type}: {value}"
|
||||
|
||||
if "ControlNet" in value_type:
|
||||
value_label = f"ControlNet {index + 1}"
|
||||
|
||||
if value_type in ['Lora', 'Checkpoint']:
|
||||
arr = value.split(',')
|
||||
model_name = os.path.basename(os.path.splitext(arr[0])[0])
|
||||
trigger_words = ' ' + arr[3] if value_type == 'Lora' and len(arr[3]) > 2 else ''
|
||||
lora_weight = float(arr[1]) if value_type == 'Lora' and len(arr) > 1 else 0
|
||||
lora_weight_desc = f"({lora_weight:.2f})" if lora_weight > 0 else ''
|
||||
value_label = f"{model_name[:30]}{lora_weight_desc} {trigger_words}"
|
||||
|
||||
if value_type in ["ModelMergeBlocks"]:
|
||||
if ":" in value:
|
||||
line = value.split(':')
|
||||
value_label = f"{line[0]}"
|
||||
elif len(value) > 16:
|
||||
value_label = f"ModelMergeBlocks {index + 1}"
|
||||
else:
|
||||
value_label = f"MMB: {value}"
|
||||
|
||||
if value_type in ["Pos Condition"]:
|
||||
value_label = f"pos cond {index + 1}" if index>0 else f"pos cond"
|
||||
if value_type in ["Neg Condition"]:
|
||||
value_label = f"neg cond {index + 1}" if index>0 else f"neg cond"
|
||||
|
||||
if value_type in ["Positive Prompt S/R"]:
|
||||
value_label = f"pos prompt {index + 1}" if index>0 else f"pos prompt"
|
||||
if value_type in ["Negative Prompt S/R"]:
|
||||
value_label = f"neg prompt {index + 1}" if index>0 else f"neg prompt"
|
||||
|
||||
if value_type in ["steps", "cfg", "denoise", "clip_skip",
|
||||
"lora_model_strength", "lora_clip_strength"]:
|
||||
value_label = f"{value_type}: {value}"
|
||||
|
||||
if value_type == "positive":
|
||||
value_label = f"pos prompt {index + 1}"
|
||||
elif value_type == "negative":
|
||||
value_label = f"neg prompt {index + 1}"
|
||||
|
||||
return plot_image_vars, value_label
|
||||
|
||||
@staticmethod
|
||||
def get_font(font_size, font_path=None):
|
||||
if font_path is None:
|
||||
font_path = str(Path(os.path.join(RESOURCES_DIR, 'OpenSans-Medium.ttf')))
|
||||
return ImageFont.truetype(font_path, font_size)
|
||||
|
||||
@staticmethod
|
||||
def update_label(label, value, num_items):
|
||||
if len(label) < num_items:
|
||||
return [*label, value]
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def rearrange_tensors(latent, num_cols, num_rows):
|
||||
new_latent = []
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
index = j * num_rows + i
|
||||
new_latent.append(latent[index])
|
||||
return new_latent
|
||||
|
||||
def calculate_background_dimensions(self):
|
||||
border_size = int((self.max_width // 8) * 1.5) if self.y_type != "None" or self.x_type != "None" else 0
|
||||
|
||||
bg_width = self.num_cols * (self.max_width + self.grid_spacing) - self.grid_spacing + border_size * (
|
||||
self.y_type != "None")
|
||||
bg_height = self.num_rows * (self.max_height + self.grid_spacing) - self.grid_spacing + border_size * (
|
||||
self.x_type != "None")
|
||||
|
||||
# Add space at the bottom of the image for common informaiton about the image
|
||||
bg_height = bg_height + (border_size*2)
|
||||
# print(f"Grid Size: width = {bg_width} height = {bg_height} border_size = {border_size}")
|
||||
|
||||
x_offset_initial = border_size if self.y_type != "None" else 0
|
||||
y_offset = border_size if self.x_type != "None" else 0
|
||||
|
||||
return bg_width, bg_height, x_offset_initial, y_offset
|
||||
|
||||
|
||||
def adjust_font_size(self, text, initial_font_size, label_width):
|
||||
font = self.get_font(initial_font_size, self.custom_font)
|
||||
text_width = font.getbbox(text)
|
||||
# pprint.pp(f"Initial font size: {initial_font_size}, text: {text}, text_width: {text_width}")
|
||||
if text_width and text_width[2]:
|
||||
text_width = text_width[2]
|
||||
|
||||
scaling_factor = 0.9
|
||||
if text_width > (label_width * scaling_factor):
|
||||
# print(f"Adjusting font size from {initial_font_size} to fit text width {text_width} into label width {label_width} scaling_factor {scaling_factor}")
|
||||
return int(initial_font_size * (label_width / text_width) * scaling_factor)
|
||||
else:
|
||||
return initial_font_size
|
||||
|
||||
def textsize(self, d, text, font):
|
||||
_, _, width, height = d.textbbox((0, 0), text=text, font=font)
|
||||
return width, height
|
||||
|
||||
def create_label(self, img, text, initial_font_size, is_x_label=True, max_font_size=70, min_font_size=10, label_width=0, label_height=0):
|
||||
|
||||
# if the label_width is specified, leave it along. Otherwise do the old logic.
|
||||
if label_width == 0:
|
||||
label_width = img.width if is_x_label else img.height
|
||||
|
||||
text_lines = text.split('\n')
|
||||
longest_line = max(text_lines, key=len)
|
||||
|
||||
# Adjust font size
|
||||
font_size = self.adjust_font_size(longest_line, initial_font_size, label_width)
|
||||
font_size = min(max_font_size, font_size) # Ensure font isn't too large
|
||||
font_size = max(min_font_size, font_size) # Ensure font isn't too small
|
||||
|
||||
if label_height == 0:
|
||||
label_height = int(font_size * 1.5) if is_x_label else font_size
|
||||
|
||||
label_bg = Image.new('RGBA', (label_width, label_height), color=(255, 255, 255, 0))
|
||||
d = ImageDraw.Draw(label_bg)
|
||||
|
||||
font = self.get_font(font_size, self.custom_font)
|
||||
|
||||
# Check if text will fit, if not insert ellipsis and reduce text
|
||||
if self.textsize(d, text, font=font)[0] > label_width:
|
||||
while self.textsize(d, text + '...', font=font)[0] > label_width and len(text) > 0:
|
||||
text = text[:-1]
|
||||
text = text + '...'
|
||||
|
||||
# Compute text width and height for multi-line text
|
||||
|
||||
text_widths, text_heights = zip(*[self.textsize(d, line, font=font) for line in text_lines])
|
||||
max_text_width = max(text_widths)
|
||||
total_text_height = sum(text_heights)
|
||||
|
||||
# Compute position for each line of text
|
||||
lines_positions = []
|
||||
current_y = 0
|
||||
for line, line_width, line_height in zip(text_lines, text_widths, text_heights):
|
||||
text_x = (label_width - line_width) // 2
|
||||
text_y = current_y + (label_height - total_text_height) // 2
|
||||
current_y += line_height
|
||||
lines_positions.append((line, (text_x, text_y)))
|
||||
|
||||
# Draw each line of text
|
||||
for line, (text_x, text_y) in lines_positions:
|
||||
d.text((text_x, text_y), line, fill='black', font=font)
|
||||
|
||||
return label_bg
|
||||
|
||||
def sample_plot_image(self, plot_image_vars, samples, preview_latent, latents_plot, image_list, disable_noise,
|
||||
start_step, last_step, force_full_denoise, x_value=None, y_value=None):
|
||||
model, clip, vae, positive, negative, seed, steps, cfg = None, None, None, None, None, None, None, None
|
||||
sampler_name, scheduler, denoise = None, None, None
|
||||
|
||||
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] if "a1111_prompt_style" in plot_image_vars else False
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
steps = plot_image_vars['steps'] if "steps" in plot_image_vars else 1
|
||||
|
||||
sd_version = get_sd_version(plot_image_vars['model'])
|
||||
# 高级用法
|
||||
if plot_image_vars["x_node_type"] == "advanced" or plot_image_vars["y_node_type"] == "advanced":
|
||||
if self.x_type == "Seeds++ Batch" or self.y_type == "Seeds++ Batch":
|
||||
seed = int(x_value) if self.x_type == "Seeds++ Batch" else int(y_value)
|
||||
if self.x_type == "Steps" or self.y_type == "Steps":
|
||||
steps = int(x_value) if self.x_type == "Steps" else int(y_value)
|
||||
if self.x_type == "StartStep" or self.y_type == "StartStep":
|
||||
start_step = int(x_value) if self.x_type == "StartStep" else int(y_value)
|
||||
if self.x_type == "EndStep" or self.y_type == "EndStep":
|
||||
last_step = int(x_value) if self.x_type == "EndStep" else int(y_value)
|
||||
if self.x_type == "CFG Scale" or self.y_type == "CFG Scale":
|
||||
cfg = float(x_value) if self.x_type == "CFG Scale" else float(y_value)
|
||||
if self.x_type == "Sampler" or self.y_type == "Sampler":
|
||||
sampler_name = x_value if self.x_type == "Sampler" else y_value
|
||||
if self.x_type == "Scheduler" or self.y_type == "Scheduler":
|
||||
scheduler = x_value if self.x_type == "Scheduler" else y_value
|
||||
if self.x_type == "Sampler&Scheduler" or self.y_type == "Sampler&Scheduler":
|
||||
arr = x_value.split(',') if self.x_type == "Sampler&Scheduler" else y_value.split(',')
|
||||
if arr[0] and arr[0]!= 'None':
|
||||
sampler_name = arr[0]
|
||||
if arr[1] and arr[1]!= 'None':
|
||||
scheduler = arr[1]
|
||||
if self.x_type == "Denoise" or self.y_type == "Denoise":
|
||||
denoise = float(x_value) if self.x_type == "Denoise" else float(y_value)
|
||||
if self.x_type == "Pos Condition" or self.y_type == "Pos Condition":
|
||||
positive = plot_image_vars['positive_cond_stack'][int(x_value)] if self.x_type == "Pos Condition" else plot_image_vars['positive_cond_stack'][int(y_value)]
|
||||
if self.x_type == "Neg Condition" or self.y_type == "Neg Condition":
|
||||
negative = plot_image_vars['negative_cond_stack'][int(x_value)] if self.x_type == "Neg Condition" else plot_image_vars['negative_cond_stack'][int(y_value)]
|
||||
# 模型叠加
|
||||
if self.x_type == "ModelMergeBlocks" or self.y_type == "ModelMergeBlocks":
|
||||
ckpt_name_1, ckpt_name_2 = plot_image_vars['models']
|
||||
model1, clip1, vae1, clip_vision = self.easyCache.load_checkpoint(ckpt_name_1)
|
||||
model2, clip2, vae2, clip_vision = self.easyCache.load_checkpoint(ckpt_name_2)
|
||||
xy_values = x_value if self.x_type == "ModelMergeBlocks" else y_value
|
||||
if ":" in xy_values:
|
||||
xy_line = xy_values.split(':')
|
||||
xy_values = xy_line[1]
|
||||
|
||||
xy_arrs = xy_values.split(',')
|
||||
# ModelMergeBlocks
|
||||
if len(xy_arrs) == 3:
|
||||
input, middle, out = xy_arrs
|
||||
kwargs = {
|
||||
"input": input,
|
||||
"middle": middle,
|
||||
"out": out
|
||||
}
|
||||
elif len(xy_arrs) == 30:
|
||||
kwargs = {}
|
||||
kwargs["time_embed."] = xy_arrs[0]
|
||||
kwargs["label_emb."] = xy_arrs[1]
|
||||
|
||||
for i in range(12):
|
||||
kwargs["input_blocks.{}.".format(i)] = xy_arrs[2+i]
|
||||
|
||||
for i in range(3):
|
||||
kwargs["middle_block.{}.".format(i)] = xy_arrs[14+i]
|
||||
|
||||
for i in range(12):
|
||||
kwargs["output_blocks.{}.".format(i)] = xy_arrs[17+i]
|
||||
|
||||
kwargs["out."] = xy_arrs[29]
|
||||
else:
|
||||
raise Exception("ModelMergeBlocks weight length error")
|
||||
default_ratio = next(iter(kwargs.values()))
|
||||
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
|
||||
for k in kp:
|
||||
ratio = float(default_ratio)
|
||||
k_unet = k[len("diffusion_model."):]
|
||||
|
||||
last_arg_size = 0
|
||||
for arg in kwargs:
|
||||
if k_unet.startswith(arg) and last_arg_size < len(arg):
|
||||
ratio = float(kwargs[arg])
|
||||
last_arg_size = len(arg)
|
||||
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
|
||||
vae_use = plot_image_vars['vae_use']
|
||||
|
||||
clip = clip2 if vae_use == 'Use Model 2' else clip1
|
||||
if vae_use == 'Use Model 2':
|
||||
vae = vae2
|
||||
elif vae_use == 'Use Model 1':
|
||||
vae = vae1
|
||||
else:
|
||||
vae = self.easyCache.load_vae(vae_use)
|
||||
model = m
|
||||
|
||||
# 如果存在lora_stack叠加lora
|
||||
optional_lora_stack = plot_image_vars['lora_stack']
|
||||
if optional_lora_stack is not None and optional_lora_stack != []:
|
||||
for lora in optional_lora_stack:
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 处理clip
|
||||
clip = clip.clone()
|
||||
if plot_image_vars['clip_skip'] != 0:
|
||||
clip.clip_layer(plot_image_vars['clip_skip'])
|
||||
|
||||
# CheckPoint
|
||||
if self.x_type == "Checkpoint" or self.y_type == "Checkpoint":
|
||||
xy_values = x_value if self.x_type == "Checkpoint" else y_value
|
||||
ckpt_name, clip_skip, vae_name = xy_values.split(",")
|
||||
ckpt_name = ckpt_name.replace('*', ',')
|
||||
vae_name = vae_name.replace('*', ',')
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
|
||||
if vae_name != 'None':
|
||||
vae = self.easyCache.load_vae(vae_name)
|
||||
|
||||
# 如果存在lora_stack叠加lora
|
||||
optional_lora_stack = plot_image_vars['lora_stack']
|
||||
if optional_lora_stack is not None and optional_lora_stack != []:
|
||||
for lora in optional_lora_stack:
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 处理clip
|
||||
clip = clip.clone()
|
||||
if clip_skip != 'None':
|
||||
clip.clip_layer(int(clip_skip))
|
||||
positive = plot_image_vars['positive']
|
||||
negative = plot_image_vars['negative']
|
||||
a1111_prompt_style = plot_image_vars['a1111_prompt_style']
|
||||
steps = plot_image_vars['steps']
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
positive = advanced_encode(clip, positive,
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable",
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
negative = advanced_encode(clip, negative,
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable",
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
if "positive_cond" in plot_image_vars:
|
||||
positive = positive + plot_image_vars["positive_cond"]
|
||||
if "negative_cond" in plot_image_vars:
|
||||
negative = negative + plot_image_vars["negative_cond"]
|
||||
|
||||
# Lora
|
||||
if self.x_type == "Lora" or self.y_type == "Lora":
|
||||
# print(f"Lora: {x_value} {y_value}")
|
||||
model = model if model is not None else plot_image_vars["model"]
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
xy_values = x_value if self.x_type == "Lora" else y_value
|
||||
lora_name, lora_model_strength, lora_clip_strength, _ = xy_values.split(",")
|
||||
lora_stack = [{"lora_name": lora_name, "model": model, "clip" :clip, "model_strength": float(lora_model_strength), "clip_strength": float(lora_clip_strength)}]
|
||||
|
||||
# print(f"new_lora_stack: {new_lora_stack}")
|
||||
|
||||
|
||||
if 'lora_stack' in plot_image_vars:
|
||||
lora_stack = lora_stack + plot_image_vars['lora_stack']
|
||||
|
||||
if lora_stack is not None and lora_stack != []:
|
||||
for lora in lora_stack:
|
||||
# Each generation of the model, must use the reference to previously created model / clip objects.
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 提示词
|
||||
if "Positive" in self.x_type or "Positive" in self.y_type:
|
||||
if self.x_type == 'Positive Prompt S/R' or self.y_type == 'Positive Prompt S/R':
|
||||
positive = x_value if self.x_type == "Positive Prompt S/R" else y_value
|
||||
|
||||
if sd_version == 'flux':
|
||||
positive, = CLIPTextEncode().encode(clip, positive)
|
||||
else:
|
||||
positive = advanced_encode(clip, positive,
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
# if "positive_cond" in plot_image_vars:
|
||||
# positive = positive + plot_image_vars["positive_cond"]
|
||||
|
||||
if "Negative" in self.x_type or "Negative" in self.y_type:
|
||||
if self.x_type == 'Negative Prompt S/R' or self.y_type == 'Negative Prompt S/R':
|
||||
negative = x_value if self.x_type == "Negative Prompt S/R" else y_value
|
||||
|
||||
if sd_version == 'flux':
|
||||
negative, = CLIPTextEncode().encode(clip, negative)
|
||||
else:
|
||||
negative = advanced_encode(clip, negative,
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
# if "negative_cond" in plot_image_vars:
|
||||
# negative = negative + plot_image_vars["negative_cond"]
|
||||
|
||||
# ControlNet
|
||||
if "ControlNet" in self.x_type or "ControlNet" in self.y_type:
|
||||
cnet = plot_image_vars["cnet"] if "cnet" in plot_image_vars else None
|
||||
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
|
||||
negative = plot_image_vars["negative_cond"] if "negative" in plot_image_vars else None
|
||||
if cnet:
|
||||
index = x_value if "ControlNet" in self.x_type else y_value
|
||||
controlnet = cnet[index]
|
||||
for index, item in enumerate(controlnet):
|
||||
control_net_name = item[0]
|
||||
image = item[1]
|
||||
strength = item[2]
|
||||
start_percent = item[3]
|
||||
end_percent = item[4]
|
||||
provided_control_net = item[5] if len(item) > 5 else None
|
||||
positive, negative = easyControlnet().apply(control_net_name, image, positive, negative, strength, start_percent, end_percent, provided_control_net, 1)
|
||||
# Flux guidance
|
||||
if self.x_type == "Flux Guidance" or self.y_type == "Flux Guidance":
|
||||
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
|
||||
flux_guidance = float(x_value) if self.x_type == "Flux Guidance" else float(y_value)
|
||||
positive, = FluxGuidance().append(positive, flux_guidance)
|
||||
|
||||
# 简单用法
|
||||
if plot_image_vars["x_node_type"] == "loader" or plot_image_vars["y_node_type"] == "loader":
|
||||
if self.x_type == 'ckpt_name' or self.y_type == 'ckpt_name':
|
||||
ckpt_name = x_value if self.x_type == "ckpt_name" else y_value
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
|
||||
|
||||
if self.x_type == 'lora_name' or self.y_type == 'lora_name':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_name = x_value if self.x_type == "lora_name" else y_value
|
||||
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": 1, "clip_strength": 1}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
if self.x_type == 'lora_model_strength' or self.y_type == 'lora_model_strength':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_model_strength = float(x_value) if self.x_type == "lora_model_strength" else float(y_value)
|
||||
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": lora_model_strength, "clip_strength": plot_image_vars['lora_clip_strength']}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
if self.x_type == 'lora_clip_strength' or self.y_type == 'lora_clip_strength':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_clip_strength = float(x_value) if self.x_type == "lora_clip_strength" else float(y_value)
|
||||
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": plot_image_vars['lora_model_strength'], "clip_strength": lora_clip_strength}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# Check for custom VAE
|
||||
if self.x_type == 'vae_name' or self.y_type == 'vae_name':
|
||||
vae_name = x_value if self.x_type == "vae_name" else y_value
|
||||
vae = self.easyCache.load_vae(vae_name)
|
||||
|
||||
# CLIP skip
|
||||
if not clip:
|
||||
raise Exception("No CLIP found")
|
||||
clip = clip.clone()
|
||||
clip.clip_layer(plot_image_vars['clip_skip'])
|
||||
|
||||
if sd_version == 'flux':
|
||||
positive, = CLIPTextEncode().encode(clip, positive)
|
||||
else:
|
||||
positive = advanced_encode(clip, plot_image_vars['positive'],
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'], w_max=1.0,
|
||||
apply_to_pooled="enable",a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
if sd_version == 'flux':
|
||||
negative, = CLIPTextEncode().encode(clip, negative)
|
||||
else:
|
||||
negative = advanced_encode(clip, plot_image_vars['negative'],
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'], w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
|
||||
model = model if model is not None else plot_image_vars["model"]
|
||||
vae = vae if vae is not None else plot_image_vars["vae"]
|
||||
positive = positive if positive is not None else plot_image_vars["positive_cond"]
|
||||
negative = negative if negative is not None else plot_image_vars["negative_cond"]
|
||||
|
||||
seed = seed if seed is not None else plot_image_vars["seed"]
|
||||
steps = steps if steps is not None else plot_image_vars["steps"]
|
||||
cfg = cfg if cfg is not None else plot_image_vars["cfg"]
|
||||
sampler_name = sampler_name if sampler_name is not None else plot_image_vars["sampler_name"]
|
||||
scheduler = scheduler if scheduler is not None else plot_image_vars["scheduler"]
|
||||
denoise = denoise if denoise is not None else plot_image_vars["denoise"]
|
||||
|
||||
noise_device = plot_image_vars["noise_device"] if "noise_device" in plot_image_vars else 'cpu'
|
||||
|
||||
# LayerDiffuse
|
||||
layer_diffusion_method = plot_image_vars["layer_diffusion_method"] if "layer_diffusion_method" in plot_image_vars else None
|
||||
empty_samples = plot_image_vars["empty_samples"] if "empty_samples" in plot_image_vars else None
|
||||
|
||||
if layer_diffusion_method:
|
||||
samp_blend_samples = plot_image_vars["blend_samples"] if "blend_samples" in plot_image_vars else None
|
||||
additional_cond = plot_image_vars["layer_diffusion_cond"] if "layer_diffusion_cond" in plot_image_vars else None
|
||||
|
||||
images = plot_image_vars["images"].movedim(-1, 1) if "images" in plot_image_vars else None
|
||||
weight = plot_image_vars['layer_diffusion_weight'] if 'layer_diffusion_weight' in plot_image_vars else 1.0
|
||||
model, positive, negative = LayerDiffuse().apply_layer_diffusion(model, layer_diffusion_method, weight, samples,
|
||||
samp_blend_samples, positive,
|
||||
negative, images, additional_cond)
|
||||
|
||||
samples = empty_samples if layer_diffusion_method is not None and empty_samples is not None else samples
|
||||
# Sample
|
||||
samples = self.sampler.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, samples,
|
||||
denoise=denoise, disable_noise=disable_noise, preview_latent=preview_latent,
|
||||
start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_device=noise_device)
|
||||
|
||||
# Decode images and store
|
||||
latent = samples["samples"]
|
||||
|
||||
# Add the latent tensor to the tensors list
|
||||
latents_plot.append(latent)
|
||||
|
||||
# Decode the image
|
||||
image = vae.decode(latent).cpu()
|
||||
|
||||
if self.output_individuals in [True, "True"]:
|
||||
easySave(image, self.save_prefix, self.image_output)
|
||||
|
||||
# Convert the image from tensor to PIL Image and add it to the list
|
||||
pil_image = self.sampler.tensor2pil(image)
|
||||
image_list.append(pil_image)
|
||||
|
||||
# Update max dimensions
|
||||
self.max_width = max(self.max_width, pil_image.width)
|
||||
self.max_height = max(self.max_height, pil_image.height)
|
||||
|
||||
# Return the touched variables
|
||||
return image_list, self.max_width, self.max_height, latents_plot
|
||||
|
||||
# Process Functions
|
||||
def validate_xy_plot(self):
|
||||
if self.x_type == 'None' and self.y_type == 'None':
|
||||
log_node_warn(f'#{self.my_unique_id}','No Valid Plot Types - Reverting to default sampling...')
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_latent(self, samples):
|
||||
# Extract the 'samples' tensor from the dictionary
|
||||
latent_image_tensor = samples["samples"]
|
||||
|
||||
# Split the tensor into individual image tensors
|
||||
image_tensors = torch.split(latent_image_tensor, 1, dim=0)
|
||||
|
||||
# Create a list of dictionaries containing the individual image tensors
|
||||
latent_list = [{'samples': image} for image in image_tensors]
|
||||
|
||||
# Set latent only to the first latent of batch
|
||||
if self.latent_id >= len(latent_list):
|
||||
log_node_warn(f'#{self.my_unique_id}',f'The selected latent_id ({self.latent_id}) is out of range.')
|
||||
log_node_warn(f'#{self.my_unique_id}', f'Automatically setting the latent_id to the last image in the list (index: {len(latent_list) - 1}).')
|
||||
|
||||
self.latent_id = len(latent_list) - 1
|
||||
|
||||
return latent_list[self.latent_id]
|
||||
|
||||
def get_labels_and_sample(self, plot_image_vars, latent_image, preview_latent, start_step, last_step,
|
||||
force_full_denoise, disable_noise):
|
||||
for x_index, x_value in enumerate(self.x_values):
|
||||
plot_image_vars, x_value_label = self.define_variable(plot_image_vars, self.x_type, x_value,
|
||||
x_index)
|
||||
self.x_label = self.update_label(self.x_label, x_value_label, len(self.x_values))
|
||||
if self.y_type != 'None':
|
||||
for y_index, y_value in enumerate(self.y_values):
|
||||
plot_image_vars, y_value_label = self.define_variable(plot_image_vars, self.y_type, y_value,
|
||||
y_index)
|
||||
self.y_label = self.update_label(self.y_label, y_value_label, len(self.y_values))
|
||||
# ttNl(f'{CC.GREY}X: {x_value_label}, Y: {y_value_label}').t(
|
||||
# f'Plot Values {self.num}/{self.total} ->').p()
|
||||
|
||||
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
|
||||
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list,
|
||||
disable_noise, start_step, last_step, force_full_denoise, x_value, y_value)
|
||||
self.num += 1
|
||||
else:
|
||||
# ttNl(f'{CC.GREY}X: {x_value_label}').t(f'Plot Values {self.num}/{self.total} ->').p()
|
||||
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
|
||||
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, disable_noise,
|
||||
start_step, last_step, force_full_denoise, x_value)
|
||||
self.num += 1
|
||||
|
||||
# Rearrange latent array to match preview image grid
|
||||
self.latents_plot = self.rearrange_tensors(self.latents_plot, self.num_cols, self.num_rows)
|
||||
|
||||
# Concatenate the tensors along the first dimension (dim=0)
|
||||
self.latents_plot = torch.cat(self.latents_plot, dim=0)
|
||||
|
||||
return self.latents_plot
|
||||
|
||||
def plot_images_and_labels(self, plot_image_vars):
|
||||
|
||||
bg_width, bg_height, x_offset_initial, y_offset = self.calculate_background_dimensions()
|
||||
|
||||
background = Image.new('RGBA', (int(bg_width), int(bg_height)), color=(255, 255, 255, 255))
|
||||
|
||||
output_image = []
|
||||
for row_index in range(self.num_rows):
|
||||
x_offset = x_offset_initial
|
||||
|
||||
for col_index in range(self.num_cols):
|
||||
index = col_index * self.num_rows + row_index
|
||||
img = self.image_list[index]
|
||||
output_image.append(self.sampler.pil2tensor(img))
|
||||
background.paste(img, (x_offset, y_offset))
|
||||
|
||||
# Handle X label
|
||||
if row_index == 0 and self.x_type != "None":
|
||||
label_bg = self.create_label(img, self.x_label[col_index], int(48 * img.width / 512))
|
||||
label_y = (y_offset - label_bg.height) // 2
|
||||
background.alpha_composite(label_bg, (x_offset, label_y))
|
||||
|
||||
# Handle Y label
|
||||
if col_index == 0 and self.y_type != "None":
|
||||
label_bg = self.create_label(img, self.y_label[row_index], int(48 * img.height / 512), False)
|
||||
label_bg = label_bg.rotate(90, expand=True)
|
||||
|
||||
label_x = (x_offset - label_bg.width) // 2
|
||||
label_y = y_offset + (img.height - label_bg.height) // 2
|
||||
background.alpha_composite(label_bg, (label_x, label_y))
|
||||
|
||||
x_offset += img.width + self.grid_spacing
|
||||
|
||||
y_offset += img.height + self.grid_spacing
|
||||
|
||||
# lookup used models in the image
|
||||
common_label = ""
|
||||
# Update to add a function to do the heavy lifting. Parameters are plot_image_vars name, label to use, names of the axis,
|
||||
|
||||
# pprint.pp(plot_image_vars)
|
||||
|
||||
# We don't process LORAs here because there can be multiple of them.
|
||||
labels = [
|
||||
{"id": "ckpt_name", "id_desc": "ckpt", "axis_type" : "Checkpoint"},
|
||||
{"id": "vae_name", "id_desc": '', "axis_type" : "vae_name"},
|
||||
{"id": "sampler_name", "id_desc": "sampler", "axis_type" : "Sampler"},
|
||||
{"id": "scheduler", "id_desc": '', "axis_type" : "Scheduler"},
|
||||
{"id": "steps", "id_desc": '', "axis_type" : "Steps"},
|
||||
{"id": "Flux Guidance", "id_desc": 'guidance', "axis_type" : "Flux Guidance"},
|
||||
{"id": "seed", "id_desc": '', "axis_type" : "Seeds++ Batch"}
|
||||
]
|
||||
|
||||
for item in labels:
|
||||
# Only add the label if it's not one of the axis
|
||||
# print(f"Checking item: {item['id']} axis_type {item['axis_type']} x_type: {self.x_type} y_type: {self.y_type}")
|
||||
if self.x_type != item['axis_type'] and self.y_type != item['axis_type']:
|
||||
common_label += self.add_common_label(item['id'], plot_image_vars, item['id_desc'])
|
||||
common_label += f"\n"
|
||||
|
||||
if plot_image_vars['lora_stack'] is not None and plot_image_vars['lora_stack'] != []:
|
||||
# print(f"lora_stack: {plot_image_vars['lora_stack']}")
|
||||
for lora in plot_image_vars['lora_stack']:
|
||||
|
||||
lora_name = lora['lora_name']
|
||||
lora_weight = lora['model_strength']
|
||||
if lora_name is not None and len(lora_name) > 0 and lora_weight > 0:
|
||||
common_label += f"LORA: {lora_name} weight: {lora_weight:.2f} \n"
|
||||
|
||||
common_label = common_label.strip()
|
||||
|
||||
if len(common_label) > 0:
|
||||
label_height = background.height - y_offset
|
||||
label_bg = self.create_label(background, common_label, int(48 * background.width / 512), label_width=background.width, label_height=label_height)
|
||||
label_x = (background.width - label_bg.width) // 2
|
||||
label_y = y_offset
|
||||
# print(f"Adding common label: {common_label} x = {label_x} y = {label_y}")
|
||||
background.alpha_composite(label_bg, (label_x, label_y))
|
||||
|
||||
return (self.sampler.pil2tensor(background), output_image)
|
||||
|
||||
def add_common_label(self, tag, plot_image_vars, description = ''):
|
||||
label = ''
|
||||
if description == '': description = tag
|
||||
if tag in plot_image_vars and plot_image_vars[tag] is not None and plot_image_vars[tag] != 'None':
|
||||
label += f"{description}: {plot_image_vars[tag]} "
|
||||
# print(f"add_common_label: {tag} description: {description} label: {label}" )
|
||||
return label
|
||||
Reference in New Issue
Block a user