Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,38 @@
from .context_utils import is_context_empty
from .constants import get_category, get_name
from .utils import FlexibleOptionalInputType, any_type
def is_none(value):
"""Checks if a value is none. Pulled out in case we want to expand what 'None' means."""
if value is not None:
if isinstance(value, dict) and 'model' in value and 'clip' in value:
return is_context_empty(value)
return value is None
class RgthreeAnySwitch:
"""The dynamic Any Switch. """
NAME = get_name("Any Switch")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = (any_type,)
RETURN_NAMES = ('*',)
FUNCTION = "switch"
def switch(self, **kwargs):
"""Chooses the first non-empty item to output."""
any_value = None
for key, value in kwargs.items():
if key.startswith('any_') and not is_none(value):
any_value = value
break
return (any_value,)

View File

@@ -0,0 +1,111 @@
import os
import json
from .utils import get_dict_value, set_dict_value, dict_has_key, load_json_file
from .pyproject import VERSION
def get_config_value(key, default=None):
return get_dict_value(RGTHREE_CONFIG, key, default)
def extend_config(default_config, user_config):
""" Returns a new config dict combining user_config into defined keys for default_config."""
cfg = {}
for key, value in default_config.items():
if key not in user_config:
cfg[key] = value
elif isinstance(value, dict):
cfg[key] = extend_config(value, user_config[key])
else:
cfg[key] = user_config[key] if key in user_config else value
return cfg
def set_user_config(data: dict):
""" Sets the user configuration."""
count = 0
for key, value in data.items():
if dict_has_key(DEFAULT_CONFIG, key):
set_dict_value(USER_CONFIG, key, value)
set_dict_value(RGTHREE_CONFIG, key, value)
count += 1
if count > 0:
write_user_config()
def get_rgthree_default_config():
""" Gets the default configuration."""
return load_json_file(DEFAULT_CONFIG_FILE, default={})
def get_rgthree_user_config():
""" Gets the user configuration."""
return load_json_file(USER_CONFIG_FILE, default={})
def write_user_config():
""" Writes the user configuration."""
with open(USER_CONFIG_FILE, 'w+', encoding='UTF-8') as file:
json.dump(USER_CONFIG, file, sort_keys=True, indent=2, separators=(",", ": "))
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CONFIG_FILE = os.path.join(THIS_DIR, '..', 'rgthree_config.json.default')
USER_CONFIG_FILE = os.path.join(THIS_DIR, '..', 'rgthree_config.json')
DEFAULT_CONFIG = {}
USER_CONFIG = {}
RGTHREE_CONFIG = {}
def refresh_config():
"""Refreshes the config."""
global DEFAULT_CONFIG, USER_CONFIG, RGTHREE_CONFIG
DEFAULT_CONFIG = get_rgthree_default_config()
USER_CONFIG = get_rgthree_user_config()
# Migrate old config options into "features"
needs_to_write_user_config = False
if 'patch_recursive_execution' in USER_CONFIG:
del USER_CONFIG['patch_recursive_execution']
needs_to_write_user_config = True
if 'features' in USER_CONFIG and 'patch_recursive_execution' in USER_CONFIG['features']:
del USER_CONFIG['features']['patch_recursive_execution']
needs_to_write_user_config = True
if 'show_alerts_for_corrupt_workflows' in USER_CONFIG:
if 'features' not in USER_CONFIG:
USER_CONFIG['features'] = {}
USER_CONFIG['features']['show_alerts_for_corrupt_workflows'] = USER_CONFIG[
'show_alerts_for_corrupt_workflows']
del USER_CONFIG['show_alerts_for_corrupt_workflows']
needs_to_write_user_config = True
if 'monitor_for_corrupt_links' in USER_CONFIG:
if 'features' not in USER_CONFIG:
USER_CONFIG['features'] = {}
USER_CONFIG['features']['monitor_for_corrupt_links'] = USER_CONFIG['monitor_for_corrupt_links']
del USER_CONFIG['monitor_for_corrupt_links']
needs_to_write_user_config = True
if needs_to_write_user_config is True:
print('writing new user config.')
write_user_config()
RGTHREE_CONFIG = {"version": VERSION} | extend_config(DEFAULT_CONFIG, USER_CONFIG)
if "unreleased" in USER_CONFIG and "unreleased" not in RGTHREE_CONFIG:
RGTHREE_CONFIG["unreleased"] = USER_CONFIG["unreleased"]
if "debug" in USER_CONFIG and "debug" not in RGTHREE_CONFIG:
RGTHREE_CONFIG["debug"] = USER_CONFIG["debug"]
def get_config():
"""Returns the congfig."""
return RGTHREE_CONFIG
refresh_config()

View File

@@ -0,0 +1,11 @@
NAMESPACE='rgthree'
def get_name(name):
return '{} ({})'.format(name, NAMESPACE)
def get_category(sub_dirs = None):
if sub_dirs is None:
return NAMESPACE
else:
return "{}/utils".format(NAMESPACE)

View File

@@ -0,0 +1,33 @@
"""The Context node."""
from .context_utils import (ORIG_CTX_OPTIONAL_INPUTS, ORIG_CTX_RETURN_NAMES, ORIG_CTX_RETURN_TYPES,
get_orig_context_return_tuple, new_context)
from .constants import get_category, get_name
class RgthreeContext:
"""The initial Context node.
For now, this nodes' outputs will remain as-is, as they are perfect for most 1.5 application, but
is also backwards compatible with other Context nodes.
"""
NAME = get_name("Context")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": ORIG_CTX_OPTIONAL_INPUTS,
"hidden": {
"version": "FLOAT"
},
}
RETURN_TYPES = ORIG_CTX_RETURN_TYPES
RETURN_NAMES = ORIG_CTX_RETURN_NAMES
FUNCTION = "convert"
def convert(self, base_ctx=None, **kwargs): # pylint: disable = missing-function-docstring
ctx = new_context(base_ctx, **kwargs)
return get_orig_context_return_tuple(ctx)

View File

@@ -0,0 +1,31 @@
"""The Conmtext big node."""
from .constants import get_category, get_name
from .context_utils import (ALL_CTX_OPTIONAL_INPUTS, ALL_CTX_RETURN_NAMES, ALL_CTX_RETURN_TYPES,
new_context, get_context_return_tuple)
class RgthreeBigContext:
"""The Context Big node.
This context node will expose all context fields as inputs and outputs. It is backwards compatible
with other context nodes and can be intertwined with them.
"""
NAME = get_name("Context Big")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name,missing-function-docstring
return {
"required": {},
"optional": ALL_CTX_OPTIONAL_INPUTS,
"hidden": {},
}
RETURN_TYPES = ALL_CTX_RETURN_TYPES
RETURN_NAMES = ALL_CTX_RETURN_NAMES
FUNCTION = "convert"
def convert(self, base_ctx=None, **kwargs): # pylint: disable = missing-function-docstring
ctx = new_context(base_ctx, **kwargs)
return get_context_return_tuple(ctx)

View File

@@ -0,0 +1,37 @@
"""The Context Switch (Big)."""
from .constants import get_category, get_name
from .context_utils import (ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES, merge_new_context,
get_orig_context_return_tuple, is_context_empty)
from .utils import FlexibleOptionalInputType
class RgthreeContextMerge:
"""The Context Merge node."""
NAME = get_name("Context Merge")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType("RGTHREE_CONTEXT"),
}
RETURN_TYPES = ORIG_CTX_RETURN_TYPES
RETURN_NAMES = ORIG_CTX_RETURN_NAMES
FUNCTION = "merge"
def get_return_tuple(self, ctx):
"""Returns the context data. Separated so it can be overridden."""
return get_orig_context_return_tuple(ctx)
def merge(self, **kwargs):
"""Merges any non-null passed contexts; later ones overriding earlier."""
ctxs = [
value for key, value in kwargs.items()
if key.startswith('ctx_') and not is_context_empty(value)
]
ctx = merge_new_context(*ctxs)
return self.get_return_tuple(ctx)

View File

@@ -0,0 +1,16 @@
"""The Context Switch (Big)."""
from .constants import get_category, get_name
from .context_utils import (ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES, get_context_return_tuple)
from .context_merge import RgthreeContextMerge
class RgthreeContextMergeBig(RgthreeContextMerge):
"""The Context Merge Big node."""
NAME = get_name("Context Merge Big")
RETURN_TYPES = ALL_CTX_RETURN_TYPES
RETURN_NAMES = ALL_CTX_RETURN_NAMES
def get_return_tuple(self, ctx):
"""Returns the context data. Separated so it can be overridden."""
return get_context_return_tuple(ctx)

View File

@@ -0,0 +1,36 @@
"""The original Context Switch."""
from .constants import get_category, get_name
from .context_utils import (ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES, is_context_empty,
get_orig_context_return_tuple)
from .utils import FlexibleOptionalInputType
class RgthreeContextSwitch:
"""The (original) Context Switch node."""
NAME = get_name("Context Switch")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType("RGTHREE_CONTEXT"),
}
RETURN_TYPES = ORIG_CTX_RETURN_TYPES
RETURN_NAMES = ORIG_CTX_RETURN_NAMES
FUNCTION = "switch"
def get_return_tuple(self, ctx):
"""Returns the context data. Separated so it can be overridden."""
return get_orig_context_return_tuple(ctx)
def switch(self, **kwargs):
"""Chooses the first non-empty Context to output."""
ctx = None
for key, value in kwargs.items():
if key.startswith('ctx_') and not is_context_empty(value):
ctx = value
break
return self.get_return_tuple(ctx)

View File

@@ -0,0 +1,16 @@
"""The Context Switch (Big)."""
from .constants import get_category, get_name
from .context_utils import (ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES, get_context_return_tuple)
from .context_switch import RgthreeContextSwitch
class RgthreeContextSwitchBig(RgthreeContextSwitch):
"""The Context Switch Big node."""
NAME = get_name("Context Switch Big")
RETURN_TYPES = ALL_CTX_RETURN_TYPES
RETURN_NAMES = ALL_CTX_RETURN_NAMES
def get_return_tuple(self, ctx):
"""Overrides the RgthreeContextSwitch `get_return_tuple` to return big context data."""
return get_context_return_tuple(ctx)

View File

@@ -0,0 +1,118 @@
"""A set of constants and utilities for handling contexts.
Sets up the inputs and outputs for the Context going forward, with additional functions for
creating and exporting context objects.
"""
import comfy.samplers
import folder_paths
_all_context_input_output_data = {
"base_ctx": ("base_ctx", "RGTHREE_CONTEXT", "CONTEXT"),
"model": ("model", "MODEL", "MODEL"),
"clip": ("clip", "CLIP", "CLIP"),
"vae": ("vae", "VAE", "VAE"),
"positive": ("positive", "CONDITIONING", "POSITIVE"),
"negative": ("negative", "CONDITIONING", "NEGATIVE"),
"latent": ("latent", "LATENT", "LATENT"),
"images": ("images", "IMAGE", "IMAGE"),
"seed": ("seed", "INT", "SEED"),
"steps": ("steps", "INT", "STEPS"),
"step_refiner": ("step_refiner", "INT", "STEP_REFINER"),
"cfg": ("cfg", "FLOAT", "CFG"),
"ckpt_name": ("ckpt_name", folder_paths.get_filename_list("checkpoints"), "CKPT_NAME"),
"sampler": ("sampler", comfy.samplers.KSampler.SAMPLERS, "SAMPLER"),
"scheduler": ("scheduler", comfy.samplers.KSampler.SCHEDULERS, "SCHEDULER"),
"clip_width": ("clip_width", "INT", "CLIP_WIDTH"),
"clip_height": ("clip_height", "INT", "CLIP_HEIGHT"),
"text_pos_g": ("text_pos_g", "STRING", "TEXT_POS_G"),
"text_pos_l": ("text_pos_l", "STRING", "TEXT_POS_L"),
"text_neg_g": ("text_neg_g", "STRING", "TEXT_NEG_G"),
"text_neg_l": ("text_neg_l", "STRING", "TEXT_NEG_L"),
"mask": ("mask", "MASK", "MASK"),
"control_net": ("control_net", "CONTROL_NET", "CONTROL_NET"),
}
force_input_types = ["INT", "STRING", "FLOAT"]
force_input_names = ["sampler", "scheduler", "ckpt_name"]
def _create_context_data(input_list=None):
"""Returns a tuple of context inputs, return types, and return names to use in a node"s def"""
if input_list is None:
input_list = _all_context_input_output_data.keys()
list_ctx_return_types = []
list_ctx_return_names = []
ctx_optional_inputs = {}
for inp in input_list:
data = _all_context_input_output_data[inp]
list_ctx_return_types.append(data[1])
list_ctx_return_names.append(data[2])
ctx_optional_inputs[data[0]] = tuple([data[1]] + ([{
"forceInput": True
}] if data[1] in force_input_types or data[0] in force_input_names else []))
ctx_return_types = tuple(list_ctx_return_types)
ctx_return_names = tuple(list_ctx_return_names)
return (ctx_optional_inputs, ctx_return_types, ctx_return_names)
ALL_CTX_OPTIONAL_INPUTS, ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES = _create_context_data()
_original_ctx_inputs_list = [
"base_ctx", "model", "clip", "vae", "positive", "negative", "latent", "images", "seed"
]
ORIG_CTX_OPTIONAL_INPUTS, ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES = _create_context_data(
_original_ctx_inputs_list)
def new_context(base_ctx, **kwargs):
"""Creates a new context from the provided data, with an optional base ctx to start."""
context = base_ctx if base_ctx is not None else None
new_ctx = {}
for key in _all_context_input_output_data:
if key == "base_ctx":
continue
v = kwargs[key] if key in kwargs else None
new_ctx[key] = v if v is not None else context[
key] if context is not None and key in context else None
return new_ctx
def merge_new_context(*args):
"""Creates a new context by merging provided contexts with the latter overriding same fields."""
new_ctx = {}
for key in _all_context_input_output_data:
if key == "base_ctx":
continue
v = None
# Move backwards through the passed contexts until we find a value and use it.
for ctx in reversed(args):
v = ctx[key] if not is_context_empty(ctx) and key in ctx else None
if v is not None:
break
new_ctx[key] = v
return new_ctx
def get_context_return_tuple(ctx, inputs_list=None):
"""Returns a tuple for returning in the order of the inputs list."""
if inputs_list is None:
inputs_list = _all_context_input_output_data.keys()
tup_list = [
ctx,
]
for key in inputs_list:
if key == "base_ctx":
continue
tup_list.append(ctx[key] if ctx is not None and key in ctx else None)
return tuple(tup_list)
def get_orig_context_return_tuple(ctx):
"""Returns a tuple for returning from a node with only the original context keys."""
return get_context_return_tuple(ctx, _original_ctx_inputs_list)
def is_context_empty(ctx):
"""Checks if the provided ctx is None or contains just None values."""
return not ctx or all(v is None for v in ctx.values())

View File

@@ -0,0 +1,77 @@
import json
from .constants import get_category, get_name
from .utils import any_type, get_dict_value
class RgthreeDisplayAny:
"""Display any data node."""
NAME = get_name('Display Any')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"source": (any_type, {}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
RETURN_TYPES = ()
FUNCTION = "main"
OUTPUT_NODE = True
def main(self, source=None, unique_id=None, extra_pnginfo=None):
value = 'None'
if isinstance(source, str):
value = source
elif isinstance(source, (int, float, bool)):
value = str(source)
elif source is not None:
try:
value = json.dumps(source)
except Exception:
try:
value = str(source)
except Exception:
value = 'source exists, but could not be serialized.'
# Save the output to the pnginfo so it's pre-filled when loading the data.
if extra_pnginfo and unique_id:
for node in get_dict_value(extra_pnginfo, 'workflow.nodes', []):
if str(node['id']) == str(unique_id):
node['widgets_values'] = [value]
break
return {"ui": {"text": (value,)}}
class RgthreeDisplayInt:
"""Old DisplayInt node.
Can be ported over to DisplayAny if https://github.com/comfyanonymous/ComfyUI/issues/1527 fixed.
"""
NAME = get_name('Display Int')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input": ("INT", {
"forceInput": True
}),
},
}
RETURN_TYPES = ()
FUNCTION = "main"
OUTPUT_NODE = True
def main(self, input=None):
return {"ui": {"text": (input,)}}

View File

@@ -0,0 +1,56 @@
"""The Dynamic Context node."""
from mimetypes import add_type
from .constants import get_category, get_name
from .utils import ByPassTypeTuple, FlexibleOptionalInputType
class RgthreeDynamicContext:
"""The Dynamic Context node.
Similar to the static Context and Context Big nodes, this allows users to add any number and
variety of inputs to a Dynamic Context node, and return the outputs by key name.
"""
NAME = get_name("Dynamic Context")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name,missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(add_type),
"hidden": {},
}
RETURN_TYPES = ByPassTypeTuple(("RGTHREE_DYNAMIC_CONTEXT",))
RETURN_NAMES = ByPassTypeTuple(("CONTEXT",))
FUNCTION = "main"
def main(self, **kwargs):
"""Creates a new context from the provided data, with an optional base ctx to start.
This node takes a list of named inputs that are the named keys (with an optional "+ " prefix)
which are to be stored within the ctx dict as well as a list of keys contained in `output_keys`
to determine the list of output data.
"""
base_ctx = kwargs.get('base_ctx', None)
output_keys = kwargs.get('output_keys', None)
new_ctx = base_ctx.copy() if base_ctx is not None else {}
for key_raw, value in kwargs.items():
if key_raw in ['base_ctx', 'output_keys']:
continue
key = key_raw.upper()
if key.startswith('+ '):
key = key[2:]
new_ctx[key] = value
print(new_ctx)
res = [new_ctx]
output_keys = output_keys.split(',') if output_keys is not None else []
for key in output_keys:
res.append(new_ctx[key] if key in new_ctx else None)
return tuple(res)

View File

@@ -0,0 +1,39 @@
"""The original Context Switch."""
from .constants import get_category, get_name
from .context_utils import is_context_empty
from .utils import ByPassTypeTuple, FlexibleOptionalInputType
class RgthreeDynamicContextSwitch:
"""The initial Context Switch node."""
NAME = get_name("Dynamic Context Switch")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType("RGTHREE_DYNAMIC_CONTEXT"),
}
RETURN_TYPES = ByPassTypeTuple(("RGTHREE_DYNAMIC_CONTEXT",))
RETURN_NAMES = ByPassTypeTuple(("CONTEXT",))
FUNCTION = "switch"
def switch(self, **kwargs):
"""Chooses the first non-empty Context to output."""
output_keys = kwargs.get('output_keys', None)
ctx = None
for key, value in kwargs.items():
if key.startswith('ctx_') and not is_context_empty(value):
ctx = value
break
res = [ctx]
output_keys = output_keys.split(',') if output_keys is not None else []
for key in output_keys:
res.append(ctx[key] if ctx is not None and key in ctx else None)
return tuple(res)

View File

@@ -0,0 +1,42 @@
from nodes import PreviewImage
from .constants import get_category, get_name
class RgthreeImageComparer(PreviewImage):
"""A node that compares two images in the UI."""
NAME = get_name('Image Comparer')
CATEGORY = get_category()
FUNCTION = "compare_images"
DESCRIPTION = "Compares two images with a hover slider, or click from properties."
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": {
"image_a": ("IMAGE",),
"image_b": ("IMAGE",),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
def compare_images(self,
image_a=None,
image_b=None,
filename_prefix="rgthree.compare.",
prompt=None,
extra_pnginfo=None):
result = { "ui": { "a_images":[], "b_images": [] } }
if image_a is not None and len(image_a) > 0:
result['ui']['a_images'] = self.save_images(image_a, filename_prefix, prompt, extra_pnginfo)['ui']['images']
if image_b is not None and len(image_b) > 0:
result['ui']['b_images'] = self.save_images(image_b, filename_prefix, prompt, extra_pnginfo)['ui']['images']
return result

View File

@@ -0,0 +1,93 @@
"""Image Inset Crop, with percentages."""
from .log import log_node_info
from .constants import get_category, get_name
from nodes import MAX_RESOLUTION
def get_new_bounds(width, height, left, right, top, bottom):
"""Returns the new bounds for an image with inset crop data."""
left = 0 + left
right = width - right
top = 0 + top
bottom = height - bottom
return (left, right, top, bottom)
class RgthreeImageInsetCrop:
"""Image Inset Crop, with percentages."""
NAME = get_name('Image Inset Crop')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"image": ("IMAGE",),
"measurement": (['Pixels', 'Percentage'],),
"left": ("INT", {
"default": 0,
"min": 0,
"max": MAX_RESOLUTION,
"step": 8
}),
"right": ("INT", {
"default": 0,
"min": 0,
"max": MAX_RESOLUTION,
"step": 8
}),
"top": ("INT", {
"default": 0,
"min": 0,
"max": MAX_RESOLUTION,
"step": 8
}),
"bottom": ("INT", {
"default": 0,
"min": 0,
"max": MAX_RESOLUTION,
"step": 8
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
# pylint: disable = too-many-arguments
def crop(self, measurement, left, right, top, bottom, image=None):
"""Does the crop."""
_, height, width, _ = image.shape
if measurement == 'Percentage':
left = int(width - (width * (100 - left) / 100))
right = int(width - (width * (100 - right) / 100))
top = int(height - (height * (100 - top) / 100))
bottom = int(height - (height * (100 - bottom) / 100))
# Snap to 8 pixels
left = left // 8 * 8
right = right // 8 * 8
top = top // 8 * 8
bottom = bottom // 8 * 8
if left == 0 and right == 0 and bottom == 0 and top == 0:
return (image,)
inset_left, inset_right, inset_top, inset_bottom = get_new_bounds(width, height, left, right,
top, bottom)
if inset_top > inset_bottom:
raise ValueError(
f"Invalid cropping dimensions top ({inset_top}) exceeds bottom ({inset_bottom})")
if inset_left > inset_right:
raise ValueError(
f"Invalid cropping dimensions left ({inset_left}) exceeds right ({inset_right})")
log_node_info(
self.NAME, f'Cropping image {width}x{height} width inset by {inset_left},{inset_right}, ' +
f'and height inset by {inset_top}, {inset_bottom}')
image = image[:, inset_top:inset_bottom, inset_left:inset_right, :]
return (image,)

View File

@@ -0,0 +1,31 @@
from .utils import FlexibleOptionalInputType, any_type
from .constants import get_category, get_name
class RgthreeImageOrLatentSize:
"""The ImageOrLatentSize Node."""
NAME = get_name('Image or Latent Size')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("INT", "INT")
RETURN_NAMES = ('WIDTH', 'HEIGHT')
FUNCTION = "main"
def main(self, **kwargs):
"""Does the node's work."""
image_or_latent_or_mask = kwargs.get('input', None)
if isinstance(image_or_latent_or_mask, dict) and 'samples' in image_or_latent_or_mask:
count, _, height, width = image_or_latent_or_mask['samples'].shape
return (width * 8, height * 8)
batch, height, width, channel = image_or_latent_or_mask.shape
return (width, height)

View File

@@ -0,0 +1,117 @@
import torch
import comfy.utils
import nodes
from .constants import get_category, get_name
class RgthreeImageResize:
"""Image Resize."""
NAME = get_name("Image Resize")
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"image": ("IMAGE",),
"measurement": (["pixels", "percentage"],),
"width": (
"INT", {
"default": 0,
"min": 0,
"max": nodes.MAX_RESOLUTION,
"step": 1,
"tooltip": (
"The width of the desired resize. A pixel value if measurement is 'pixels' or a"
" 100% scale percentage value if measurement is 'percentage'. Passing '0' will"
" calculate the dimension based on the height."
),
},
),
"height": ("INT", {
"default": 0,
"min": 0,
"max": nodes.MAX_RESOLUTION,
"step": 1
}),
"fit": (["crop", "pad", "contain"], {
"tooltip": (
"'crop' resizes so the image covers the desired width and height, and center-crops the"
" excess, returning exactly the desired width and height."
"\n'pad' resizes so the image fits inside the desired width and height, and fills the"
" empty space returning exactly the desired width and height."
"\n'contain' resizes so the image fits inside the desired width and height, and"
" returns the image with it's new size, with one side liekly smaller than the desired."
"\n\nNote, if either width or height is '0', the effective fit is 'contain'."
)
},
),
"method": (nodes.ImageScale.upscale_methods,),
},
}
RETURN_TYPES = ("IMAGE", "INT", "INT",)
RETURN_NAMES = ("IMAGE", "WIDTH", "HEIGHT",)
FUNCTION = "main"
DESCRIPTION = """Resize the image."""
def main(self, image, measurement, width, height, method, fit):
"""Resizes the image."""
_, H, W, _ = image.shape
if measurement == "percentage":
width = round(width * W / 100)
height = round(height * H / 100)
if (width == 0 and height == 0) or (width == W and height == H):
return (image, W, H)
# If one dimension is 0, then calculate the desired value from the ratio of the set dimension.
# This also implies a 'contain' fit since the width and height will be scaled with a locked
# aspect ratio.
if width == 0 or height == 0:
width = round(height / H * W) if width == 0 else width
height = round(width / W * H) if height == 0 else height
fit = "contain"
# At this point, width and height are our output height, but our resize sizes will be different.
resized_width = width
resized_height = height
if fit == "crop":
# If we resize against the opposite ratio, then choose the ratio that has the overhang.
if (height / H * W) > width:
resized_width = round(height / H * W)
elif (width / W * H) > height:
resized_height = round(width / W * H)
elif fit == "contain" or fit == "pad":
# If we resize against the opposite ratio, then choose the ratio that has the overhang.
if (height / H * W) > width:
resized_height = round(width / W * H)
elif (width / W * H) > height:
resized_width = round(height / H * W)
out_image = comfy.utils.common_upscale(
image.clone().movedim(-1, 1), resized_width, resized_height, method, crop="disabled"
).movedim(1, -1)
OB, OH, OW, OC = out_image.shape
if fit != "contain":
# First, we crop, then we pad; no need to check fit (other than not 'contain') since the size
# should already be correct.
if OW > width:
out_image = out_image.narrow(-2, (OW - width) // 2, width)
if OH > height:
out_image = out_image.narrow(-3, (OH - height) // 2, height)
OB, OH, OW, OC = out_image.shape
if width != OW or height != OH:
padded_image = torch.zeros((OB, height, width, OC), dtype=image.dtype, device=image.device)
x = (width - OW) // 2
y = (height - OH) // 2
for b in range(OB):
padded_image[b, y:y + OH, x:x + OW, :] = out_image[b]
out_image = padded_image
return (out_image, out_image.shape[2], out_image.shape[1])

View File

@@ -0,0 +1,56 @@
"""Some basic config stuff I use for SDXL."""
from .constants import get_category, get_name
from nodes import MAX_RESOLUTION
import comfy.samplers
class RgthreeKSamplerConfig:
"""Some basic config stuff I started using for SDXL, but useful in other spots too."""
NAME = get_name('KSampler Config')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"steps_total": ("INT", {
"default": 30,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1,
}),
"refiner_step": ("INT", {
"default": 24,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1,
}),
"cfg": ("FLOAT", {
"default": 8.0,
"min": 0.0,
"max": 100.0,
"step": 0.5,
}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS,),
#"refiner_ascore_pos": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
#"refiner_ascore_neg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
},
}
RETURN_TYPES = ("INT", "INT", "FLOAT", comfy.samplers.KSampler.SAMPLERS,
comfy.samplers.KSampler.SCHEDULERS)
RETURN_NAMES = ("STEPS", "REFINER_STEP", "CFG", "SAMPLER", "SCHEDULER")
FUNCTION = "main"
def main(self, steps_total, refiner_step, cfg, sampler_name, scheduler):
"""main"""
return (
steps_total,
refiner_step,
cfg,
sampler_name,
scheduler,
)

View File

@@ -0,0 +1,100 @@
import datetime
import time
from .pyproject import NAME
# https://stackoverflow.com/questions/4842424/list-of-ansi-color-escape-sequences
# https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit
COLORS = {
'BLACK': '\33[30m',
'RED': '\33[31m',
'GREEN': '\33[32m',
'YELLOW': '\33[33m',
'BLUE': '\33[34m',
'MAGENTA': '\33[35m',
'CYAN': '\33[36m',
'WHITE': '\33[37m',
'GREY': '\33[90m',
'BRIGHT_RED': '\33[91m',
'BRIGHT_GREEN': '\33[92m',
'BRIGHT_YELLOW': '\33[93m',
'BRIGHT_BLUE': '\33[94m',
'BRIGHT_MAGENTA': '\33[95m',
'BRIGHT_CYAN': '\33[96m',
'BRIGHT_WHITE': '\33[97m',
# Styles.
'RESET': '\33[0m', # Note, Portainer doesn't like 00 here, so we'll use 0. Should be fine...
'BOLD': '\33[01m',
'NORMAL': '\33[22m',
'ITALIC': '\33[03m',
'UNDERLINE': '\33[04m',
'BLINK': '\33[05m',
'BLINK2': '\33[06m',
'SELECTED': '\33[07m',
# Backgrounds
'BG_BLACK': '\33[40m',
'BG_RED': '\33[41m',
'BG_GREEN': '\33[42m',
'BG_YELLOW': '\33[43m',
'BG_BLUE': '\33[44m',
'BG_MAGENTA': '\33[45m',
'BG_CYAN': '\33[46m',
'BG_WHITE': '\33[47m',
'BG_GREY': '\33[100m',
'BG_BRIGHT_RED': '\33[101m',
'BG_BRIGHT_GREEN': '\33[102m',
'BG_BRIGHT_YELLOW': '\33[103m',
'BG_BRIGHT_BLUE': '\33[104m',
'BG_BRIGHT_MAGENTA': '\33[105m',
'BG_BRIGHT_CYAN': '\33[106m',
'BG_BRIGHT_WHITE': '\33[107m',
}
def log_node_success(node_name, message, msg_color='RESET'):
"""Logs a success message."""
_log_node("BRIGHT_GREEN", node_name, message, msg_color=msg_color)
def log_node_info(node_name, message, msg_color='RESET'):
"""Logs an info message."""
_log_node("CYAN", node_name, message, msg_color=msg_color)
def log_node_error(node_name, message, msg_color='RESET'):
"""Logs an info message."""
_log_node("RED", node_name, message, msg_color=msg_color)
def log_node_warn(node_name, message, msg_color='RESET'):
"""Logs an warn message."""
_log_node("YELLOW", node_name, message, msg_color=msg_color)
def log_node(node_name, message, msg_color='RESET'):
"""Logs a message."""
_log_node("CYAN", node_name, message, msg_color=msg_color)
def _log_node(color, node_name, message, msg_color='RESET'):
"""Logs for a node message."""
log(message, color=color, prefix=node_name.replace(" (rgthree)", ""), msg_color=msg_color)
LOGGED = {}
def log(message, color=None, msg_color=None, prefix=None, id=None, at_most_secs=None):
"""Basic logging."""
now = int(time.time())
if id:
if at_most_secs is None:
raise ValueError('at_most_secs should be set if an id is set.')
if id in LOGGED:
last_logged = LOGGED[id]
if now < last_logged + at_most_secs:
return
LOGGED[id] = now
color = COLORS[color] if color is not None and color in COLORS else COLORS["BRIGHT_GREEN"]
msg_color = COLORS[msg_color] if msg_color is not None and msg_color in COLORS else ''
prefix = f'[{prefix}]' if prefix is not None else ''
msg = f'{color}[{NAME}]{prefix}'
msg += f'{msg_color} {message}{COLORS["RESET"]}'
print(msg)

View File

@@ -0,0 +1,46 @@
from .constants import get_category, get_name
from nodes import LoraLoader
import folder_paths
class RgthreeLoraLoaderStack:
NAME = get_name('Lora Loader Stack')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP", ),
"lora_01": (['None'] + folder_paths.get_filename_list("loras"), ),
"strength_01":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"lora_02": (['None'] + folder_paths.get_filename_list("loras"), ),
"strength_02":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"lora_03": (['None'] + folder_paths.get_filename_list("loras"), ),
"strength_03":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"lora_04": (['None'] + folder_paths.get_filename_list("loras"), ),
"strength_04":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
def load_lora(self, model, clip, lora_01, strength_01, lora_02, strength_02, lora_03, strength_03, lora_04, strength_04):
if lora_01 != "None" and strength_01 != 0:
model, clip = LoraLoader().load_lora(model, clip, lora_01, strength_01, strength_01)
if lora_02 != "None" and strength_02 != 0:
model, clip = LoraLoader().load_lora(model, clip, lora_02, strength_02, strength_02)
if lora_03 != "None" and strength_03 != 0:
model, clip = LoraLoader().load_lora(model, clip, lora_03, strength_03, strength_03)
if lora_04 != "None" and strength_04 != 0:
model, clip = LoraLoader().load_lora(model, clip, lora_04, strength_04, strength_04)
return (model, clip)

View File

@@ -0,0 +1,101 @@
import folder_paths
from typing import Union
from nodes import LoraLoader
from .constants import get_category, get_name
from .power_prompt_utils import get_lora_by_filename
from .utils import FlexibleOptionalInputType, any_type
from .server.utils_info import get_model_info_file_data
from .log import log_node_warn
NODE_NAME = get_name('Power Lora Loader')
class RgthreePowerLoraLoader:
""" The Power Lora Loader is a powerful, flexible node to add multiple loras to a model/clip."""
NAME = NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
},
# Since we will pass any number of loras in from the UI, this needs to always allow an
"optional": FlexibleOptionalInputType(type=any_type, data={
"model": ("MODEL",),
"clip": ("CLIP",),
}),
"hidden": {},
}
RETURN_TYPES = ("MODEL", "CLIP")
RETURN_NAMES = ("MODEL", "CLIP")
FUNCTION = "load_loras"
def load_loras(self, model=None, clip=None, **kwargs):
"""Loops over the provided loras in kwargs and applies valid ones."""
for key, value in kwargs.items():
key = key.upper()
if key.startswith('LORA_') and 'on' in value and 'lora' in value and 'strength' in value:
strength_model = value['strength']
# If we just passed one strength value, then use it for both, if we passed a strengthTwo
# as well, then our `strength` will be for the model, and `strengthTwo` for clip.
strength_clip = value['strengthTwo'] if 'strengthTwo' in value else None
if clip is None:
if strength_clip is not None and strength_clip != 0:
log_node_warn(NODE_NAME, 'Recieved clip strength eventhough no clip supplied!')
strength_clip = 0
else:
strength_clip = strength_clip if strength_clip is not None else strength_model
if value['on'] and (strength_model != 0 or strength_clip != 0):
lora = get_lora_by_filename(value['lora'], log_node=self.NAME)
if model is not None and lora is not None:
model, clip = LoraLoader().load_lora(model, clip, lora, strength_model, strength_clip)
return (model, clip)
@classmethod
def get_enabled_loras_from_prompt_node(cls,
prompt_node: dict) -> list[dict[str, Union[str, float]]]:
"""Gets enabled loras of a node within a server prompt."""
result = []
for name, lora in prompt_node['inputs'].items():
if name.startswith('lora_') and lora['on']:
lora_file = get_lora_by_filename(lora['lora'], log_node=cls.NAME)
if lora_file is not None: # Add the same safety check
lora_dict = {
'name': lora['lora'],
'strength': lora['strength'],
'path': folder_paths.get_full_path("loras", lora_file)
}
if 'strengthTwo' in lora:
lora_dict['strength_clip'] = lora['strengthTwo']
result.append(lora_dict)
return result
@classmethod
def get_enabled_triggers_from_prompt_node(cls, prompt_node: dict, max_each: int = 1):
"""Gets trigger words up to the max for enabled loras of a node within a server prompt."""
loras = [l['name'] for l in cls.get_enabled_loras_from_prompt_node(prompt_node)]
trained_words = []
for lora in loras:
info = get_model_info_file_data(lora, 'loras', default={})
if not info or not info.keys():
log_node_warn(
NODE_NAME,
f'No info found for lora {lora} when grabbing triggers. Have you generated an info file'
' from the Power Lora Loader "Show Info" dialog?'
)
continue
if 'trainedWords' not in info or not info['trainedWords']:
log_node_warn(
NODE_NAME,
f'No trained words for lora {lora} when grabbing triggers. Have you fetched data from'
'civitai or manually added words?'
)
continue
trained_words += [w for wi in info['trainedWords'][:max_each] if (wi and (w := wi['word']))]
return trained_words

View File

@@ -0,0 +1,83 @@
import re
from .utils import FlexibleOptionalInputType, any_type
from .constants import get_category, get_name
def cast_to_str(x):
"""Handles our cast to a string."""
if x is None:
return ''
try:
return str(x)
except (ValueError, TypeError):
return ''
def cast_to_float(x):
"""Handles our cast to a float."""
try:
return float(x)
except (ValueError, TypeError):
return 0.0
def cast_to_bool(x):
"""Handles our cast to a bool."""
try:
return bool(float(x))
except (ValueError, TypeError):
return str(x).lower() not in ['0', 'false', 'null', 'none', '']
output_to_type = {
'STRING': {
'cast': cast_to_str,
'null': '',
},
'FLOAT': {
'cast': cast_to_float,
'null': 0.0,
},
'INT': {
'cast': lambda x: int(cast_to_float(x)),
'null': 0,
},
'BOOLEAN': {
'cast': cast_to_bool,
'null': False,
},
# This can be removed soon, there was a bug where this should have been BOOLEAN
'BOOL': {
'cast': cast_to_bool,
'null': False,
},
}
class RgthreePowerPrimitive:
"""The Power Primitive Node."""
NAME = get_name('Power Primitive')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = (any_type,)
RETURN_NAMES = ('*',)
FUNCTION = "main"
def main(self, **kwargs):
"""Outputs the expected type."""
output = kwargs.get('value', None)
output_type = re.sub(r'\s*\([^\)]*\)\s*$', '', kwargs.get('type', ''))
output_type = output_to_type[output_type]
cast = output_type['cast']
output = cast(output)
return (output,)

View File

@@ -0,0 +1,95 @@
import os
from .log import log_node_warn, log_node_info, log_node_success
from .constants import get_category, get_name
from .power_prompt_utils import get_and_strip_loras
from nodes import LoraLoader, CLIPTextEncode
import folder_paths
NODE_NAME = get_name('Power Prompt')
class RgthreePowerPrompt:
NAME = NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
# Removed Saved Prompts feature; No sure it worked any longer. UI should fail gracefully,
# TODO: Rip out saved prompt input data
SAVED_PROMPTS_FILES=[]
SAVED_PROMPTS_CONTENT=[]
return {
'required': {
'prompt': ('STRING', {
'multiline': True,
'dynamicPrompts': True
}),
},
'optional': {
"opt_model": ("MODEL",),
"opt_clip": ("CLIP",),
'insert_lora': (['CHOOSE', 'DISABLE LORAS'] +
[os.path.splitext(x)[0] for x in folder_paths.get_filename_list('loras')],),
'insert_embedding': ([
'CHOOSE',
] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],),
'insert_saved': ([
'CHOOSE',
] + SAVED_PROMPTS_FILES,),
},
'hidden': {
'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,),
}
}
RETURN_TYPES = (
'CONDITIONING',
'MODEL',
'CLIP',
'STRING',
)
RETURN_NAMES = (
'CONDITIONING',
'MODEL',
'CLIP',
'TEXT',
)
FUNCTION = 'main'
def main(self,
prompt,
opt_model=None,
opt_clip=None,
insert_lora=None,
insert_embedding=None,
insert_saved=None,
values_insert_saved=None):
if insert_lora == 'DISABLE LORAS':
prompt, loras, skipped, unfound = get_and_strip_loras(prompt, log_node=NODE_NAME, silent=True)
log_node_info(
NODE_NAME,
f'Disabling all found loras ({len(loras)}) and stripping lora tags for TEXT output.')
elif opt_model is not None and opt_clip is not None:
prompt, loras, skipped, unfound = get_and_strip_loras(prompt, log_node=NODE_NAME)
if len(loras) > 0:
for lora in loras:
opt_model, opt_clip = LoraLoader().load_lora(opt_model, opt_clip, lora['lora'],
lora['strength'], lora['strength'])
log_node_success(NODE_NAME, f'Loaded "{lora["lora"]}" from prompt')
log_node_info(NODE_NAME, f'{len(loras)} Loras processed; stripping tags for TEXT output.')
elif '<lora:' in prompt:
prompt, loras, skipped, unfound = get_and_strip_loras(prompt, log_node=NODE_NAME, silent=True)
total_loras = len(loras) + len(skipped) + len(unfound)
if total_loras:
log_node_warn(
NODE_NAME, f'Found {len(loras)} lora tags in prompt but model & clip were not supplied!')
log_node_info(NODE_NAME, 'Loras not processed, keeping for TEXT output.')
conditioning = None
if opt_clip is not None:
conditioning = CLIPTextEncode().encode(opt_clip, prompt)[0]
return (conditioning, opt_model, opt_clip, prompt)

View File

@@ -0,0 +1,42 @@
import os
import folder_paths
from nodes import CLIPTextEncode
from .constants import get_category, get_name
from .power_prompt import RgthreePowerPrompt
class RgthreePowerPromptSimple(RgthreePowerPrompt):
NAME=get_name('Power Prompt - Simple')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
# Removed Saved Prompts feature; No sure it worked any longer. UI should fail gracefully,
# TODO: Rip out saved prompt input data
SAVED_PROMPTS_FILES=[]
SAVED_PROMPTS_CONTENT=[]
return {
'required': {
'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
},
'optional': {
"opt_clip": ("CLIP", ),
'insert_embedding': (['CHOOSE',] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],),
'insert_saved': (['CHOOSE',] + SAVED_PROMPTS_FILES,),
},
'hidden': {
'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,),
}
}
RETURN_TYPES = ('CONDITIONING', 'STRING',)
RETURN_NAMES = ('CONDITIONING', 'TEXT',)
FUNCTION = 'main'
def main(self, prompt, opt_clip=None, insert_embedding=None, insert_saved=None, values_insert_saved=None):
conditioning=None
if opt_clip != None:
conditioning = CLIPTextEncode().encode(opt_clip, prompt)[0]
return (conditioning, prompt)

View File

@@ -0,0 +1,104 @@
"""Utilities for Power Prompt nodes."""
import re
import os
import folder_paths
from .log import log_node_warn, log_node_info
def get_and_strip_loras(prompt, silent=False, log_node="Power Prompt"):
"""Collects and strips lora tags from a prompt."""
pattern = r'<lora:([^:>]*?)(?::(-?\d*(?:\.\d*)?))?>'
lora_paths = folder_paths.get_filename_list('loras')
matches = re.findall(pattern, prompt)
loras = []
unfound_loras = []
skipped_loras = []
for match in matches:
tag_path = match[0]
strength = float(match[1] if len(match) > 1 and len(match[1]) else 1.0)
if strength == 0:
if not silent:
log_node_info(log_node, f'Skipping "{tag_path}" with strength of zero')
skipped_loras.append({'lora': tag_path, 'strength': strength})
continue
lora_path = get_lora_by_filename(tag_path, lora_paths, log_node=None if silent else log_node)
if lora_path is None:
unfound_loras.append({'lora': tag_path, 'strength': strength})
continue
loras.append({'lora': lora_path, 'strength': strength})
return (re.sub(pattern, '', prompt), loras, skipped_loras, unfound_loras)
# pylint: disable = too-many-return-statements, too-many-branches
def get_lora_by_filename(file_path, lora_paths=None, log_node=None):
"""Returns a lora by filename, looking for exactl paths and then fuzzier matching."""
lora_paths = lora_paths if lora_paths is not None else folder_paths.get_filename_list('loras')
if file_path in lora_paths:
return file_path
lora_paths_no_ext = [os.path.splitext(x)[0] for x in lora_paths]
# See if we've entered the exact path, but without the extension
if file_path in lora_paths_no_ext:
found = lora_paths[lora_paths_no_ext.index(file_path)]
return found
# Same check, but ensure file_path is without extension.
file_path_force_no_ext = os.path.splitext(file_path)[0]
if file_path_force_no_ext in lora_paths_no_ext:
found = lora_paths[lora_paths_no_ext.index(file_path_force_no_ext)]
return found
# See if we passed just the name, without paths.
lora_filenames_only = [os.path.basename(x) for x in lora_paths]
if file_path in lora_filenames_only:
found = lora_paths[lora_filenames_only.index(file_path)]
if log_node is not None:
log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".')
return found
# Same, but force the input to be without paths
file_path_force_filename = os.path.basename(file_path)
lora_filenames_only = [os.path.basename(x) for x in lora_paths]
if file_path_force_filename in lora_filenames_only:
found = lora_paths[lora_filenames_only.index(file_path_force_filename)]
if log_node is not None:
log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".')
return found
# Check the filenames and without extension.
lora_filenames_and_no_ext = [os.path.splitext(os.path.basename(x))[0] for x in lora_paths]
if file_path in lora_filenames_and_no_ext:
found = lora_paths[lora_filenames_and_no_ext.index(file_path)]
if log_node is not None:
log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".')
return found
# And, one last forcing the input to be the same
file_path_force_filename_and_no_ext = os.path.splitext(os.path.basename(file_path))[0]
if file_path_force_filename_and_no_ext in lora_filenames_and_no_ext:
found = lora_paths[lora_filenames_and_no_ext.index(file_path_force_filename_and_no_ext)]
if log_node is not None:
log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".')
return found
# Finally, super fuzzy, we'll just check if the input exists in the path at all.
for index, lora_path in enumerate(lora_paths):
if file_path in lora_path:
found = lora_paths[index]
if log_node is not None:
log_node_warn(log_node, f'Fuzzy-matched Lora input "{file_path}" to "{found}".')
return found
if log_node is not None:
log_node_warn(log_node, f'Lora "{file_path}" not found, skipping.')
return None

View File

@@ -0,0 +1,842 @@
"""The Power Puter is a powerful node that can compute and evaluate Python-like code safely allowing
for complex operations for primitives and workflow items for output. From string concatenation, to
math operations, list comprehension, and node value output.
Originally based off https://github.com/pythongosssss/ComfyUI-Custom-Scripts/blob/aac13aa7ce35b07d43633c3bbe654a38c00d74f5/py/math_expression.py
under an MIT License https://github.com/pythongosssss/ComfyUI-Custom-Scripts/blob/aac13aa7ce35b07d43633c3bbe654a38c00d74f5/LICENSE
"""
import math
import ast
import json
import random
import dataclasses
import re
import time
import operator as op
import datetime
import numpy as np
from typing import Any, Callable, Iterable, Optional, Union
from types import MappingProxyType
from .constants import get_category, get_name
from .utils import ByPassTypeTuple, FlexibleOptionalInputType, any_type, get_dict_value
from .log import log_node_error, log_node_warn, log_node_info
from .power_lora_loader import RgthreePowerLoraLoader
from nodes import ImageBatch
from comfy_extras.nodes_latent import LatentBatch
class LoopBreak(Exception):
"""A special error type that is caught in a loop for correct breaking behavior."""
def __init__(self):
super().__init__('Cannot use "break" outside of a loop.')
class LoopContinue(Exception):
"""A special error type that is caught in a loop for correct continue behavior."""
def __init__(self):
super().__init__('Cannot use "continue" outside of a loop.')
@dataclasses.dataclass(frozen=True) # Note, kw_only=True is only python 3.10+
class Function():
"""Function data.
Attributes:
name: The name of the function as called from the node.
call: The callable (reference, lambda, etc), or a string if on _Puter instance.
args: A tuple that represents the minimum and maximum number of args (or arg for no limit).
"""
name: str
call: Union[Callable, str]
args: tuple[int, Optional[int]]
def purge_vram(purge_models=True):
"""Purges vram and, optionally, unloads models."""
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if purge_models:
import comfy
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
def batch(*args):
"""Batches multiple image or latents together."""
def check_is_latent(item) -> bool:
return isinstance(item, dict) and 'samples' in item
args = list(args)
result = args.pop(0)
is_latent = check_is_latent(result)
node = LatentBatch() if is_latent else ImageBatch()
for arg in args:
if is_latent != check_is_latent(arg):
raise ValueError(
f'batch() error: Expecting "{"LATENT" if is_latent else "IMAGE"}"'
f' but got "{"IMAGE" if is_latent else "LATENT"}".'
)
result = node.batch(result, arg)[0]
return result
_BUILTIN_FN_PREFIX = '__rgthreefn.'
def _get_built_in_fn_key(fn: Function) -> str:
"""Returns a key for a built-in function."""
return f'{_BUILTIN_FN_PREFIX}{hash(fn.name)}'
def _get_built_in_fn_by_key(fn_key: str):
"""Returns the `Function` for the provided key (purposefully, not name)."""
if not fn_key.startswith(_BUILTIN_FN_PREFIX) or fn_key not in _BUILT_INS_BY_NAME_AND_KEY:
raise ValueError('No built in function found.')
return _BUILT_INS_BY_NAME_AND_KEY[fn_key]
_BUILT_IN_FNS_LIST = [
Function(name="round", call=round, args=(1, 2)),
Function(name="ceil", call=math.ceil, args=(1, 1)),
Function(name="floor", call=math.floor, args=(1, 1)),
Function(name="sqrt", call=math.sqrt, args=(1, 1)),
Function(name="min", call=min, args=(2, None)),
Function(name="max", call=max, args=(2, None)),
Function(name=".random_int", call=random.randint, args=(2, 2)),
Function(name=".random_choice", call=random.choice, args=(1, 1)),
Function(name=".random_seed", call=random.seed, args=(1, 1)),
Function(name="re", call=re.compile, args=(1, 1)),
Function(name="len", call=len, args=(1, 1)),
Function(name="enumerate", call=enumerate, args=(1, 1)),
Function(name="range", call=range, args=(1, 3)),
# Casts
Function(name="int", call=int, args=(1, 1)),
Function(name="float", call=float, args=(1, 1)),
Function(name="str", call=str, args=(1, 1)),
Function(name="bool", call=bool, args=(1, 1)),
Function(name="list", call=list, args=(1, 1)),
Function(name="tuple", call=tuple, args=(1, 1)),
# Special
Function(name="dir", call=dir, args=(1, 1)),
Function(name="type", call=type, args=(1, 1)),
Function(name="print", call=print, args=(0, None)),
# Comfy Specials
Function(name="node", call='_get_node', args=(0, 1)),
Function(name="nodes", call='_get_nodes', args=(0, 1)),
Function(name="input_node", call='_get_input_node', args=(0, 1)),
Function(name="purge_vram", call=purge_vram, args=(0, 1)),
Function(name="batch", call=batch, args=(2, None)),
]
_BUILT_INS_BY_NAME_AND_KEY = {
fn.name: fn for fn in _BUILT_IN_FNS_LIST
} | {
key: fn for fn in _BUILT_IN_FNS_LIST if (key := _get_built_in_fn_key(fn))
}
_BUILT_INS = MappingProxyType(
{fn.name: key for fn in _BUILT_IN_FNS_LIST if (key := _get_built_in_fn_key(fn))} | {
'random':
MappingProxyType({
'int': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_int']),
'choice': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_choice']),
'seed': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_seed']),
}),
}
)
# A dict of types to blocked attributes/methods. Used to disallow file system access or other
# invocations we may want to block. Necessary for any instance type that is possible to create from
# the code or standard ComfyUI inputs.
#
# For instance, a user does not have access to the numpy module directly, so they cannot invoke
# `numpy.save`. However, a user can access a numpy.ndarray instance from a tensor and, from there,
# an attempt to call `tofile` or `dump` etc. would need to be blocked.
_BLOCKED_METHODS_OR_ATTRS = MappingProxyType({np.ndarray: ['tofile', 'dump']})
# Special functions by class type (called from the Attrs.)
_SPECIAL_FUNCTIONS = {
RgthreePowerLoraLoader.NAME: {
# Get a list of the enabled loras from a power lora loader.
"loras": RgthreePowerLoraLoader.get_enabled_loras_from_prompt_node,
"triggers": RgthreePowerLoraLoader.get_enabled_triggers_from_prompt_node,
}
}
# Series of regex checks for usage of a non-deterministic function. Using these is fine, but means
# the output can't be cached because it's either random, or is associated with another node that is
# not connected to ours (like looking up a node in the prompt). Using these means downstream nodes
# would always be run; that is fine for something like a final JSON output, but less so for a prompt
# text.
_NON_DETERMINISTIC_FUNCTION_CHECKS = [r'(?<!input_)(nodes?)\(',]
_OPERATORS = {
# operator
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.MatMult: op.matmul,
ast.Div: op.truediv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.RShift: op.rshift,
ast.LShift: op.lshift,
ast.BitOr: op.or_,
ast.BitXor: op.xor,
ast.BitAnd: op.and_,
ast.FloorDiv: op.floordiv,
# boolop
ast.And: lambda a, b: a and b,
ast.Or: lambda a, b: a or b,
# unaryop
ast.Invert: op.invert,
ast.Not: lambda a: 0 if a else 1,
ast.USub: op.neg,
# cmpop
ast.Eq: op.eq,
ast.NotEq: op.ne,
ast.Lt: op.lt,
ast.LtE: op.le,
ast.Gt: op.gt,
ast.GtE: op.ge,
ast.Is: op.is_,
ast.IsNot: op.is_not,
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
}
_NODE_NAME = get_name("Power Puter")
def _update_code(code: str, unique_id: str, log=False):
"""Updates the code to either newer syntax or general cleaning."""
# Change usage of `input_node` so the passed variable is a string, if it isn't. So, instead of
# `input_node(a)` it needs to be `input_node('a')`
code = re.sub(r'input_node\(([^\'"].*?)\)', r'input_node("\1")', code)
# Update use of `random_int` to `random.int`
srch = re.compile(r'random_int\(')
if re.search(srch, code):
if log:
log_node_warn(
_NODE_NAME, f"Power Puter node #{unique_id} should update to use the `random.int`"
" built-in instead of `random_int`."
)
code = re.sub(srch, 'random.int(', code)
# Update use of `random_choice` to `random.choice`
srch = re.compile(r'random_choice\(')
if re.search(srch, code):
if log:
log_node_warn(
_NODE_NAME, f"Power Puter node #{unique_id} should update to use the `random.choice`"
" built-in instead of `random_choice`."
)
code = re.sub(srch, 'random.choice(', code)
return code
class RgthreePowerPuter:
"""A powerful node that can compute and evaluate expressions and output as various types."""
NAME = _NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(any_type),
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
"prompt": "PROMPT",
"dynprompt": "DYNPROMPT"
},
}
RETURN_TYPES = ByPassTypeTuple((any_type,))
RETURN_NAMES = ByPassTypeTuple(("*",))
FUNCTION = "main"
@classmethod
def IS_CHANGED(cls, **kwargs):
"""Forces a changed state if we could be unaware of data changes (like using `node()`)."""
code = _update_code(kwargs['code'], unique_id=kwargs['unique_id'])
# Strip string literals and comments.
code = re.sub(r"'[^']+?'", "''", code)
code = re.sub(r'"[^"]+?"', '""', code)
code = re.sub(r'#.*\n', '\n', code)
# If we have a non-deterministic function, then we'll always consider ourself changed since we
# cannot be sure that the data would be the same (random, another unconnected node, etc).
for check in _NON_DETERMINISTIC_FUNCTION_CHECKS:
matches = re.search(check, code)
if matches:
log_node_warn(
_NODE_NAME,
f"Note, Power Puter (node #{kwargs['unique_id']}) cannot be cached b/c it's using a"
f" non-deterministic function call. Matches function call for '{matches.group(1)}'."
)
return time.time()
# Advanced checks.
has_rand_seed = re.search(r'random\.seed\(', code)
has_rand_int_or_choice = re.search(r'(?<!\.)(random\.(int|choice))\(', code)
if has_rand_int_or_choice:
if not has_rand_seed or has_rand_seed.span()[0] > has_rand_int_or_choice.span()[0]:
log_node_warn(
_NODE_NAME,
f"Note, Power Puter (node #{kwargs['unique_id']}) cannot be cached b/c it's using a"
" non-deterministic function call. Matches function call for"
f" `{has_rand_int_or_choice.group(1)}`."
)
return time.time()
if has_rand_seed:
log_node_info(
_NODE_NAME,
f"Power Puter node #{kwargs['unique_id']} WILL be cached eventhough it's using"
f" a non-deterministic random call `{has_rand_int_or_choice.group(1)}` because it also"
f" calls `random.seed` first. NOTE: Please ensure that the seed value is deterministic."
)
return 42
def main(self, **kwargs):
"""Does the nodes' work."""
code = kwargs['code']
unique_id = kwargs['unique_id']
pnginfo = kwargs['extra_pnginfo']
workflow = pnginfo["workflow"] if "workflow" in pnginfo else {"nodes": []}
prompt = kwargs['prompt']
dynprompt = kwargs['dynprompt']
outputs = get_dict_value(kwargs, 'outputs.outputs', None)
if not outputs:
output = kwargs.get('output', None)
if not output:
output = 'STRING'
outputs = [output]
ctx = {}
# Set variable names, defaulting to None instead of KeyErrors
for c in list('abcdefghijklmnopqrstuvwxyz'):
ctx[c] = kwargs[c] if c in kwargs else None
code = _update_code(kwargs['code'], unique_id=kwargs['unique_id'], log=True)
eva = _Puter(
code=code,
ctx=ctx,
workflow=workflow,
prompt=prompt,
dynprompt=dynprompt,
unique_id=unique_id
)
values = eva.execute()
# Check if we have multiple outputs that the returned value is a tuple and raise if not.
if len(outputs) > 1 and not isinstance(values, tuple):
t = re.sub(r'^<[a-z]*\s(.*?)>$', r'\1', str(type(values)))
msg = (
f"When using multiple node outputs, the value from the code should be a 'tuple' with the"
f" number of items equal to the number of outputs. But value from code was of type {t}."
)
log_node_error(_NODE_NAME, f'{msg}\n')
raise ValueError(msg)
if len(outputs) == 1:
values = (values,)
if len(values) > len(outputs):
log_node_warn(
_NODE_NAME,
f"Expected value from code to be tuple with {len(outputs)} items, but value from code had"
f" {len(values)} items. Extra values will be dropped."
)
elif len(values) < len(outputs):
log_node_warn(
_NODE_NAME,
f"Expected value from code to be tuple with {len(outputs)} items, but value from code had"
f" {len(values)} items. Extra outputs will be null."
)
# Now, we'll go over out return tuple, and cast as the output types.
response = []
for i, output in enumerate(outputs):
value = values[i] if len(values) > i else None
if value is not None:
if output == 'INT':
value = int(value)
elif output == 'FLOAT':
value = float(value)
# Accidentally defined "BOOL" when should have been "BOOLEAN."
# TODO: Can prob get rid of BOOl after a bit when UIs would be updated from sending
# BOOL incorrectly.
elif output in ('BOOL', 'BOOLEAN'):
value = bool(value)
elif output == 'STRING':
if isinstance(value, (dict, list)):
value = json.dumps(value, indent=2)
else:
value = str(value)
elif output == '*':
# Do nothing, the output will be passed as-is. This could be anything and it's up to the
# user to control the intended output, like passing through an input value, etc.
pass
response.append(value)
return tuple(response)
class _Puter:
"""The main computation evaluator, using ast.parse the code.
See https://www.basicexamples.com/example/python/ast for examples.
"""
def __init__(self, *, code: str, ctx: dict[str, Any], workflow, prompt, dynprompt, unique_id):
ctx = ctx or {}
self._ctx = {**ctx}
self._code = code
self._workflow = workflow
self._prompt = prompt
self._unique_id = unique_id
self._dynprompt = dynprompt
# These are now expanded lazily when needed.
self._prompt_nodes = None
self._prompt_node = None
def execute(self, code=Optional[str]) -> Any:
"""Evaluates a the code block."""
# Always store random state and initialize a new seed. We'll restore the state later.
initial_random_state = random.getstate()
random.seed(datetime.datetime.now().timestamp())
last_value = None
try:
code = code or self._code
node = ast.parse(self._code)
ctx = {**self._ctx}
for body in node.body:
last_value = self._eval_statement(body, ctx)
# If we got a return, then that's it folks.
if isinstance(body, ast.Return):
break
except:
random.setstate(initial_random_state)
raise
random.setstate(initial_random_state)
return last_value
def _get_prompt_nodes(self):
"""Expands the prompt nodes lazily from the dynamic prompt.
https://github.com/comfyanonymous/ComfyUI/blob/fc657f471a29d07696ca16b566000e8e555d67d1/comfy_execution/graph.py#L22
"""
if self._prompt_nodes is None:
self._prompt_nodes = []
if self._dynprompt:
all_ids = self._dynprompt.all_node_ids()
self._prompt_nodes = [{'id': k} | {**self._dynprompt.get_node(k)} for k in all_ids]
return self._prompt_nodes
def _get_prompt_node(self):
if self._prompt_nodes is None:
self._prompt_node = [n for n in self._get_prompt_nodes() if n['id'] == self._unique_id][0]
return self._prompt_node
def _get_nodes(self, node_id: Union[int, str, re.Pattern, None] = None) -> list[Any]:
"""Get a list of the nodes that match the node_id, or all the nodes in the prompt."""
nodes = self._get_prompt_nodes().copy()
if not node_id:
return nodes
if isinstance(node_id, re.Pattern):
found = [n for n in nodes if re.search(node_id, get_dict_value(n, '_meta.title', ''))]
else:
node_id = str(node_id)
found = None
if re.match(r'\d+$', node_id):
found = [n for n in nodes if node_id == n['id']]
if not found:
found = [n for n in nodes if node_id == get_dict_value(n, '_meta.title', '')]
return found
def _get_node(self, node_id: Union[int, str, re.Pattern, None] = None) -> Union[Any, None]:
"""Returns a prompt-node from the hidden prompt."""
if node_id is None:
return self._get_prompt_node()
nodes = self._get_nodes(node_id)
if nodes and len(nodes) > 1:
log_node_warn(_NODE_NAME, f"More than one node found for '{node_id}'. Returning first.")
return nodes[0] if nodes else None
def _get_input_node(self, input_name, node=None):
"""Gets the (non-muted) node of an input connection from a node (default to the power puter)."""
node = node if node else self._get_prompt_node()
try:
connected_node_id = node['inputs'][input_name][0]
return [n for n in self._get_prompt_nodes() if n['id'] == connected_node_id][0]
except (TypeError, IndexError, KeyError):
log_node_warn(_NODE_NAME, f'No input node found for "{input_name}". ')
return None
def _eval_statement(self, stmt: ast.AST, ctx: dict, prev_stmt: Union[ast.AST, None] = None):
"""Evaluates an ast.stmt."""
if '__returned__' in ctx:
return ctx['__returned__']
# print('\n\n----: _eval_statement')
# print(type(stmt))
# print(ctx)
if isinstance(stmt, (ast.FormattedValue, ast.Expr)):
return self._eval_statement(stmt.value, ctx=ctx)
if isinstance(stmt, (ast.Constant, ast.Num)):
return stmt.n
if isinstance(stmt, ast.BinOp):
left = self._eval_statement(stmt.left, ctx=ctx)
right = self._eval_statement(stmt.right, ctx=ctx)
return _OPERATORS[type(stmt.op)](left, right)
if isinstance(stmt, ast.BoolOp):
is_and = isinstance(stmt.op, ast.And)
is_or = isinstance(stmt.op, ast.Or)
stmt_value_eval = None
for stmt_value in stmt.values:
stmt_value_eval = self._eval_statement(stmt_value, ctx=ctx)
# If we're an and operator and have a falsyt value, then we stop and return. Likewise, if
# we're an or operator and have a truthy value, we can stop and return.
if (is_and and not stmt_value_eval) or (is_or and stmt_value_eval):
return stmt_value_eval
# Always return the last if we made it here w/o success.
return stmt_value_eval
if isinstance(stmt, ast.UnaryOp):
return _OPERATORS[type(stmt.op)](self._eval_statement(stmt.operand, ctx=ctx))
if isinstance(stmt, (ast.Attribute, ast.Subscript)):
# Like: node(14).inputs.sampler_name (Attribute)
# Like: node(14)['inputs']['sampler_name'] (Subscript)
item = self._eval_statement(stmt.value, ctx=ctx)
attr = None
# if hasattr(stmt, 'attr'):
if isinstance(stmt, ast.Attribute):
attr = stmt.attr
else:
# Slice could be a name or a constant; evaluate it
attr = self._eval_statement(stmt.slice, ctx=ctx)
# Check if we're blocking access to this attribute/method on this item type.
for typ, names in _BLOCKED_METHODS_OR_ATTRS.items():
if isinstance(item, typ) and isinstance(attr, str) and attr in names:
raise ValueError(f'Disallowed access to "{attr}" for type {typ}.')
try:
val = item[attr]
except (TypeError, IndexError, KeyError):
try:
val = getattr(item, attr)
except AttributeError:
# If we're a dict, then just return None instead of error; saves time.
if isinstance(item, dict):
# Any special cases in the _SPECIAL_FUNCTIONS
class_type = get_dict_value(item, "class_type")
if class_type in _SPECIAL_FUNCTIONS and attr in _SPECIAL_FUNCTIONS[class_type]:
val = _SPECIAL_FUNCTIONS[class_type][attr]
# If our previous statment was a Call, then send back a tuple of the callable and
# the evaluated item, and it will make the call; perhaps also adding other arguments
# only it knows about.
if isinstance(prev_stmt, ast.Call):
return (val, item)
val = val(item)
else:
val = None
else:
raise
return val
if isinstance(stmt, (ast.List, ast.Tuple)):
value = []
for elt in stmt.elts:
value.append(self._eval_statement(elt, ctx=ctx))
return tuple(value) if isinstance(stmt, ast.Tuple) else value
if isinstance(stmt, ast.Dict):
the_dict = {}
if stmt.keys:
if len(stmt.keys) != len(stmt.values):
raise ValueError('Expected same number of keys as values for dict.')
for i, k in enumerate(stmt.keys):
item_key = self._eval_statement(k, ctx=ctx)
item_value = self._eval_statement(stmt.values[i], ctx=ctx)
the_dict[item_key] = item_value
return the_dict
# f-strings: https://www.basicexamples.com/example/python/ast-JoinedStr
# Note, this will str() all evaluated items in the fstrings, and doesn't handle f-string
# directives, like padding, etc.
if isinstance(stmt, ast.JoinedStr):
vals = [str(self._eval_statement(v, ctx=ctx)) for v in stmt.values]
val = ''.join(vals)
return val
if isinstance(stmt, ast.Slice):
if not stmt.lower or not stmt.upper:
raise ValueError('Unhandled Slice w/o lower or upper.')
slice_lower = self._eval_statement(stmt.lower, ctx=ctx)
slice_upper = self._eval_statement(stmt.upper, ctx=ctx)
if stmt.step:
slice_step = self._eval_statement(stmt.step, ctx=ctx)
return slice(slice_lower, slice_upper, slice_step)
return slice(slice_lower, slice_upper)
if isinstance(stmt, ast.Name):
if stmt.id in ctx:
val = ctx[stmt.id]
return val
if stmt.id in _BUILT_INS:
val = _BUILT_INS[stmt.id]
return val
raise NameError(f"Name not found: {stmt.id}")
if isinstance(stmt, ast.For):
for_iter = self._eval_statement(stmt.iter, ctx=ctx)
for item in for_iter:
# Set the for var(s)
if isinstance(stmt.target, ast.Name):
ctx[stmt.target.id] = item
elif isinstance(stmt.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for i, elt in enumerate(stmt.target.elts):
ctx[elt.id] = item[i]
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
breaked = False
for body in bodies:
# Catch any breaks or continues and handle inside the loop normally.
try:
value = self._eval_statement(body, ctx=ctx)
except (LoopBreak, LoopContinue) as e:
breaked = isinstance(e, LoopBreak)
break
if breaked:
break
return None
if isinstance(stmt, ast.While):
while self._eval_statement(stmt.test, ctx=ctx):
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
breaked = False
for body in bodies:
# Catch any breaks or continues and handle inside the loop normally.
try:
value = self._eval_statement(body, ctx=ctx)
except (LoopBreak, LoopContinue) as e:
breaked = isinstance(e, LoopBreak)
break
if breaked:
break
return None
if isinstance(stmt, ast.ListComp):
# Like: [v.lora for name, v in node(19).inputs.items() if name.startswith('lora_')]
# Like: [v.lower() for v in lora_list]
# Like: [v for v in l if v.startswith('B')]
# Like: [v.lower() for v in l if v.startswith('B') or v.startswith('F')]
# ---
# Like: [l for n in nodes(re('Loras')).values() if (l := n.loras)]
final_list = []
gen_ctx = {**ctx}
generators = [*stmt.generators]
def handle_gen(generators: list[ast.comprehension]):
gen = generators.pop(0)
if isinstance(gen.target, ast.Name):
gen_ctx[gen.target.id] = None
elif isinstance(gen.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for elt in gen.target.elts:
gen_ctx[elt.id] = None
else:
raise ValueError('Na')
gen_iters = None
# A call, like my_dct.items(), or a named ctx list
if isinstance(gen.iter, ast.Call):
gen_iters = self._eval_statement(gen.iter, ctx=gen_ctx)
elif isinstance(gen.iter, (ast.Name, ast.Attribute, ast.List, ast.Tuple)):
gen_iters = self._eval_statement(gen.iter, ctx=gen_ctx)
if not isinstance(gen_iters, Iterable):
raise ValueError('No iteraors found for list comprehension')
for gen_iter in gen_iters:
if_ctx = {**gen_ctx}
if isinstance(gen.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for i, elt in enumerate(gen.target.elts):
if_ctx[elt.id] = gen_iter[i]
else:
if_ctx[gen.target.id] = gen_iter
good = True
for ifcall in gen.ifs:
if not self._eval_statement(ifcall, ctx=if_ctx):
good = False
break
if not good:
continue
gen_ctx.update(if_ctx)
if len(generators):
handle_gen(generators)
else:
final_list.append(self._eval_statement(stmt.elt, gen_ctx))
generators.insert(0, gen)
handle_gen(generators)
return final_list
if isinstance(stmt, ast.Call):
call = None
args = []
kwargs = {}
if isinstance(stmt.func, ast.Attribute):
call = self._eval_statement(stmt.func, prev_stmt=stmt, ctx=ctx)
if isinstance(call, tuple):
args.append(call[1])
call = call[0]
if not call:
raise ValueError(f'No call for ast.Call {stmt.func}')
name = ''
if isinstance(stmt.func, ast.Name):
name = stmt.func.id
if name in _BUILT_INS:
call = _BUILT_INS[name]
if isinstance(call, str) and call.startswith(_BUILTIN_FN_PREFIX):
fn = _get_built_in_fn_by_key(call)
call = fn.call
if isinstance(call, str):
call = getattr(self, call)
num_args = len(stmt.args)
if num_args < fn.args[0] or (fn.args[1] is not None and num_args > fn.args[1]):
toErr = " or more" if fn.args[1] is None else f" to {fn.args[1]}"
raise SyntaxError(f"Invalid function call: {fn.name} requires {fn.args[0]}{toErr} args")
if not call:
raise ValueError(f'No call for ast.Call {name}')
for arg in stmt.args:
args.append(self._eval_statement(arg, ctx=ctx))
for kwarg in stmt.keywords:
kwargs[kwarg.arg] = self._eval_statement(kwarg.value, ctx=ctx)
return call(*args, **kwargs)
if isinstance(stmt, ast.Compare):
l = self._eval_statement(stmt.left, ctx=ctx)
r = self._eval_statement(stmt.comparators[0], ctx=ctx)
if isinstance(stmt.ops[0], ast.Eq):
return 1 if l == r else 0
if isinstance(stmt.ops[0], ast.NotEq):
return 1 if l != r else 0
if isinstance(stmt.ops[0], ast.Gt):
return 1 if l > r else 0
if isinstance(stmt.ops[0], ast.GtE):
return 1 if l >= r else 0
if isinstance(stmt.ops[0], ast.Lt):
return 1 if l < r else 0
if isinstance(stmt.ops[0], ast.LtE):
return 1 if l <= r else 0
if isinstance(stmt.ops[0], ast.In):
return 1 if l in r else 0
if isinstance(stmt.ops[0], ast.Is):
return 1 if l is r else 0
if isinstance(stmt.ops[0], ast.IsNot):
return 1 if l is not r else 0
raise NotImplementedError("Operator " + stmt.ops[0].__class__.__name__ + " not supported.")
if isinstance(stmt, (ast.If, ast.IfExp)):
value = self._eval_statement(stmt.test, ctx=ctx)
if value:
# ast.If is a list, ast.IfExp is an object.
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
for body in bodies:
value = self._eval_statement(body, ctx=ctx)
elif stmt.orelse:
# ast.If is a list, ast.IfExp is an object. TBH, I don't know why the If is a list, it's
# only ever one item AFAICT.
orelses = stmt.orelse if isinstance(stmt.orelse, list) else [stmt.orelse]
for orelse in orelses:
value = self._eval_statement(orelse, ctx=ctx)
return value
# Assign a variable and add it to our ctx.
if isinstance(stmt, (ast.Assign, ast.AugAssign)):
if isinstance(stmt, ast.AugAssign):
left = self._eval_statement(stmt.target, ctx=ctx)
right = self._eval_statement(stmt.value, ctx=ctx)
value = _OPERATORS[type(stmt.op)](left, right)
target = stmt.target
else:
value = self._eval_statement(stmt.value, ctx=ctx)
if len(stmt.targets) != 1:
raise ValueError('Expected length of assign targets to be 1')
target = stmt.targets[0]
if isinstance(target, ast.Tuple): # like `a, z = (1,2)` (ast.Assign only)
for i, elt in enumerate(target.elts):
ctx[elt.id] = value[i]
elif isinstance(target, ast.Name): # like `a = 1``
ctx[target.id] = value
elif isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name): # `a[0] = 1`
ctx[target.value.id][self._eval_statement(target.slice, ctx=ctx)] = value
else:
raise ValueError('Unhandled target type for Assign.')
return value
# For assigning a var in a list comprehension.
# Like [name for node in node_list if (name := node.name)]
if isinstance(stmt, ast.NamedExpr):
value = self._eval_statement(stmt.value, ctx=ctx)
ctx[stmt.target.id] = value
return value
if isinstance(stmt, ast.Return):
if stmt.value is None:
value = None
else:
value = self._eval_statement(stmt.value, ctx=ctx)
# Mark that we have a return value, as we may be deeper in evaluation, like going through an
# if condition's body.
ctx['__returned__'] = value
return value
# Raise an error for break or continue, which should be caught and handled inside of loops,
# otherwise the error will be raised (which is desired when used outside of a loop).
if isinstance(stmt, ast.Break):
raise LoopBreak()
if isinstance(stmt, ast.Continue):
raise LoopContinue()
# Literally nothing.
if isinstance(stmt, ast.Pass):
return None
raise TypeError(stmt)

View File

@@ -0,0 +1,70 @@
import os
import re
import json
from .utils import set_dict_value
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_FILE_PY_PROJECT = os.path.join(_THIS_DIR, '..', 'pyproject.toml')
def read_pyproject():
"""Reads the pyproject.toml file"""
data = {}
last_key = ''
lines = []
# I'd like to use tomllib/tomli, but I'd much rather not introduce dependencies since I've yet to
# need to and not everyone may have 3.11. We've got a controlled config file anyway.
with open(_FILE_PY_PROJECT, "r", encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if re.match(r'\[([^\]]+)\]$', line):
last_key = line[1:-1]
set_dict_value(data, last_key, data[last_key] if last_key in data else {})
continue
value_matches = re.match(r'^([^\s\=]+)\s*=\s*(.*)$', line)
if value_matches:
try:
set_dict_value(data, f'{last_key}.{value_matches[1]}', json.loads(value_matches[2]))
except json.decoder.JSONDecodeError:
# We don't handle multiline arrays or curly brackets; that's ok, we know the file.
pass
return data
_DATA = read_pyproject()
# We would want these to fail if they don't exist, so assume they do.
VERSION: str = _DATA['project']['version']
NAME: str = _DATA['project']['name']
LOGO_URL: str = _DATA['tool']['comfy']['Icon']
if not LOGO_URL.endswith('.svg'):
raise ValueError('Bad logo url.')
LOGO_SVG = None
async def get_logo_svg():
import aiohttp
global LOGO_SVG
if LOGO_SVG is not None:
return LOGO_SVG
# Fetch the logo so we have any updated markup.
try:
async with aiohttp.ClientSession(
trust_env=True, connector=aiohttp.TCPConnector(verify_ssl=True)
) as session:
headers = {
"user-agent": f"rgthree-comfy/{VERSION}",
'Cache-Control': 'no-cache',
'Pragma': 'no-cache',
'Expires': '0'
}
async with session.get(LOGO_URL, headers=headers) as resp:
LOGO_SVG = await resp.text()
LOGO_SVG = re.sub(r'(id="bg".*fill=)"[^\"]+"', r'\1"{bg}"', LOGO_SVG)
LOGO_SVG = re.sub(r'(id="fg".*fill=)"[^\"]+"', r'\1"{fg}"', LOGO_SVG)
except Exception:
LOGO_SVG = '<svg></svg>'
return LOGO_SVG

View File

@@ -0,0 +1,63 @@
from nodes import EmptyLatentImage
from .constants import get_category, get_name
class RgthreeSDXLEmptyLatentImage:
NAME = get_name('SDXL Empty Latent Image')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"dimensions": (
[
# 'Custom',
'1536 x 640 (landscape)',
'1344 x 768 (landscape)',
'1216 x 832 (landscape)',
'1152 x 896 (landscape)',
'1024 x 1024 (square)',
' 896 x 1152 (portrait)',
' 832 x 1216 (portrait)',
' 768 x 1344 (portrait)',
' 640 x 1536 (portrait)',
],
{
"default": '1024 x 1024 (square)'
}),
"clip_scale": ("FLOAT", {
"default": 2.0,
"min": 1.0,
"max": 10.0,
"step": .5
}),
"batch_size": ("INT", {
"default": 1,
"min": 1,
"max": 64
}),
},
# "optional": {
# "custom_width": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 64}),
# "custom_height": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 64}),
# }
}
RETURN_TYPES = ("LATENT", "INT", "INT")
RETURN_NAMES = ("LATENT", "CLIP_WIDTH", "CLIP_HEIGHT")
FUNCTION = "generate"
def generate(self, dimensions, clip_scale, batch_size):
"""Generates the latent and exposes the clip_width and clip_height"""
if True:
result = [x.strip() for x in dimensions.split('x')]
width = int(result[0])
height = int(result[1].split(' ')[0])
latent = EmptyLatentImage().generate(width, height, batch_size)[0]
return (
latent,
int(width * clip_scale),
int(height * clip_scale),
)

View File

@@ -0,0 +1,178 @@
import os
import re
from nodes import MAX_RESOLUTION
from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL
from .log import log_node_warn, log_node_info, log_node_success
from .constants import get_category, get_name
from .power_prompt_utils import get_and_strip_loras
from nodes import LoraLoader, CLIPTextEncode
import folder_paths
NODE_NAME = get_name('SDXL Power Prompt - Positive')
class RgthreeSDXLPowerPromptPositive:
"""The Power Prompt for positive conditioning."""
NAME = NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
# Removed Saved Prompts feature; No sure it worked any longer. UI should fail gracefully,
# TODO: Rip out saved prompt input data
SAVED_PROMPTS_FILES=[]
SAVED_PROMPTS_CONTENT=[]
return {
'required': {
'prompt_g': ('STRING', {
'multiline': True,
'dynamicPrompts': True
}),
'prompt_l': ('STRING', {
'multiline': True,
'dynamicPrompts': True
}),
},
'optional': {
"opt_model": ("MODEL",),
"opt_clip": ("CLIP",),
"opt_clip_width": ("INT", {
"forceInput": True,
"default": 1024.0,
"min": 0,
"max": MAX_RESOLUTION
}),
"opt_clip_height": ("INT", {
"forceInput": True,
"default": 1024.0,
"min": 0,
"max": MAX_RESOLUTION
}),
'insert_lora': (['CHOOSE', 'DISABLE LORAS'] +
[os.path.splitext(x)[0] for x in folder_paths.get_filename_list('loras')],),
'insert_embedding': ([
'CHOOSE',
] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],),
'insert_saved': ([
'CHOOSE',
] + SAVED_PROMPTS_FILES,),
# We'll hide these in the UI for now.
"target_width": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"target_height": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"crop_width": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"crop_height": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
},
'hidden': {
'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,),
}
}
RETURN_TYPES = ('CONDITIONING', 'MODEL', 'CLIP', 'STRING', 'STRING')
RETURN_NAMES = ('CONDITIONING', 'MODEL', 'CLIP', 'TEXT_G', 'TEXT_L')
FUNCTION = 'main'
def main(self,
prompt_g,
prompt_l,
opt_model=None,
opt_clip=None,
opt_clip_width=None,
opt_clip_height=None,
insert_lora=None,
insert_embedding=None,
insert_saved=None,
target_width=-1,
target_height=-1,
crop_width=-1,
crop_height=-1,
values_insert_saved=None):
if insert_lora == 'DISABLE LORAS':
prompt_g, loras_g, _skipped, _unfound = get_and_strip_loras(prompt_g,
True,
log_node=self.NAME)
prompt_l, loras_l, _skipped, _unfound = get_and_strip_loras(prompt_l,
True,
log_node=self.NAME)
loras = loras_g + loras_l
log_node_info(
NODE_NAME,
f'Disabling all found loras ({len(loras)}) and stripping lora tags for TEXT output.')
elif opt_model is not None and opt_clip is not None:
prompt_g, loras_g, _skipped, _unfound = get_and_strip_loras(prompt_g, log_node=self.NAME)
prompt_l, loras_l, _skipped, _unfound = get_and_strip_loras(prompt_l, log_node=self.NAME)
loras = loras_g + loras_l
if len(loras) > 0:
for lora in loras:
opt_model, opt_clip = LoraLoader().load_lora(opt_model, opt_clip, lora['lora'],
lora['strength'], lora['strength'])
log_node_success(NODE_NAME, f'Loaded "{lora["lora"]}" from prompt')
log_node_info(NODE_NAME, f'{len(loras)} Loras processed; stripping tags for TEXT output.')
elif '<lora:' in prompt_g or '<lora:' in prompt_l:
_prompt_g, loras_g, _skipped, _unfound = get_and_strip_loras(prompt_g,
True,
log_node=self.NAME)
_prompt_l, loras_l, _skipped, _unfound = get_and_strip_loras(prompt_l,
True,
log_node=self.NAME)
loras = loras_g + loras_l
if len(loras):
log_node_warn(
NODE_NAME, f'Found {len(loras)} lora tags in prompt but model & clip were not supplied!')
log_node_info(NODE_NAME, 'Loras not processed, keeping for TEXT output.')
conditioning = self.get_conditioning(prompt_g, prompt_l, opt_clip, opt_clip_width,
opt_clip_height, target_width, target_height, crop_width,
crop_height)
return (conditioning, opt_model, opt_clip, prompt_g, prompt_l)
def get_conditioning(self, prompt_g, prompt_l, opt_clip, opt_clip_width, opt_clip_height,
target_width, target_height, crop_width, crop_height):
"""Checks the inputs and gets the conditioning."""
conditioning = None
if opt_clip is not None:
do_regular_clip_text_encode = opt_clip_width and opt_clip_height
if do_regular_clip_text_encode:
target_width = target_width if target_width and target_width > 0 else opt_clip_width
target_height = target_height if target_height and target_height > 0 else opt_clip_height
crop_width = crop_width if crop_width and crop_width > 0 else 0
crop_height = crop_height if crop_height and crop_height > 0 else 0
try:
conditioning = CLIPTextEncodeSDXL().encode(opt_clip, opt_clip_width, opt_clip_height,
crop_width, crop_height, target_width,
target_height, prompt_g, prompt_l)[0]
except Exception:
do_regular_clip_text_encode = True
log_node_info(
self.NAME,
'Exception while attempting to CLIPTextEncodeSDXL, will fall back to standard encoding.'
)
else:
log_node_info(
self.NAME,
'CLIP supplied, but not CLIP_WIDTH and CLIP_HEIGHT. Text encoding will use standard ' +
'encoding with prompt_g and prompt_l concatenated.')
if not do_regular_clip_text_encode:
conditioning = CLIPTextEncode().encode(
opt_clip, f'{prompt_g if prompt_g else ""}\n{prompt_l if prompt_l else ""}')[0]
return conditioning

View File

@@ -0,0 +1,106 @@
"""A simpler SDXL Power Prompt that doesn't load Loras, like for negative."""
import os
import re
import folder_paths
from nodes import MAX_RESOLUTION, LoraLoader
from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL
from .sdxl_power_prompt_postive import RgthreeSDXLPowerPromptPositive
from .log import log_node_warn, log_node_info, log_node_success
from .constants import get_category, get_name
NODE_NAME = get_name('SDXL Power Prompt - Simple / Negative')
class RgthreeSDXLPowerPromptSimple(RgthreeSDXLPowerPromptPositive):
"""A simpler SDXL Power Prompt that doesn't handle Loras."""
NAME = NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
# Removed Saved Prompts feature; No sure it worked any longer. UI should fail gracefully,
# TODO: Rip out saved prompt input data
SAVED_PROMPTS_FILES=[]
SAVED_PROMPTS_CONTENT=[]
return {
'required': {
'prompt_g': ('STRING', {
'multiline': True,
'dynamicPrompts': True
}),
'prompt_l': ('STRING', {
'multiline': True,
'dynamicPrompts': True
}),
},
'optional': {
"opt_clip": ("CLIP",),
"opt_clip_width": ("INT", {
"forceInput": True,
"default": 1024.0,
"min": 0,
"max": MAX_RESOLUTION
}),
"opt_clip_height": ("INT", {
"forceInput": True,
"default": 1024.0,
"min": 0,
"max": MAX_RESOLUTION
}),
'insert_embedding': ([
'CHOOSE',
] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],),
'insert_saved': ([
'CHOOSE',
] + SAVED_PROMPTS_FILES,),
# We'll hide these in the UI for now.
"target_width": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"target_height": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"crop_width": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
"crop_height": ("INT", {
"default": -1,
"min": -1,
"max": MAX_RESOLUTION
}),
},
'hidden': {
'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,),
}
}
RETURN_TYPES = ('CONDITIONING', 'STRING', 'STRING')
RETURN_NAMES = ('CONDITIONING', 'TEXT_G', 'TEXT_L')
FUNCTION = 'main'
def main(self,
prompt_g,
prompt_l,
opt_clip=None,
opt_clip_width=None,
opt_clip_height=None,
insert_embedding=None,
insert_saved=None,
target_width=-1,
target_height=-1,
crop_width=-1,
crop_height=-1,
values_insert_saved=None):
conditioning = self.get_conditioning(prompt_g, prompt_l, opt_clip, opt_clip_width,
opt_clip_height, target_width, target_height, crop_width, crop_height)
return (conditioning, prompt_g, prompt_l)

View File

@@ -0,0 +1,123 @@
"""See node."""
import random
from datetime import datetime
from .constants import get_category, get_name
from .log import log_node_warn, log_node_info
# Some extension must be setting a seed as server-generated seeds were not random. We'll set a new
# seed and use that state going forward.
initial_random_state = random.getstate()
random.seed(datetime.now().timestamp())
rgthree_seed_random_state = random.getstate()
random.setstate(initial_random_state)
def new_random_seed():
""" Gets a new random seed from the rgthree_seed_random_state and resetting the previous state."""
global rgthree_seed_random_state
prev_random_state = random.getstate()
random.setstate(rgthree_seed_random_state)
seed = random.randint(1, 1125899906842624)
rgthree_seed_random_state = random.getstate()
random.setstate(prev_random_state)
return seed
class RgthreeSeed:
"""Seed node."""
NAME = get_name('Seed')
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"seed": ("INT", {
"default": 0,
"min": -1125899906842624,
"max": 1125899906842624
}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("SEED",)
FUNCTION = "main"
@classmethod
def IS_CHANGED(cls, seed, prompt=None, extra_pnginfo=None, unique_id=None):
"""Forces a changed state if we happen to get a special seed, as if from the API directly."""
if seed in (-1, -2, -3):
# This isn't used, but a different value than previous will force it to be "changed"
return new_random_seed()
return seed
def main(self, seed=0, prompt=None, extra_pnginfo=None, unique_id=None):
"""Returns the passed seed on execution."""
# We generate random seeds on the frontend in the seed node before sending the workflow in for
# many reasons. However, if we want to use this in an API call without changing the seed before
# sending, then users _could_ pass in "-1" and get a random seed used and added to the metadata.
# Though, this should likely be discouraged for several reasons (thus, a lot of logging).
if seed in (-1, -2, -3):
log_node_warn(self.NAME,
f'Got "{seed}" as passed seed. ' +
'This shouldn\'t happen when queueing from the ComfyUI frontend.',
msg_color="YELLOW")
if seed in (-2, -3):
log_node_warn(self.NAME,
f'Cannot {"increment" if seed == -2 else "decrement"} seed from ' +
'server, but will generate a new random seed.',
msg_color="YELLOW")
original_seed = seed
seed = new_random_seed()
log_node_info(self.NAME, f'Server-generated random seed {seed} and saving to workflow.')
log_node_warn(
self.NAME,
'NOTE: Re-queues passing in "{seed}" and server-generated random seed won\'t be cached.',
msg_color="YELLOW")
if unique_id is None:
log_node_warn(
self.NAME, 'Cannot save server-generated seed to image metadata because ' +
'the node\'s id was not provided.')
else:
if extra_pnginfo is None:
log_node_warn(
self.NAME, 'Cannot save server-generated seed to image workflow ' +
'metadata because workflow was not provided.')
else:
workflow_node = next(
(x for x in extra_pnginfo['workflow']['nodes'] if str(x['id']) == str(unique_id)), None)
if workflow_node is None or 'widgets_values' not in workflow_node:
log_node_warn(
self.NAME, 'Cannot save server-generated seed to image workflow ' +
'metadata because node was not found in the provided workflow.')
else:
for index, widget_value in enumerate(workflow_node['widgets_values']):
if widget_value == original_seed:
workflow_node['widgets_values'][index] = seed
if prompt is None:
log_node_warn(
self.NAME, 'Cannot save server-generated seed to image API prompt ' +
'metadata because prompt was not provided.')
else:
prompt_node = prompt[str(unique_id)]
if prompt_node is None or 'inputs' not in prompt_node or 'seed' not in prompt_node[
'inputs']:
log_node_warn(
self.NAME, 'Cannot save server-generated seed to image workflow ' +
'metadata because node was not found in the provided workflow.')
else:
prompt_node['inputs']['seed'] = seed
return (seed,)

View File

@@ -0,0 +1,48 @@
import os
from aiohttp import web
from server import PromptServer
from ..config import get_config_value
from ..log import log
from .utils_server import set_default_page_resources, set_default_page_routes, get_param
from .routes_config import *
from .routes_model_info import *
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
DIR_WEB = os.path.abspath(f'{THIS_DIR}/../../web/')
routes = PromptServer.instance.routes
# Sometimes other pages (link_fixer, etc.) may want to import JS from the comfyui
# directory. To allows TS to resolve like '../comfyui/file.js', we'll also resolve any module HTTP
# to these routes.
set_default_page_resources("comfyui", routes)
set_default_page_resources("common", routes)
set_default_page_resources("lib", routes)
set_default_page_routes("link_fixer", routes)
if get_config_value('unreleased.models_page.enabled') is True:
set_default_page_routes("models", routes)
@routes.get('/rgthree/api/print')
async def api_print(request):
"""Logs a user message to the terminal."""
message_type = get_param(request, 'type')
if message_type == 'PRIMITIVE_REROUTE':
log(
"You are using rgthree-comfy reroutes with a ComfyUI Primitive node. Unfortunately, ComfyUI "
"has removed support for this. While rgthree-comfy has a best-effort support fallback for "
"now, it may no longer work as expected and is strongly recommended you either replace the "
"Reroute node using ComfyUI's reroute node, or refrain from using the Primitive node "
"(you can always use the rgthree-comfy \"Power Primitive\" for non-combo primitives).",
prefix="Reroute",
color="YELLOW",
id=message_type,
at_most_secs=20
)
else:
log("Unknown log type from api", prefix="rgthree-comfy",color ="YELLOW")
return web.json_response({})

View File

@@ -0,0 +1,67 @@
import json
import re
from aiohttp import web
from server import PromptServer
from ..pyproject import get_logo_svg
from .utils_server import is_param_truthy, get_param
from ..config import get_config, set_user_config, refresh_config
routes = PromptServer.instance.routes
@routes.get('/rgthree/config.js')
def api_get_user_config_file(request):
""" Returns the user configuration as a javascript file. """
data_str = json.dumps(get_config(), sort_keys=True, indent=2, separators=(",", ": "))
text = f'export const rgthreeConfig = {data_str}'
return web.Response(text=text, content_type='application/javascript')
@routes.get('/rgthree/api/config')
def api_get_user_config(request):
""" Returns the user configuration. """
if is_param_truthy(request, 'refresh'):
refresh_config()
return web.json_response(get_config())
@routes.post('/rgthree/api/config')
async def api_set_user_config(request):
""" Returns the user configuration. """
post = await request.post()
data = json.loads(post.get("json"))
set_user_config(data)
return web.json_response({"status": "ok"})
@routes.get('/rgthree/logo.svg')
async def get_logo(request, as_markup=False):
""" Returns the rgthree logo with color config. """
bg = get_param(request, 'bg', 'transparent')
fg = get_param(request, 'fg', '#111111')
w = get_param(request, 'w')
h = get_param(request, 'h')
css_class = get_param(request, 'cssClass')
svg = await get_logo_svg()
resp = svg.format(bg=bg, fg=fg)
if w is not None:
resp = re.sub(r'(<svg[^\>]*?)width="[^\"]+"', r'\1', resp)
if str(w).isnumeric():
resp = re.sub(r'<svg', f'<svg width="{w}"', resp)
if h is not None:
resp = re.sub(r'(<svg[^\>]*?)height="[^\"]+"', r'\1', resp)
if str(h).isnumeric():
resp = re.sub(r'<svg', f'<svg height="{h}"', resp)
if css_class is not None:
resp = re.sub(r'<svg', f'<svg class="{css_class}"', resp)
if as_markup:
resp = re.sub(r'^.*?<svg', r'<svg', resp, flags=re.DOTALL)
return web.Response(text=resp, content_type='image/svg+xml')
@routes.get('/rgthree/logo_markup.svg')
async def get_logo_markup(request):
""" Returns the rgthree logo svg markup (no doctype) with options. """
return await get_logo(request, as_markup=True)

View File

@@ -0,0 +1,199 @@
import os
import json
from aiohttp import web
from ..log import log
from server import PromptServer
import folder_paths
from ..utils import abspath, path_exists
from .utils_server import get_param, is_param_falsy
from .utils_info import delete_model_info, get_model_info, set_model_info_partial, get_file_info
routes = PromptServer.instance.routes
def _check_valid_model_type(request):
model_type = request.match_info['type']
if model_type not in ['loras', 'checkpoints']:
return web.json_response({'status': 404, 'error': f'Invalid model type: {model_type}'})
return None
@routes.get('/rgthree/api/{type}')
async def api_get_models_list(request):
"""Returns a list of model types from user configuration.
By default, a list of filenames are provided. If `format=details` is specified, a list of objects
with additional _file info_ is provided. This includes modigied time, hasInfoFile, and imageLocal
among others.
"""
if _check_valid_model_type(request):
return _check_valid_model_type(request)
model_type = request.match_info['type']
files = folder_paths.get_filename_list(model_type)
format_param = get_param(request, 'format')
if format_param == 'details':
response = []
bad_files_first = None
bad_files_num = 0
for file in files:
file_info = get_file_info(file, model_type)
# Some folks were seeing null in this list, which is odd since it's coming from ComfyUI files.
# See https://github.com/rgthree/rgthree-comfy/issues/574#issuecomment-3494629132 We'll check
# and log if we haven't found, maybe someone will have more info.
if file_info is not None:
response.append(file_info)
else:
bad_files_num += 1
if not bad_files_first:
bad_files_first = file
if bad_files_first:
log(
f"Couldn't get file info for {bad_files_first}"
f"{f' and {bad_files_num} other {model_type}.' if bad_files_num > 1 else '.'} "
"ComfyUI thinks they exist, but they were not found on the filesystem.",
prefix="Power Lora Loader",
color="YELLOW",
id=f'no_file_details_{model_type}',
at_most_secs=30
)
return web.json_response(response)
return web.json_response(list(files))
@routes.get('/rgthree/api/{type}/info')
async def api_get_models_info(request):
"""Returns a list model info; either all or a specific ones if provided a 'files' param.
If a `light` param is specified and not falsy, no metadata will be fetched.
"""
if _check_valid_model_type(request):
return _check_valid_model_type(request)
model_type = request.match_info['type']
files_param = get_param(request, 'files')
maybe_fetch_metadata = files_param is not None
if not is_param_falsy(request, 'light'):
maybe_fetch_metadata = False
api_response = await models_info_response(
request, model_type, maybe_fetch_metadata=maybe_fetch_metadata
)
return web.json_response(api_response)
@routes.get('/rgthree/api/{type}/info/refresh')
async def api_get_refresh_get_models_info(request):
"""Refreshes model info; either all or specific ones if provided a 'files' param. """
if _check_valid_model_type(request):
return _check_valid_model_type(request)
model_type = request.match_info['type']
api_response = await models_info_response(
request, model_type, maybe_fetch_civitai=True, maybe_fetch_metadata=True
)
return web.json_response(api_response)
@routes.get('/rgthree/api/{type}/info/clear')
async def api_get_delete_model_info(request):
"""Clears model info from the filesystem for the provided file."""
if _check_valid_model_type(request):
return _check_valid_model_type(request)
api_response = {'status': 200}
model_type = request.match_info['type']
files_param = get_param(request, 'files')
if files_param is not None:
files_param = files_param.split(',')
del_info = not is_param_falsy(request, 'del_info')
del_metadata = not is_param_falsy(request, 'del_metadata')
del_civitai = not is_param_falsy(request, 'del_civitai')
if not files_param:
api_response['status'] = '404'
api_response['error'] = f'No file provided. Please pass files=ALL to clear {model_type} info.'
else:
if len(files_param) == 1 and files_param[
0] == "ALL": # Force the user to supply files=ALL to trigger all clearing.
files_param = folder_paths.get_filename_list(model_type)
for file_param in files_param:
await delete_model_info(
file_param,
model_type,
del_info=del_info,
del_metadata=del_metadata,
del_civitai=del_civitai
)
return web.json_response(api_response)
@routes.post('/rgthree/api/{type}/info')
async def api_post_save_model_data(request):
"""Saves data to a model by name. """
if _check_valid_model_type(request):
return _check_valid_model_type(request)
model_type = request.match_info['type']
api_response = {'status': 200}
file_param = get_param(request, 'file')
if file_param is None:
api_response['status'] = '404'
api_response['error'] = 'No model found at path'
else:
post = await request.post()
await set_model_info_partial(file_param, model_type, json.loads(post.get("json")))
info_data = await get_model_info(file_param, model_type)
api_response['data'] = info_data
return web.json_response(api_response)
@routes.get('/rgthree/api/{type}/img')
async def api_get_models_info_img(request):
""" Returns an image response if one exists for the model. """
if _check_valid_model_type(request):
return _check_valid_model_type(request)
model_type = request.match_info['type']
file_param = get_param(request, 'file')
file_path = folder_paths.get_full_path(model_type, file_param)
if not path_exists(file_path):
file_path = abspath(file_path)
img_path = None
for ext in ['jpg', 'png', 'jpeg']:
try_path = f'{os.path.splitext(file_path)[0]}.{ext}'
if path_exists(try_path):
img_path = try_path
break
if not path_exists(img_path):
api_response = {}
api_response['status'] = '404'
api_response['error'] = 'No model found at path'
return web.json_response(api_response)
return web.FileResponse(img_path)
async def models_info_response(
request, model_type, maybe_fetch_civitai=False, maybe_fetch_metadata=False
):
"""Gets model info for all or a single model type."""
api_response = {'status': 200, 'data': []}
light = not is_param_falsy(request, 'light')
files_param = get_param(request, 'files')
if files_param is not None:
files_param = files_param.split(',')
else:
files_param = folder_paths.get_filename_list(model_type)
for file_param in files_param:
info_data = await get_model_info(
file_param,
model_type,
maybe_fetch_civitai=maybe_fetch_civitai,
maybe_fetch_metadata=maybe_fetch_metadata,
light=light
)
api_response['data'].append(info_data)
return api_response

View File

@@ -0,0 +1,452 @@
import hashlib
import json
import os
import re
from datetime import datetime
import requests
from server import PromptServer
import folder_paths
from ..utils import abspath, get_dict_value, load_json_file, file_exists, remove_path, save_json_file
from ..utils_userdata import read_userdata_json, save_userdata_json, delete_userdata_file
def _get_info_cache_file(data_type: str, file_hash: str):
return f'info/{file_hash}.{data_type}.json'
async def delete_model_info(
file: str, model_type, del_info=True, del_metadata=True, del_civitai=True
):
"""Delete the info json, and the civitai & metadata caches."""
file_path = get_folder_path(file, model_type)
if file_path is None:
return
if del_info:
remove_path(get_info_file(file_path))
if del_civitai or del_metadata:
file_hash = _get_sha256_hash(file_path)
if del_civitai:
json_file_path = _get_info_cache_file(file_hash, 'civitai')
delete_userdata_file(json_file_path)
if del_metadata:
json_file_path = _get_info_cache_file(file_hash, 'metadata')
delete_userdata_file(json_file_path)
def get_file_info(file: str, model_type):
"""Gets basic file info, like created or modified date."""
file_path = get_folder_path(file, model_type)
if file_path is None:
return None
return {
'file': file,
'path': file_path,
'modified': os.path.getmtime(file_path) * 1000, # millis
'imageLocal': f'/rgthree/api/{model_type}/img?file={file}' if get_img_file(file_path) else None,
'hasInfoFile': get_info_file(file_path) is not None,
}
def get_info_file(file_path: str, force=False):
# Try to load a rgthree-info.json file next to the file.
info_path = f'{file_path}.rgthree-info.json'
return info_path if file_exists(info_path) or force else None
def get_img_file(file_path: str, force=False):
for ext in ['jpg', 'png', 'jpeg', 'webp']:
try_path = f'{os.path.splitext(file_path)[0]}.{ext}'
if file_exists(try_path):
return try_path
def get_model_info_file_data(file: str, model_type, default=None):
"""Returns the data from the info file, or a default value if it doesn't exist."""
file_path = get_folder_path(file, model_type)
if file_path is None:
return default
return load_json_file(get_info_file(file_path), default=default)
async def get_model_info(
file: str,
model_type,
default=None,
maybe_fetch_civitai=False,
force_fetch_civitai=False,
maybe_fetch_metadata=False,
force_fetch_metadata=False,
light=False
):
"""Compiles a model info given a stored file next to the model, and/or metadata/civitai."""
file_path = get_folder_path(file, model_type)
if file_path is None:
return default
should_save = False
# basic data
basic_data = get_file_info(file, model_type)
# Try to load a rgthree-info.json file next to the file.
info_data = get_model_info_file_data(file, model_type, default={})
for key in ['file', 'path', 'modified', 'imageLocal', 'hasInfoFile']:
if key in basic_data and basic_data[key] and (
key not in info_data or info_data[key] != basic_data[key]
):
info_data[key] = basic_data[key]
should_save = True
# Check if we have an image next to the file and, if so, add it to the front of the images
# (if it isn't already).
img_next_to_file = basic_data['imageLocal']
if 'images' not in info_data:
info_data['images'] = []
should_save = True
if img_next_to_file:
if len(info_data['images']) == 0 or info_data['images'][0]['url'] != img_next_to_file:
info_data['images'].insert(0, {'url': img_next_to_file})
should_save = True
# If we just want light data then bail now with just existing data, plus file, path and img if
# next to the file.
if light and not maybe_fetch_metadata and not force_fetch_metadata and not maybe_fetch_civitai and not force_fetch_civitai:
return info_data
if 'raw' not in info_data:
info_data['raw'] = {}
should_save = True
should_save = _update_data(info_data) or should_save
should_fetch_civitai = force_fetch_civitai is True or (
maybe_fetch_civitai is True and 'civitai' not in info_data['raw']
)
should_fetch_metadata = force_fetch_metadata is True or (
maybe_fetch_metadata is True and 'metadata' not in info_data['raw']
)
if should_fetch_metadata:
data_meta = _get_model_metadata(file, model_type, default={}, refresh=force_fetch_metadata)
should_save = _merge_metadata(info_data, data_meta) or should_save
if should_fetch_civitai:
data_civitai = _get_model_civitai_data(
file, model_type, default={}, refresh=force_fetch_civitai
)
should_save = _merge_civitai_data(info_data, data_civitai) or should_save
if 'sha256' not in info_data:
file_hash = _get_sha256_hash(file_path)
if file_hash is not None:
info_data['sha256'] = file_hash
should_save = True
if should_save:
if 'trainedWords' in info_data:
# Sort by count; if it doesn't exist, then assume it's a top item from civitai or elsewhere.
info_data['trainedWords'] = sorted(
info_data['trainedWords'],
key=lambda w: w['count'] if 'count' in w else 99999,
reverse=True
)
save_model_info(file, info_data, model_type)
# If we're saving, then the UI is likely waiting to see if the refreshed data is coming in.
await PromptServer.instance.send(f"rgthree-refreshed-{model_type}-info", {"data": info_data})
return info_data
def _update_data(info_data: dict) -> bool:
"""Ports old data to new data if necessary."""
should_save = False
# If we have "triggerWords" then move them over to "trainedWords"
if 'triggerWords' in info_data and len(info_data['triggerWords']) > 0:
civitai_words = ','.join((
get_dict_value(info_data, 'raw.civitai.triggerWords', default=[]) +
get_dict_value(info_data, 'raw.civitai.trainedWords', default=[])
))
if 'trainedWords' not in info_data:
info_data['trainedWords'] = []
for trigger_word in info_data['triggerWords']:
word_data = next((data for data in info_data['trainedWords'] if data['word'] == trigger_word),
None)
if word_data is None:
word_data = {'word': trigger_word}
info_data['trainedWords'].append(word_data)
if trigger_word in civitai_words:
word_data['civitai'] = True
else:
word_data['user'] = True
del info_data['triggerWords']
should_save = True
return should_save
def _merge_metadata(info_data: dict, data_meta: dict) -> bool:
"""Returns true if data was saved."""
should_save = False
base_model_file = get_dict_value(data_meta, 'ss_sd_model_name', None)
if base_model_file:
info_data['baseModelFile'] = base_model_file
# Loop over metadata tags
trained_words = {}
if 'ss_tag_frequency' in data_meta and isinstance(data_meta['ss_tag_frequency'], dict):
for bucket_value in data_meta['ss_tag_frequency'].values():
if isinstance(bucket_value, dict):
for tag, count in bucket_value.items():
if tag not in trained_words:
trained_words[tag] = {'word': tag, 'count': 0, 'metadata': True}
trained_words[tag]['count'] = trained_words[tag]['count'] + count
if 'trainedWords' not in info_data:
info_data['trainedWords'] = list(trained_words.values())
should_save = True
else:
# We can't merge, because the list may have other data, like it's part of civitaidata.
merged_dict = {}
for existing_word_data in info_data['trainedWords']:
merged_dict[existing_word_data['word']] = existing_word_data
for new_key, new_word_data in trained_words.items():
if new_key not in merged_dict:
merged_dict[new_key] = {}
merged_dict[new_key] = {**merged_dict[new_key], **new_word_data}
info_data['trainedWords'] = list(merged_dict.values())
should_save = True
# trained_words = list(trained_words.values())
# info_data['meta_trained_words'] = trained_words
info_data['raw']['metadata'] = data_meta
should_save = True
if 'sha256' not in info_data and '_sha256' in data_meta:
info_data['sha256'] = data_meta['_sha256']
should_save = True
return should_save
def _merge_civitai_data(info_data: dict, data_civitai: dict) -> bool:
"""Returns true if data was saved."""
should_save = False
if 'name' not in info_data:
info_data['name'] = get_dict_value(data_civitai, 'model.name', '')
should_save = True
version_name = get_dict_value(data_civitai, 'name')
if version_name is not None:
info_data['name'] += f' - {version_name}'
if 'type' not in info_data:
info_data['type'] = get_dict_value(data_civitai, 'model.type')
should_save = True
if 'baseModel' not in info_data:
info_data['baseModel'] = get_dict_value(data_civitai, 'baseModel')
should_save = True
# We always want to merge triggerword.
civitai_trigger = get_dict_value(data_civitai, 'triggerWords', default=[])
civitai_trained = get_dict_value(data_civitai, 'trainedWords', default=[])
civitai_words = ','.join(civitai_trigger + civitai_trained)
if civitai_words:
civitai_words = re.sub(r"\s*,\s*", ",", civitai_words)
civitai_words = re.sub(r",+", ",", civitai_words)
civitai_words = re.sub(r"^,", "", civitai_words)
civitai_words = re.sub(r",$", "", civitai_words)
if civitai_words:
civitai_words = civitai_words.split(',')
if 'trainedWords' not in info_data:
info_data['trainedWords'] = []
for trigger_word in civitai_words:
word_data = next(
(data for data in info_data['trainedWords'] if data['word'] == trigger_word), None
)
if word_data is None:
word_data = {'word': trigger_word}
info_data['trainedWords'].append(word_data)
word_data['civitai'] = True
if 'sha256' not in info_data:
info_data['sha256'] = data_civitai['_sha256']
should_save = True
if 'modelId' in data_civitai:
info_data['links'] = info_data['links'] if 'links' in info_data else []
civitai_link = f'https://civitai.com/models/{get_dict_value(data_civitai, "modelId")}'
if get_dict_value(data_civitai, "id"):
civitai_link += f'?modelVersionId={get_dict_value(data_civitai, "id")}'
info_data['links'].append(civitai_link)
info_data['links'].append(data_civitai['_civitai_api'])
should_save = True
# Take images from civitai
if 'images' in data_civitai:
info_data_image_urls = list(
map(lambda i: i['url'] if 'url' in i else None, info_data['images'])
)
for img in data_civitai['images']:
img_url = get_dict_value(img, 'url')
if img_url is not None and img_url not in info_data_image_urls:
img_id = os.path.splitext(os.path.basename(img_url))[0] if img_url is not None else None
img_data = {
'url': img_url,
'civitaiUrl': f'https://civitai.com/images/{img_id}' if img_id is not None else None,
'width': get_dict_value(img, 'width'),
'height': get_dict_value(img, 'height'),
'type': get_dict_value(img, 'type'),
'nsfwLevel': get_dict_value(img, 'nsfwLevel'),
'seed': get_dict_value(img, 'meta.seed'),
'positive': get_dict_value(img, 'meta.prompt'),
'negative': get_dict_value(img, 'meta.negativePrompt'),
'steps': get_dict_value(img, 'meta.steps'),
'sampler': get_dict_value(img, 'meta.sampler'),
'cfg': get_dict_value(img, 'meta.cfgScale'),
'model': get_dict_value(img, 'meta.Model'),
'resources': get_dict_value(img, 'meta.resources'),
}
info_data['images'].append(img_data)
should_save = True
# The raw data
if 'civitai' not in info_data['raw']:
info_data['raw']['civitai'] = data_civitai
should_save = True
return should_save
def _get_model_civitai_data(file: str, model_type, default=None, refresh=False):
"""Gets the civitai data, either cached from the user directory, or from civitai api."""
file_hash = _get_sha256_hash(get_folder_path(file, model_type))
if file_hash is None:
return None
json_file_path = _get_info_cache_file(file_hash, 'civitai')
api_url = f'https://civitai.com/api/v1/model-versions/by-hash/{file_hash}'
file_data = read_userdata_json(json_file_path)
if file_data is None or refresh is True:
try:
response = requests.get(api_url, timeout=5000)
data = response.json()
save_userdata_json(
json_file_path, {
'url': api_url,
'timestamp': datetime.now().timestamp(),
'response': data
}
)
file_data = read_userdata_json(json_file_path)
except requests.exceptions.RequestException as e: # This is the correct syntax
print(e)
response = file_data['response'] if file_data is not None and 'response' in file_data else None
if response is not None:
response['_sha256'] = file_hash
response['_civitai_api'] = api_url
return response if response is not None else default
def _get_model_metadata(file: str, model_type, default=None, refresh=False):
"""Gets the metadata from the file itself."""
file_path = get_folder_path(file, model_type)
file_hash = _get_sha256_hash(file_path)
if file_hash is None:
return default
json_file_path = _get_info_cache_file(file_hash, 'metadata')
file_data = read_userdata_json(json_file_path)
if file_data is None or refresh is True:
data = _read_file_metadata_from_header(file_path)
if data is not None:
file_data = {'url': file, 'timestamp': datetime.now().timestamp(), 'response': data}
save_userdata_json(json_file_path, file_data)
response = file_data['response'] if file_data is not None and 'response' in file_data else None
if response is not None:
response['_sha256'] = file_hash
return response if response is not None else default
def _read_file_metadata_from_header(file_path: str) -> dict:
"""Reads the file's header and returns a JSON dict metdata if available."""
data = None
try:
if file_path.endswith('.safetensors'):
with open(file_path, "rb") as file:
# https://github.com/huggingface/safetensors#format
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
header_size = int.from_bytes(file.read(8), "little", signed=False)
if header_size <= 0:
raise BufferError("Invalid header size")
header = file.read(header_size)
if header is None:
raise BufferError("Invalid header")
header_json = json.loads(header)
data = header_json["__metadata__"] if "__metadata__" in header_json else None
if data is not None:
for key, value in data.items():
if isinstance(value, str) and value.startswith('{') and value.endswith('}'):
try:
value_as_json = json.loads(value)
data[key] = value_as_json
except Exception:
print(f'metdata for field {key} did not parse as json')
except requests.exceptions.RequestException as e:
print(e)
data = None
return data
def get_folder_path(file: str, model_type) -> str | None:
"""Gets the file path ensuring it exists."""
file_path = folder_paths.get_full_path(model_type, file)
if not file_exists(file_path):
file_path = abspath(file_path)
if not file_exists(file_path):
file_path = None
return file_path
def _get_sha256_hash(file_path: str | None):
"""Returns the hash for the file."""
if not file_path or not file_exists(file_path):
return None
BUF_SIZE = 1024 * 128 # lets read stuff in 64kb chunks!
file_hash = None
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
# Read and update hash string value in blocks of BUF_SIZE
for byte_block in iter(lambda: f.read(BUF_SIZE), b""):
sha256_hash.update(byte_block)
file_hash = sha256_hash.hexdigest()
return file_hash
async def set_model_info_partial(file: str, model_type: str, info_data_partial):
"""Sets partial data into the existing model info data."""
info_data = await get_model_info(file, model_type, default={})
info_data = {**info_data, **info_data_partial}
save_model_info(file, info_data, model_type)
def save_model_info(file: str, info_data, model_type):
"""Saves the model info alongside the model itself."""
file_path = get_folder_path(file, model_type)
if file_path is None:
return
info_path = get_info_file(file_path, force=True)
save_json_file(info_path, info_data)

View File

@@ -0,0 +1,56 @@
import os
from aiohttp import web
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
DIR_WEB = os.path.abspath(f'{THIS_DIR}/../../web/')
def get_param(request, param, default=None):
"""Gets a param from a request."""
return request.rel_url.query[param] if param in request.rel_url.query else default
def is_param_falsy(request, param):
"""Determines if a param is explicitly 0 or false."""
val = get_param(request, param)
return val is not None and (val == "0" or val.upper() == "FALSE")
def is_param_truthy(request, param):
"""Determines if a param is explicitly 0 or false."""
val = get_param(request, param)
return val is not None and not is_param_falsy(request, param)
def set_default_page_resources(path, routes):
""" Sets up routes for handling static files under a path."""
@routes.get(f'/rgthree/{path}/{{file}}')
async def get_resource(request):
""" Returns a resource file. """
return web.FileResponse(os.path.join(DIR_WEB, path, request.match_info['file']))
@routes.get(f'/rgthree/{path}/{{subdir}}/{{file}}')
async def get_resource_subdir(request):
""" Returns a resource file. """
return web.FileResponse(
os.path.join(DIR_WEB, path, request.match_info['subdir'], request.match_info['file']))
def set_default_page_routes(path, routes):
""" Sets default path handling for a hosted rgthree page. """
@routes.get(f'/rgthree/{path}')
async def get_path_redir(request):
""" Redirects to the path adding a trailing slash. """
raise web.HTTPFound(f'{request.path}/')
@routes.get(f'/rgthree/{path}/')
async def get_path_index(request):
""" Handles the page's index loading. """
html = ''
with open(os.path.join(DIR_WEB, path, 'index.html'), 'r', encoding='UTF-8') as file:
html = file.read()
return web.Response(text=html, content_type='text/html')
set_default_page_resources(path, routes)

View File

@@ -0,0 +1,168 @@
import json
import os
import re
from typing import Union
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
class FlexibleOptionalInputType(dict):
"""A special class to make flexible nodes that pass data to our python handlers.
Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs
(like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc).
Initially, ComfyUI only needed to return True for `__contains__` below, which told ComfyUI that
our node will handle the input, regardless of what it is.
However, after https://github.com/comfyanonymous/ComfyUI/pull/2666 ComdyUI's execution changed
also checking the data for the key; specifcially, the type which is the first tuple entry. This
type is supplied to our FlexibleOptionalInputType and returned for any non-data key. This can be a
real type, or use the AnyType for additional flexibility.
"""
def __init__(self, type, data: Union[dict, None] = None):
"""Initializes the FlexibleOptionalInputType.
Args:
type: The flexible type to use when ComfyUI retrieves an unknown key (via `__getitem__`).
data: An optional dict to use as the basis. This is stored both in a `data` attribute, so we
can look it up without hitting our overrides, as well as iterated over and adding its key
and values to our `self` keys. This way, when looked at, we will appear to represent this
data. When used in an "optional" INPUT_TYPES, these are the starting optional node types.
"""
self.type = type
self.data = data
if self.data is not None:
for k, v in self.data.items():
self[k] = v
def __getitem__(self, key):
# If we have this key in the initial data, then return it. Otherwise return the tuple with our
# flexible type.
if self.data is not None and key in self.data:
val = self.data[key]
return val
return (self.type,)
def __contains__(self, key):
"""Always contain a key, and we'll always return the tuple above when asked for it."""
return True
any_type = AnyType("*")
def is_dict_value_falsy(data: dict, dict_key: str):
"""Checks if a dict value is falsy."""
val = get_dict_value(data, dict_key)
return not val
def get_dict_value(data: dict, dict_key: str, default=None):
"""Gets a deeply nested value given a dot-delimited key."""
keys = dict_key.split('.')
key = keys.pop(0) if len(keys) > 0 else None
found = data[key] if key in data else None
if found is not None and len(keys) > 0:
return get_dict_value(found, '.'.join(keys), default)
return found if found is not None else default
def set_dict_value(data: dict, dict_key: str, value, create_missing_objects=True):
"""Sets a deeply nested value given a dot-delimited key."""
keys = dict_key.split('.')
key = keys.pop(0) if len(keys) > 0 else None
if key not in data:
if create_missing_objects is False:
return data
data[key] = {}
if len(keys) == 0:
data[key] = value
else:
set_dict_value(data[key], '.'.join(keys), value, create_missing_objects)
return data
def dict_has_key(data: dict, dict_key):
"""Checks if a dict has a deeply nested dot-delimited key."""
keys = dict_key.split('.')
key = keys.pop(0) if len(keys) > 0 else None
if key is None or key not in data:
return False
if len(keys) == 0:
return True
return dict_has_key(data[key], '.'.join(keys))
def load_json_file(file: str, default=None):
"""Reads a json file and returns the json dict, stripping out "//" comments first."""
if path_exists(file):
with open(file, 'r', encoding='UTF-8') as file:
config = file.read()
try:
return json.loads(config)
except json.decoder.JSONDecodeError:
try:
config = re.sub(r"^\s*//\s.*", "", config, flags=re.MULTILINE)
return json.loads(config)
except json.decoder.JSONDecodeError:
try:
config = re.sub(r"(?:^|\s)//.*", "", config, flags=re.MULTILINE)
return json.loads(config)
except json.decoder.JSONDecodeError:
pass
return default
def save_json_file(file_path: str, data: dict):
"""Saves a json file."""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w+', encoding='UTF-8') as file:
json.dump(data, file, sort_keys=False, indent=2, separators=(",", ": "))
def path_exists(path):
"""Checks if a path exists, accepting None type."""
if path is not None:
return os.path.exists(path)
return False
def file_exists(path):
"""Checks if a file exists, accepting None type."""
if path is not None:
return os.path.isfile(path)
return False
def remove_path(path):
"""Removes a path, if it exists."""
if path_exists(path):
os.remove(path)
return True
return False
def abspath(file_path: str):
"""Resolves the abspath of a file, resolving symlinks and user dirs."""
if file_path and not path_exists(file_path):
maybe_path = os.path.abspath(os.path.realpath(os.path.expanduser(file_path)))
file_path = maybe_path if path_exists(maybe_path) else file_path
return file_path
class ByPassTypeTuple(tuple):
"""A special class that will return additional "AnyType" strings beyond defined values.
Credit to Trung0246
"""
def __getitem__(self, index):
if index > len(self) - 1:
return AnyType("*")
return super().__getitem__(index)

View File

@@ -0,0 +1,50 @@
import os
from .utils import load_json_file, path_exists, save_json_file
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
USERDATA = os.path.join(THIS_DIR, '..', 'userdata')
def read_userdata_file(rel_path: str):
"""Reads a file from the userdata directory."""
file_path = clean_path(rel_path)
if path_exists(file_path):
with open(file_path, 'r', encoding='UTF-8') as file:
return file.read()
return None
def save_userdata_file(rel_path: str, content: str):
"""Saves a file from the userdata directory."""
file_path = clean_path(rel_path)
with open(file_path, 'w+', encoding='UTF-8') as file:
file.write(content)
def delete_userdata_file(rel_path: str):
"""Deletes a file from the userdata directory."""
file_path = clean_path(rel_path)
if os.path.isfile(file_path):
os.remove(file_path)
def read_userdata_json(rel_path: str):
"""Reads a json file from the userdata directory."""
file_path = clean_path(rel_path)
return load_json_file(file_path)
def save_userdata_json(rel_path: str, data: dict):
"""Saves a json file from the userdata directory."""
file_path = clean_path(rel_path)
return save_json_file(file_path, data)
def clean_path(rel_path: str):
"""Cleans a relative path by splitting on forward slash and os.path.joining."""
cleaned = USERDATA
paths = rel_path.split('/')
for path in paths:
cleaned = os.path.join(cleaned, path)
return cleaned