Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
822
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/__init__.py
Normal file
822
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/__init__.py
Normal file
@@ -0,0 +1,822 @@
|
||||
#credit to nullquant for this module
|
||||
#from https://github.com/nullquant/ComfyUI-BrushNet
|
||||
|
||||
import os
|
||||
import types
|
||||
|
||||
import torch
|
||||
try:
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
except:
|
||||
init_empty_weights, load_checkpoint_and_dispatch = None, None
|
||||
|
||||
import comfy
|
||||
|
||||
try:
|
||||
from .model import BrushNetModel, PowerPaintModel
|
||||
from .model_patch import add_model_patch_option, patch_model_function_wrapper
|
||||
from .powerpaint_utils import TokenizerWrapper, add_tokens
|
||||
except:
|
||||
BrushNetModel, PowerPaintModel = None, None
|
||||
add_model_patch_option, patch_model_function_wrapper = None, None
|
||||
TokenizerWrapper, add_tokens = None, None
|
||||
|
||||
cwd_path = os.path.dirname(os.path.realpath(__file__))
|
||||
brushnet_config_file = os.path.join(cwd_path, 'config', 'brushnet.json')
|
||||
brushnet_xl_config_file = os.path.join(cwd_path, 'config', 'brushnet_xl.json')
|
||||
powerpaint_config_file = os.path.join(cwd_path, 'config', 'powerpaint.json')
|
||||
|
||||
sd15_scaling_factor = 0.18215
|
||||
sdxl_scaling_factor = 0.13025
|
||||
|
||||
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL]
|
||||
|
||||
class BrushNet:
|
||||
|
||||
# Check models compatibility
|
||||
def check_compatibilty(self, model, brushnet):
|
||||
is_SDXL = False
|
||||
is_PP = False
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
print('Base model type: SD1.5')
|
||||
is_SDXL = False
|
||||
if brushnet["SDXL"]:
|
||||
raise Exception("Base model is SD15, but BrushNet is SDXL type")
|
||||
if brushnet["PP"]:
|
||||
is_PP = True
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
print('Base model type: SDXL')
|
||||
is_SDXL = True
|
||||
if not brushnet["SDXL"]:
|
||||
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
|
||||
|
||||
return (is_SDXL, is_PP)
|
||||
|
||||
def check_image_mask(self, image, mask, name):
|
||||
if len(image.shape) < 4:
|
||||
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
||||
image = image[None, :, :, :]
|
||||
|
||||
if len(mask.shape) > 3:
|
||||
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
|
||||
# take first mask, red channel
|
||||
mask = (mask[:, :, :, 0])[:, :, :]
|
||||
elif len(mask.shape) < 3:
|
||||
# mask tensor shape should be [B, H, W] but batch somehow is missing
|
||||
mask = mask[None, :, :]
|
||||
|
||||
if image.shape[0] > mask.shape[0]:
|
||||
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
|
||||
if mask.shape[0] == 1:
|
||||
print(name, "will copy the mask to fill batch")
|
||||
mask = torch.cat([mask] * image.shape[0], dim=0)
|
||||
else:
|
||||
print(name, "will add empty masks to fill batch")
|
||||
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
|
||||
mask = torch.cat([mask, empty_mask], dim=0)
|
||||
elif image.shape[0] < mask.shape[0]:
|
||||
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
|
||||
mask = mask[:image.shape[0], :, :]
|
||||
|
||||
return (image, mask)
|
||||
|
||||
# Prepare image and mask
|
||||
def prepare_image(self, image, mask):
|
||||
|
||||
image, mask = self.check_image_mask(image, mask, 'BrushNet')
|
||||
|
||||
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
|
||||
|
||||
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
|
||||
raise Exception("Image and mask should be the same size")
|
||||
|
||||
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
|
||||
mask = mask.round()
|
||||
|
||||
masked_image = image * (1.0 - mask[:, :, :, None])
|
||||
|
||||
return (masked_image, mask)
|
||||
|
||||
# Get origin of the mask
|
||||
def cut_with_mask(self, mask, width, height):
|
||||
iy, ix = (mask == 1).nonzero(as_tuple=True)
|
||||
|
||||
h0, w0 = mask.shape
|
||||
|
||||
if iy.numel() == 0:
|
||||
x_c = w0 / 2.0
|
||||
y_c = h0 / 2.0
|
||||
else:
|
||||
x_min = ix.min().item()
|
||||
x_max = ix.max().item()
|
||||
y_min = iy.min().item()
|
||||
y_max = iy.max().item()
|
||||
|
||||
if x_max - x_min > width or y_max - y_min > height:
|
||||
raise Exception("Mask is bigger than provided dimensions")
|
||||
|
||||
x_c = (x_min + x_max) / 2.0
|
||||
y_c = (y_min + y_max) / 2.0
|
||||
|
||||
width2 = width / 2.0
|
||||
height2 = height / 2.0
|
||||
|
||||
if w0 <= width:
|
||||
x0 = 0
|
||||
w = w0
|
||||
else:
|
||||
x0 = max(0, x_c - width2)
|
||||
w = width
|
||||
if x0 + width > w0:
|
||||
x0 = w0 - width
|
||||
|
||||
if h0 <= height:
|
||||
y0 = 0
|
||||
h = h0
|
||||
else:
|
||||
y0 = max(0, y_c - height2)
|
||||
h = height
|
||||
if y0 + height > h0:
|
||||
y0 = h0 - height
|
||||
|
||||
return (int(x0), int(y0), int(w), int(h))
|
||||
|
||||
# Prepare conditioning_latents
|
||||
@torch.inference_mode()
|
||||
def get_image_latents(self, masked_image, mask, vae, scaling_factor):
|
||||
processed_image = masked_image.to(vae.device)
|
||||
image_latents = vae.encode(processed_image[:, :, :, :3]) * scaling_factor
|
||||
processed_mask = 1. - mask[:, None, :, :]
|
||||
interpolated_mask = torch.nn.functional.interpolate(
|
||||
processed_mask,
|
||||
size=(
|
||||
image_latents.shape[-2],
|
||||
image_latents.shape[-1]
|
||||
)
|
||||
)
|
||||
interpolated_mask = interpolated_mask.to(image_latents.device)
|
||||
|
||||
conditioning_latents = [image_latents, interpolated_mask]
|
||||
|
||||
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =',
|
||||
interpolated_mask.shape)
|
||||
|
||||
return conditioning_latents
|
||||
|
||||
def brushnet_blocks(self, sd):
|
||||
brushnet_down_block = 0
|
||||
brushnet_mid_block = 0
|
||||
brushnet_up_block = 0
|
||||
for key in sd:
|
||||
if 'brushnet_down_block' in key:
|
||||
brushnet_down_block += 1
|
||||
if 'brushnet_mid_block' in key:
|
||||
brushnet_mid_block += 1
|
||||
if 'brushnet_up_block' in key:
|
||||
brushnet_up_block += 1
|
||||
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
|
||||
|
||||
def get_model_type(self, brushnet_file):
|
||||
sd = comfy.utils.load_torch_file(brushnet_file)
|
||||
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = self.brushnet_blocks(sd)
|
||||
del sd
|
||||
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
|
||||
is_SDXL = False
|
||||
if keys == 322:
|
||||
is_PP = False
|
||||
print('BrushNet model type: SD1.5')
|
||||
else:
|
||||
is_PP = True
|
||||
print('PowerPaint model type: SD1.5')
|
||||
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
|
||||
print('BrushNet model type: Loading SDXL')
|
||||
is_SDXL = True
|
||||
is_PP = False
|
||||
else:
|
||||
raise Exception("Unknown BrushNet model")
|
||||
return is_SDXL, is_PP
|
||||
|
||||
def load_brushnet_model(self, brushnet_file, dtype='float16'):
|
||||
is_SDXL, is_PP = self.get_model_type(brushnet_file)
|
||||
with init_empty_weights():
|
||||
if is_SDXL:
|
||||
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
|
||||
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
||||
elif is_PP:
|
||||
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
|
||||
brushnet_model = PowerPaintModel.from_config(brushnet_config)
|
||||
else:
|
||||
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
|
||||
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
||||
if is_PP:
|
||||
print("PowerPaint model file:", brushnet_file)
|
||||
else:
|
||||
print("BrushNet model file:", brushnet_file)
|
||||
|
||||
if dtype == 'float16':
|
||||
torch_dtype = torch.float16
|
||||
elif dtype == 'bfloat16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == 'float32':
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
torch_dtype = torch.float64
|
||||
|
||||
brushnet_model = load_checkpoint_and_dispatch(
|
||||
brushnet_model,
|
||||
brushnet_file,
|
||||
device_map="sequential",
|
||||
max_memory=None,
|
||||
offload_folder=None,
|
||||
offload_state_dict=False,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=False,
|
||||
)
|
||||
|
||||
if is_PP:
|
||||
print("PowerPaint model is loaded")
|
||||
elif is_SDXL:
|
||||
print("BrushNet SDXL model is loaded")
|
||||
else:
|
||||
print("BrushNet SD1.5 model is loaded")
|
||||
|
||||
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype},)
|
||||
|
||||
def brushnet_model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
|
||||
|
||||
is_SDXL, is_PP = self.check_compatibilty(model, brushnet)
|
||||
|
||||
if is_PP:
|
||||
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
|
||||
|
||||
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
||||
model = model.clone()
|
||||
|
||||
# prepare image and mask
|
||||
# no batches for original image and mask
|
||||
masked_image, mask = self.prepare_image(image, mask)
|
||||
|
||||
batch = masked_image.shape[0]
|
||||
width = masked_image.shape[2]
|
||||
height = masked_image.shape[1]
|
||||
|
||||
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
|
||||
'scale_factor'):
|
||||
scaling_factor = model.model.model_config.latent_format.scale_factor
|
||||
elif is_SDXL:
|
||||
scaling_factor = sdxl_scaling_factor
|
||||
else:
|
||||
scaling_factor = sd15_scaling_factor
|
||||
|
||||
torch_dtype = brushnet['dtype']
|
||||
|
||||
# prepare conditioning latents
|
||||
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
|
||||
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
|
||||
# unload vae
|
||||
del vae
|
||||
# for loaded_model in comfy.model_management.current_loaded_models:
|
||||
# if type(loaded_model.model.model) in ModelsToUnload:
|
||||
# comfy.model_management.current_loaded_models.remove(loaded_model)
|
||||
# loaded_model.model_unload()
|
||||
# del loaded_model
|
||||
|
||||
# prepare embeddings
|
||||
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
|
||||
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
||||
if prompt_embeds.shape[1] < max_tokens:
|
||||
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
|
||||
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:, -77:, :]] * multiplier, dim=1)
|
||||
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape,
|
||||
'multiplying prompt_embeds')
|
||||
if negative_prompt_embeds.shape[1] < max_tokens:
|
||||
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
|
||||
negative_prompt_embeds = torch.concat(
|
||||
[negative_prompt_embeds] + [negative_prompt_embeds[:, -77:, :]] * multiplier, dim=1)
|
||||
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape,
|
||||
'multiplying negative_prompt_embeds')
|
||||
|
||||
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
|
||||
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
else:
|
||||
print('BrushNet: positive conditioning has not pooled_output')
|
||||
if is_SDXL:
|
||||
print('BrushNet will not produce correct results')
|
||||
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
||||
|
||||
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
|
||||
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(
|
||||
brushnet['brushnet'].device)
|
||||
else:
|
||||
print('BrushNet: negative conditioning has not pooled_output')
|
||||
if is_SDXL:
|
||||
print('BrushNet will not produce correct results')
|
||||
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]],
|
||||
device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
||||
|
||||
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(
|
||||
brushnet['brushnet'].device)
|
||||
|
||||
if not is_SDXL:
|
||||
pooled_prompt_embeds = None
|
||||
negative_pooled_prompt_embeds = None
|
||||
time_ids = None
|
||||
|
||||
# apply patch to model
|
||||
brushnet_conditioning_scale = scale
|
||||
control_guidance_start = start_at
|
||||
control_guidance_end = end_at
|
||||
|
||||
add_brushnet_patch(model,
|
||||
brushnet['brushnet'],
|
||||
torch_dtype,
|
||||
conditioning_latents,
|
||||
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
||||
prompt_embeds, negative_prompt_embeds,
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
||||
False)
|
||||
|
||||
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
|
||||
device=brushnet['brushnet'].device)
|
||||
|
||||
return (model, positive, negative, {"samples": latent},)
|
||||
|
||||
#powperpaint
|
||||
def load_powerpaint_clip(self, base_clip_file, pp_clip_file):
|
||||
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_clip_file])
|
||||
|
||||
print('PowerPaint base CLIP file: ', base_clip_file)
|
||||
|
||||
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
|
||||
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
|
||||
|
||||
add_tokens(
|
||||
tokenizer=pp_tokenizer,
|
||||
text_encoder=pp_text_encoder,
|
||||
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
|
||||
initialize_tokens=["a", "a", "a"],
|
||||
num_vectors_per_token=10,
|
||||
)
|
||||
|
||||
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_clip_file), strict=False)
|
||||
|
||||
print('PowerPaint CLIP file: ', pp_clip_file)
|
||||
|
||||
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
|
||||
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
|
||||
|
||||
return (pp_clip,)
|
||||
|
||||
def powerpaint_model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
|
||||
is_SDXL, is_PP = self.check_compatibilty(model, powerpaint)
|
||||
if not is_PP:
|
||||
raise Exception("BrushNet model was loaded, please use BrushNet node")
|
||||
|
||||
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
||||
model = model.clone()
|
||||
|
||||
# prepare image and mask
|
||||
# no batches for original image and mask
|
||||
masked_image, mask = self.prepare_image(image, mask)
|
||||
|
||||
batch = masked_image.shape[0]
|
||||
# width = masked_image.shape[2]
|
||||
# height = masked_image.shape[1]
|
||||
|
||||
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
|
||||
'scale_factor'):
|
||||
scaling_factor = model.model.model_config.latent_format.scale_factor
|
||||
else:
|
||||
scaling_factor = sd15_scaling_factor
|
||||
|
||||
torch_dtype = powerpaint['dtype']
|
||||
|
||||
# prepare conditioning latents
|
||||
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
|
||||
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
|
||||
# prepare embeddings
|
||||
|
||||
if function == "object removal":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
print('You should add to positive prompt: "empty scene blur"')
|
||||
# positive = positive + " empty scene blur"
|
||||
elif function == "context aware":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = ""
|
||||
negative_promptB = ""
|
||||
# positive = positive + " empty scene"
|
||||
print('You should add to positive prompt: "empty scene"')
|
||||
elif function == "shape guided":
|
||||
promptA = "P_shape"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_shape"
|
||||
negative_promptB = "P_ctxt"
|
||||
elif function == "image outpainting":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
# positive = positive + " empty scene"
|
||||
print('You should add to positive prompt: "empty scene"')
|
||||
else:
|
||||
promptA = "P_obj"
|
||||
promptB = "P_obj"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
|
||||
tokens = clip.tokenize(promptA)
|
||||
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(negative_promptA)
|
||||
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(promptB)
|
||||
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(negative_promptB)
|
||||
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(
|
||||
powerpaint['brushnet'].device)
|
||||
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(
|
||||
dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
|
||||
# unload vae and CLIPs
|
||||
del vae
|
||||
del clip
|
||||
# for loaded_model in comfy.model_management.current_loaded_models:
|
||||
# if type(loaded_model.model.model) in ModelsToUnload:
|
||||
# comfy.model_management.current_loaded_models.remove(loaded_model)
|
||||
# loaded_model.model_unload()
|
||||
# del loaded_model
|
||||
|
||||
# apply patch to model
|
||||
|
||||
brushnet_conditioning_scale = scale
|
||||
control_guidance_start = start_at
|
||||
control_guidance_end = end_at
|
||||
|
||||
if save_memory != 'none':
|
||||
powerpaint['brushnet'].set_attention_slice(save_memory)
|
||||
|
||||
add_brushnet_patch(model,
|
||||
powerpaint['brushnet'],
|
||||
torch_dtype,
|
||||
conditioning_latents,
|
||||
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
||||
negative_prompt_embeds_pp, prompt_embeds_pp,
|
||||
None, None, None,
|
||||
False)
|
||||
|
||||
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
|
||||
device=powerpaint['brushnet'].device)
|
||||
|
||||
return (model, positive, negative, {"samples": latent},)
|
||||
@torch.inference_mode()
|
||||
def brushnet_inference(x, timesteps, transformer_options, debug):
|
||||
if 'model_patch' not in transformer_options:
|
||||
print('BrushNet inference: there is no model_patch key in transformer_options')
|
||||
return ([], 0, [])
|
||||
mp = transformer_options['model_patch']
|
||||
if 'brushnet' not in mp:
|
||||
print('BrushNet inference: there is no brushnet key in mdel_patch')
|
||||
return ([], 0, [])
|
||||
bo = mp['brushnet']
|
||||
if 'model' not in bo:
|
||||
print('BrushNet inference: there is no model key in brushnet')
|
||||
return ([], 0, [])
|
||||
brushnet = bo['model']
|
||||
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
|
||||
print('BrushNet model is not a BrushNetModel class')
|
||||
return ([], 0, [])
|
||||
|
||||
torch_dtype = bo['dtype']
|
||||
cl_list = bo['latents']
|
||||
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
|
||||
pe = bo['prompt_embeds']
|
||||
npe = bo['negative_prompt_embeds']
|
||||
ppe, nppe, time_ids = bo['add_embeds']
|
||||
|
||||
#do_classifier_free_guidance = mp['free_guidance']
|
||||
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
|
||||
|
||||
x = x.detach().clone()
|
||||
x = x.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
timesteps = timesteps.detach().clone()
|
||||
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
total_steps = mp['total_steps']
|
||||
step = mp['step']
|
||||
|
||||
added_cond_kwargs = {}
|
||||
|
||||
if do_classifier_free_guidance and step == 0:
|
||||
print('BrushNet inference: do_classifier_free_guidance is True')
|
||||
|
||||
sub_idx = None
|
||||
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
|
||||
sub_idx = transformer_options['ad_params']['sub_idxs']
|
||||
|
||||
# we have batch input images
|
||||
batch = cl_list[0].shape[0]
|
||||
# we have incoming latents
|
||||
latents_incoming = x.shape[0]
|
||||
# and we already got some
|
||||
latents_got = bo['latent_id']
|
||||
if step == 0 or batch > 1:
|
||||
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
|
||||
% (step, batch, latents_incoming, latents_got))
|
||||
|
||||
image_latents = []
|
||||
masks = []
|
||||
prompt_embeds = []
|
||||
negative_prompt_embeds = []
|
||||
pooled_prompt_embeds = []
|
||||
negative_pooled_prompt_embeds = []
|
||||
if sub_idx:
|
||||
# AnimateDiff indexes detected
|
||||
if step == 0:
|
||||
print('BrushNet inference: AnimateDiff indexes detected and applied')
|
||||
|
||||
batch = len(sub_idx)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
negative_prompt_embeds.append(npe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
negative_pooled_prompt_embeds.append(nppe)
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
else:
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
else:
|
||||
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
|
||||
continue_batch = True
|
||||
for i in range(latents_incoming):
|
||||
number = latents_got + i
|
||||
if number < batch:
|
||||
# 1st pass, cond
|
||||
image_latents.append(cl_list[0][number][None,:,:,:])
|
||||
masks.append(cl_list[1][number][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
elif do_classifier_free_guidance and number < batch * 2:
|
||||
# 2nd pass, uncond
|
||||
image_latents.append(cl_list[0][number-batch][None,:,:,:])
|
||||
masks.append(cl_list[1][number-batch][None,:,:,:])
|
||||
negative_prompt_embeds.append(npe)
|
||||
negative_pooled_prompt_embeds.append(nppe)
|
||||
else:
|
||||
# latent batch
|
||||
image_latents.append(cl_list[0][0][None,:,:,:])
|
||||
masks.append(cl_list[1][0][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
latents_got = -i
|
||||
continue_batch = False
|
||||
|
||||
if continue_batch:
|
||||
# we don't have full batch yet
|
||||
if do_classifier_free_guidance:
|
||||
if number < batch * 2 - 1:
|
||||
bo['latent_id'] = number + 1
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
else:
|
||||
if number < batch - 1:
|
||||
bo['latent_id'] = number + 1
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
|
||||
cl = []
|
||||
for il, m in zip(image_latents, masks):
|
||||
cl.append(torch.concat([il, m], dim=1))
|
||||
cl2apply = torch.concat(cl, dim=0)
|
||||
|
||||
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
prompt_embeds.extend(negative_prompt_embeds)
|
||||
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
if ppe is not None:
|
||||
added_cond_kwargs = {}
|
||||
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
|
||||
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
||||
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
|
||||
else:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
|
||||
if step == 0:
|
||||
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
|
||||
conditioning_latents = torch.nn.functional.interpolate(
|
||||
conditioning_latents, size=(
|
||||
x.shape[2],
|
||||
x.shape[3],
|
||||
), mode='bicubic',
|
||||
).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
if step == 0:
|
||||
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
|
||||
|
||||
if debug: print('BrushNet: step =', step)
|
||||
|
||||
if step < control_guidance_start or step > control_guidance_end:
|
||||
cond_scale = 0.0
|
||||
else:
|
||||
cond_scale = brushnet_conditioning_scale
|
||||
|
||||
return brushnet(x,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
brushnet_cond=conditioning_latents,
|
||||
timestep = timesteps,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=False,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
|
||||
controls,
|
||||
prompt_embeds, negative_prompt_embeds,
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
||||
debug):
|
||||
|
||||
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
|
||||
|
||||
if model.model.model_config.custom_operations is None:
|
||||
fp8 = model.model.model_config.optimizations.get("fp8", model.model.model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(model.model.model_config.unet_config.get("dtype", None), model.model.manual_cast_dtype,
|
||||
fp8_optimizations=fp8, scaled_fp8=model.model.model_config.scaled_fp8)
|
||||
else:
|
||||
# such as gguf
|
||||
operations = model.model.model_config.custom_operations
|
||||
|
||||
if is_SDXL:
|
||||
input_blocks = [[0, operations.Conv2d],
|
||||
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer]]
|
||||
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
||||
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
||||
else:
|
||||
input_blocks = [[0, operations.Conv2d],
|
||||
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
||||
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
||||
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[6, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[9, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[10, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[11, comfy.ldm.modules.attention.SpatialTransformer]]
|
||||
|
||||
def last_layer_index(block, tp):
|
||||
layer_list = []
|
||||
for layer in block:
|
||||
layer_list.append(type(layer))
|
||||
layer_list.reverse()
|
||||
if tp not in layer_list:
|
||||
return -1, layer_list.reverse()
|
||||
return len(layer_list) - 1 - layer_list.index(tp), layer_list
|
||||
|
||||
def brushnet_forward(model, x, timesteps, transformer_options, control):
|
||||
if 'brushnet' not in transformer_options['model_patch']:
|
||||
input_samples = []
|
||||
mid_sample = 0
|
||||
output_samples = []
|
||||
else:
|
||||
# brushnet inference
|
||||
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
|
||||
|
||||
# give additional samples to blocks
|
||||
for i, tp in input_blocks:
|
||||
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", tp, "layer in", i, "input block:", layer_list)
|
||||
continue
|
||||
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
|
||||
|
||||
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
|
||||
model.middle_block[idx].add_sample_after = mid_sample
|
||||
|
||||
for i, tp in output_blocks:
|
||||
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", tp, "layer in", i, "outnput block:", layer_list)
|
||||
continue
|
||||
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
|
||||
|
||||
patch_model_function_wrapper(model, brushnet_forward)
|
||||
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
if 'brushnet' not in mp:
|
||||
mp['brushnet'] = {}
|
||||
bo = mp['brushnet']
|
||||
|
||||
bo['model'] = brushnet
|
||||
bo['dtype'] = torch_dtype
|
||||
bo['latents'] = conditioning_latents
|
||||
bo['controls'] = controls
|
||||
bo['prompt_embeds'] = prompt_embeds
|
||||
bo['negative_prompt_embeds'] = negative_prompt_embeds
|
||||
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
|
||||
bo['latent_id'] = 0
|
||||
|
||||
# patch layers `forward` so we can apply brushnet
|
||||
def forward_patched_by_brushnet(self, x, *args, **kwargs):
|
||||
h = self.original_forward(x, *args, **kwargs)
|
||||
if hasattr(self, 'add_sample_after') and type(self):
|
||||
to_add = self.add_sample_after
|
||||
if torch.is_tensor(to_add):
|
||||
# interpolate due to RAUNet
|
||||
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
|
||||
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
|
||||
h += to_add.to(h.dtype).to(h.device)
|
||||
else:
|
||||
h += self.add_sample_after
|
||||
self.add_sample_after = 0
|
||||
return h
|
||||
|
||||
for i, block in enumerate(model.model.diffusion_model.input_blocks):
|
||||
for j, layer in enumerate(block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
|
||||
for j, layer in enumerate(model.model.diffusion_model.middle_block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
|
||||
for i, block in enumerate(model.model.diffusion_model.output_blocks):
|
||||
for j, layer in enumerate(block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.0.dev0",
|
||||
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "MidBlock2D",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": false,
|
||||
"use_linear_projection": false
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.0.dev0",
|
||||
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [
|
||||
5,
|
||||
10,
|
||||
20
|
||||
],
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 2048,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "MidBlock2D",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": [
|
||||
1,
|
||||
2,
|
||||
10
|
||||
],
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": true
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.2",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
],
|
||||
"upcast_attention": false,
|
||||
"use_linear_projection": false
|
||||
}
|
||||
1688
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model.py
Normal file
1688
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model.py
Normal file
File diff suppressed because it is too large
Load Diff
137
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model_patch.py
Normal file
137
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model_patch.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import comfy
|
||||
|
||||
# Check and add 'model_patch' to model.model_options['transformer_options']
|
||||
def add_model_patch_option(model):
|
||||
if 'transformer_options' not in model.model_options:
|
||||
model.model_options['transformer_options'] = {}
|
||||
to = model.model_options['transformer_options']
|
||||
if "model_patch" not in to:
|
||||
to["model_patch"] = {}
|
||||
return to
|
||||
|
||||
|
||||
# Patch model with model_function_wrapper
|
||||
def patch_model_function_wrapper(model, forward_patch, remove=False):
|
||||
def brushnet_model_function_wrapper(apply_model_method, options_dict):
|
||||
to = options_dict['c']['transformer_options']
|
||||
|
||||
control = None
|
||||
if 'control' in options_dict['c']:
|
||||
control = options_dict['c']['control']
|
||||
|
||||
x = options_dict['input']
|
||||
timestep = options_dict['timestep']
|
||||
|
||||
# check if there are patches to execute
|
||||
if 'model_patch' not in to or 'forward' not in to['model_patch']:
|
||||
return apply_model_method(x, timestep, **options_dict['c'])
|
||||
|
||||
mp = to['model_patch']
|
||||
unet = mp['unet']
|
||||
|
||||
all_sigmas = mp['all_sigmas']
|
||||
sigma = to['sigmas'][0].item()
|
||||
total_steps = all_sigmas.shape[0] - 1
|
||||
step = torch.argmin((all_sigmas - sigma).abs()).item()
|
||||
|
||||
mp['step'] = step
|
||||
mp['total_steps'] = total_steps
|
||||
|
||||
# comfy.model_base.apply_model
|
||||
xc = model.model.model_sampling.calculate_input(timestep, x)
|
||||
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
|
||||
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
|
||||
t = model.model.model_sampling.timestep(timestep).float()
|
||||
# execute all patches
|
||||
for method in mp['forward']:
|
||||
method(unet, xc, t, to, control)
|
||||
|
||||
return apply_model_method(x, timestep, **options_dict['c'])
|
||||
|
||||
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
|
||||
print('BrushNet is going to replace existing model_function_wrapper:',
|
||||
model.model_options["model_function_wrapper"])
|
||||
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
|
||||
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
mp['SDXL'] = False
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
mp['SDXL'] = True
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
||||
|
||||
if 'forward' not in mp:
|
||||
mp['forward'] = []
|
||||
|
||||
if remove:
|
||||
if forward_patch in mp['forward']:
|
||||
mp['forward'].remove(forward_patch)
|
||||
else:
|
||||
mp['forward'].append(forward_patch)
|
||||
|
||||
mp['unet'] = model.model.diffusion_model
|
||||
mp['step'] = 0
|
||||
mp['total_steps'] = 1
|
||||
|
||||
# apply patches to code
|
||||
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
|
||||
comfy.samplers.original_sample = comfy.samplers.sample
|
||||
comfy.samplers.sample = modified_sample
|
||||
|
||||
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
|
||||
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
|
||||
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
|
||||
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
|
||||
|
||||
|
||||
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
|
||||
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
|
||||
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
|
||||
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
''' Modified by BrushNet nodes'''
|
||||
cfg_guider = comfy.samplers.CFGGuider(model)
|
||||
cfg_guider.set_conds(positive, negative)
|
||||
cfg_guider.set_cfg(cfg)
|
||||
|
||||
### Modified part ######################################################################
|
||||
to = add_model_patch_option(model)
|
||||
to['model_patch']['all_sigmas'] = sigmas
|
||||
#######################################################################################
|
||||
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
|
||||
def modified_apply_control(h, control, name):
|
||||
'''Modified by BrushNet nodes'''
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
|
||||
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(
|
||||
h.dtype).to(h.device)
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
return h
|
||||
|
||||
def add_model_patch(model):
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
if "brushnet" in mp:
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
mp['SDXL'] = False
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
mp['SDXL'] = True
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
||||
|
||||
mp['unet'] = model.model.diffusion_model
|
||||
mp['step'] = 0
|
||||
mp['total_steps'] = 1
|
||||
@@ -0,0 +1,467 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
|
||||
currently. This wrapper is modified from https://github.com/huggingface/dif
|
||||
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
|
||||
py#L358 # noqa.
|
||||
|
||||
Args:
|
||||
from_pretrained (Union[str, os.PathLike], optional): The *model id*
|
||||
of a pretrained model or a path to a *directory* containing
|
||||
model weights and config. Defaults to None.
|
||||
from_config (Union[str, os.PathLike], optional): The *model id*
|
||||
of a pretrained model or a path to a *directory* containing
|
||||
model weights and config. Defaults to None.
|
||||
|
||||
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
|
||||
will be passed to `from_pretrained` function. Otherwise, *args
|
||||
and **kwargs will be used to initialize the model by
|
||||
`self._module_cls(*args, **kwargs)`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.wrapped = tokenizer
|
||||
self.token_map = {}
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self.__dict__:
|
||||
return getattr(self, name)
|
||||
# if name == "wrapped":
|
||||
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
|
||||
|
||||
try:
|
||||
return getattr(self.wrapped, name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"'name' cannot be found in both "
|
||||
f"'{self.__class__.__name__}' and "
|
||||
f"'{self.__class__.__name__}.tokenizer'."
|
||||
)
|
||||
|
||||
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
||||
"""Attempt to add tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
tokens (Union[str, List[str]]): The tokens to be added.
|
||||
"""
|
||||
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
||||
assert num_added_tokens != 0, (
|
||||
f"The tokenizer already contains the token {tokens}. Please pass "
|
||||
"a different `placeholder_token` that is not already in the "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> dict:
|
||||
"""Get the information of a token, including its start and end index in
|
||||
the current tokenizer.
|
||||
|
||||
Args:
|
||||
token (str): The token to be queried.
|
||||
|
||||
Returns:
|
||||
dict: The information of the token, including its start and end
|
||||
index in current tokenizer.
|
||||
"""
|
||||
token_ids = self.__call__(token).input_ids
|
||||
start, end = token_ids[1], token_ids[-2] + 1
|
||||
return {"name": token, "start": start, "end": end}
|
||||
|
||||
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
|
||||
"""Add placeholder tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
placeholder_token (str): The placeholder token to be added.
|
||||
num_vec_per_token (int, optional): The number of vectors of
|
||||
the added placeholder token.
|
||||
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
||||
"""
|
||||
output = []
|
||||
if num_vec_per_token == 1:
|
||||
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||
output.append(placeholder_token)
|
||||
else:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||
output.append(ith_token)
|
||||
|
||||
for token in self.token_map:
|
||||
if token in placeholder_token:
|
||||
raise ValueError(
|
||||
f"The tokenizer already has placeholder token {token} "
|
||||
f"that can get confused with {placeholder_token} "
|
||||
"keep placeholder tokens independent"
|
||||
)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def replace_placeholder_tokens_in_text(
|
||||
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the keywords in text with placeholder tokens. This function
|
||||
will be called in `self.__call__` and `self.encode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
|
||||
return output
|
||||
|
||||
for placeholder_token in self.token_map:
|
||||
if placeholder_token in text:
|
||||
tokens = self.token_map[placeholder_token]
|
||||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||
if vector_shuffle:
|
||||
tokens = copy.copy(tokens)
|
||||
random.shuffle(tokens)
|
||||
text = text.replace(placeholder_token, " ".join(tokens))
|
||||
return text
|
||||
|
||||
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
|
||||
"""Replace the placeholder tokens in text with the original keywords.
|
||||
This function will be called in `self.decode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
||||
return output
|
||||
|
||||
for placeholder_token, tokens in self.token_map.items():
|
||||
merged_tokens = " ".join(tokens)
|
||||
if merged_tokens in text:
|
||||
text = text.replace(merged_tokens, placeholder_token)
|
||||
return text
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
*args,
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""The call function of the wrapper.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be tokenized.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
)
|
||||
|
||||
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
||||
|
||||
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
||||
"""Encode the passed text to token index.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be encode.
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
||||
return self.wrapped(replaced_text, *args, **kwargs)
|
||||
|
||||
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
|
||||
"""Decode the token index to text.
|
||||
|
||||
Args:
|
||||
token_ids: The token index to be decoded.
|
||||
return_raw: Whether keep the placeholder token in the text.
|
||||
Defaults to False.
|
||||
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The decoded text.
|
||||
"""
|
||||
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
||||
if return_raw:
|
||||
return text
|
||||
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||
return replaced_text
|
||||
|
||||
def __repr__(self):
|
||||
"""The representation of the wrapper."""
|
||||
s = super().__repr__()
|
||||
prefix = f"Wrapped Module Class: {self._module_cls}\n"
|
||||
prefix += f"Wrapped Module Name: {self._module_name}\n"
|
||||
if self._from_pretrained:
|
||||
prefix += f"From Pretrained: {self._from_pretrained}\n"
|
||||
s = prefix + s
|
||||
return s
|
||||
|
||||
|
||||
class EmbeddingLayerWithFixes(nn.Module):
|
||||
"""The revised embedding layer to support external embeddings. This design
|
||||
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
||||
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
||||
jack.py#L224 # noqa.
|
||||
|
||||
Args:
|
||||
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
||||
external_embeddings (Union[dict, List[dict]], optional): The external
|
||||
embeddings added to this layer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.num_embeddings = wrapped.weight.shape[0]
|
||||
|
||||
self.external_embeddings = []
|
||||
if external_embeddings:
|
||||
self.add_embeddings(external_embeddings)
|
||||
|
||||
self.trainable_embeddings = nn.ParameterDict()
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
"""Get the weight of wrapped embedding layer."""
|
||||
return self.wrapped.weight
|
||||
|
||||
def check_duplicate_names(self, embeddings: List[dict]):
|
||||
"""Check whether duplicate names exist in list of 'external
|
||||
embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
names = [emb["name"] for emb in embeddings]
|
||||
assert len(names) == len(set(names)), (
|
||||
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
||||
)
|
||||
|
||||
def check_ids_overlap(self, embeddings):
|
||||
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
||||
ids_range.sort() # sort by 'start'
|
||||
# check if 'end' has overlapping
|
||||
for idx in range(len(ids_range) - 1):
|
||||
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
||||
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
||||
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
||||
)
|
||||
|
||||
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
||||
"""Add external embeddings to this layer.
|
||||
Use case:
|
||||
Args:
|
||||
embeddings (Union[dict, list[dict]]): The external embeddings to
|
||||
be added. Each dict must contain the following 4 fields: 'name'
|
||||
(the name of this embedding), 'embedding' (the embedding
|
||||
tensor), 'start' (the start token id of this embedding), 'end'
|
||||
(the end token id of this embedding). For example:
|
||||
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
||||
"""
|
||||
if isinstance(embeddings, dict):
|
||||
embeddings = [embeddings]
|
||||
|
||||
self.external_embeddings += embeddings
|
||||
self.check_duplicate_names(self.external_embeddings)
|
||||
self.check_ids_overlap(self.external_embeddings)
|
||||
|
||||
# set for trainable
|
||||
added_trainable_emb_info = []
|
||||
for embedding in embeddings:
|
||||
trainable = embedding.get("trainable", False)
|
||||
if trainable:
|
||||
name = embedding["name"]
|
||||
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
||||
self.trainable_embeddings[name] = embedding["embedding"]
|
||||
added_trainable_emb_info.append(name)
|
||||
|
||||
added_emb_info = [emb["name"] for emb in embeddings]
|
||||
added_emb_info = ", ".join(added_emb_info)
|
||||
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
||||
|
||||
if added_trainable_emb_info:
|
||||
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
||||
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
|
||||
|
||||
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Replace external input ids to 0.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The input ids to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced input ids.
|
||||
"""
|
||||
input_ids_fwd = input_ids.clone()
|
||||
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
||||
return input_ids_fwd
|
||||
|
||||
def replace_embeddings(
|
||||
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
||||
) -> torch.Tensor:
|
||||
"""Replace external embedding to the embedding layer. Noted that, in
|
||||
this function we use `torch.cat` to avoid inplace modification.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The original token ids. Shape like
|
||||
[LENGTH, ].
|
||||
embedding (torch.Tensor): The embedding of token ids after
|
||||
`replace_input_ids` function.
|
||||
external_embedding (dict): The external embedding to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced embedding.
|
||||
"""
|
||||
new_embedding = []
|
||||
|
||||
name = external_embedding["name"]
|
||||
start = external_embedding["start"]
|
||||
end = external_embedding["end"]
|
||||
target_ids_to_replace = [i for i in range(start, end)]
|
||||
ext_emb = external_embedding["embedding"].to(embedding.device)
|
||||
|
||||
# do not need to replace
|
||||
if not (input_ids == start).any():
|
||||
return embedding
|
||||
|
||||
# start replace
|
||||
s_idx, e_idx = 0, 0
|
||||
while e_idx < len(input_ids):
|
||||
if input_ids[e_idx] == start:
|
||||
if e_idx != 0:
|
||||
# add embedding do not need to replace
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
# check if the next embedding need to replace is valid
|
||||
actually_ids_to_replace = [int(i) for i in input_ids[e_idx: e_idx + end - start]]
|
||||
assert actually_ids_to_replace == target_ids_to_replace, (
|
||||
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
||||
f"Expect '{target_ids_to_replace}' for embedding "
|
||||
f"'{name}' but found '{actually_ids_to_replace}'."
|
||||
)
|
||||
|
||||
new_embedding.append(ext_emb)
|
||||
|
||||
s_idx = e_idx + end - start
|
||||
e_idx = s_idx + 1
|
||||
else:
|
||||
e_idx += 1
|
||||
|
||||
if e_idx == len(input_ids):
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
return torch.cat(new_embedding, dim=0)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
|
||||
"""The forward function.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
||||
[LENGTH, ].
|
||||
external_embeddings (Optional[List[dict]]): The external
|
||||
embeddings. If not passed, only `self.external_embeddings`
|
||||
will be used. Defaults to None.
|
||||
|
||||
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
||||
"""
|
||||
assert input_ids.ndim in [1, 2]
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
if external_embeddings is None and not self.external_embeddings:
|
||||
return self.wrapped(input_ids, out_dtype=out_dtype)
|
||||
|
||||
input_ids_fwd = self.replace_input_ids(input_ids)
|
||||
inputs_embeds = self.wrapped(input_ids_fwd)
|
||||
|
||||
vecs = []
|
||||
|
||||
if external_embeddings is None:
|
||||
external_embeddings = []
|
||||
elif isinstance(external_embeddings, dict):
|
||||
external_embeddings = [external_embeddings]
|
||||
embeddings = self.external_embeddings + external_embeddings
|
||||
|
||||
for input_id, embedding in zip(input_ids, inputs_embeds):
|
||||
new_embedding = embedding
|
||||
for external_embedding in embeddings:
|
||||
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
|
||||
vecs.append(new_embedding)
|
||||
|
||||
return torch.stack(vecs).to(out_dtype)
|
||||
|
||||
|
||||
def add_tokens(
|
||||
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None,
|
||||
num_vectors_per_token: int = 1
|
||||
):
|
||||
"""Add token for training.
|
||||
|
||||
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
||||
"""
|
||||
if initialize_tokens is not None:
|
||||
assert len(initialize_tokens) == len(
|
||||
placeholder_tokens
|
||||
), "placeholder_token should be the same length as initialize_token"
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
|
||||
|
||||
# text_encoder.set_embedding_layer()
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
|
||||
assert embedding_layer is not None, (
|
||||
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
|
||||
)
|
||||
initialize_embedding = []
|
||||
if initialize_tokens is not None:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
|
||||
else:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer("a").input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
len_emb = temp_embedding.shape[0]
|
||||
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
||||
initialize_embedding.append(init_weight)
|
||||
|
||||
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
||||
|
||||
token_info_all = []
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
||||
token_info["embedding"] = initialize_embedding[ii]
|
||||
token_info["trainable"] = True
|
||||
token_info_all.append(token_info)
|
||||
embedding_layer.add_embeddings(token_info_all)
|
||||
3908
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/unet_2d_blocks.py
Normal file
3908
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/unet_2d_blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user