Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,822 @@
#credit to nullquant for this module
#from https://github.com/nullquant/ComfyUI-BrushNet
import os
import types
import torch
try:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
except:
init_empty_weights, load_checkpoint_and_dispatch = None, None
import comfy
try:
from .model import BrushNetModel, PowerPaintModel
from .model_patch import add_model_patch_option, patch_model_function_wrapper
from .powerpaint_utils import TokenizerWrapper, add_tokens
except:
BrushNetModel, PowerPaintModel = None, None
add_model_patch_option, patch_model_function_wrapper = None, None
TokenizerWrapper, add_tokens = None, None
cwd_path = os.path.dirname(os.path.realpath(__file__))
brushnet_config_file = os.path.join(cwd_path, 'config', 'brushnet.json')
brushnet_xl_config_file = os.path.join(cwd_path, 'config', 'brushnet_xl.json')
powerpaint_config_file = os.path.join(cwd_path, 'config', 'powerpaint.json')
sd15_scaling_factor = 0.18215
sdxl_scaling_factor = 0.13025
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL]
class BrushNet:
# Check models compatibility
def check_compatibilty(self, model, brushnet):
is_SDXL = False
is_PP = False
if isinstance(model.model.model_config, comfy.supported_models.SD15):
print('Base model type: SD1.5')
is_SDXL = False
if brushnet["SDXL"]:
raise Exception("Base model is SD15, but BrushNet is SDXL type")
if brushnet["PP"]:
is_PP = True
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
print('Base model type: SDXL')
is_SDXL = True
if not brushnet["SDXL"]:
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
return (is_SDXL, is_PP)
def check_image_mask(self, image, mask, name):
if len(image.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
image = image[None, :, :, :]
if len(mask.shape) > 3:
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
# take first mask, red channel
mask = (mask[:, :, :, 0])[:, :, :]
elif len(mask.shape) < 3:
# mask tensor shape should be [B, H, W] but batch somehow is missing
mask = mask[None, :, :]
if image.shape[0] > mask.shape[0]:
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
if mask.shape[0] == 1:
print(name, "will copy the mask to fill batch")
mask = torch.cat([mask] * image.shape[0], dim=0)
else:
print(name, "will add empty masks to fill batch")
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
mask = torch.cat([mask, empty_mask], dim=0)
elif image.shape[0] < mask.shape[0]:
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
mask = mask[:image.shape[0], :, :]
return (image, mask)
# Prepare image and mask
def prepare_image(self, image, mask):
image, mask = self.check_image_mask(image, mask, 'BrushNet')
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
raise Exception("Image and mask should be the same size")
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
mask = mask.round()
masked_image = image * (1.0 - mask[:, :, :, None])
return (masked_image, mask)
# Get origin of the mask
def cut_with_mask(self, mask, width, height):
iy, ix = (mask == 1).nonzero(as_tuple=True)
h0, w0 = mask.shape
if iy.numel() == 0:
x_c = w0 / 2.0
y_c = h0 / 2.0
else:
x_min = ix.min().item()
x_max = ix.max().item()
y_min = iy.min().item()
y_max = iy.max().item()
if x_max - x_min > width or y_max - y_min > height:
raise Exception("Mask is bigger than provided dimensions")
x_c = (x_min + x_max) / 2.0
y_c = (y_min + y_max) / 2.0
width2 = width / 2.0
height2 = height / 2.0
if w0 <= width:
x0 = 0
w = w0
else:
x0 = max(0, x_c - width2)
w = width
if x0 + width > w0:
x0 = w0 - width
if h0 <= height:
y0 = 0
h = h0
else:
y0 = max(0, y_c - height2)
h = height
if y0 + height > h0:
y0 = h0 - height
return (int(x0), int(y0), int(w), int(h))
# Prepare conditioning_latents
@torch.inference_mode()
def get_image_latents(self, masked_image, mask, vae, scaling_factor):
processed_image = masked_image.to(vae.device)
image_latents = vae.encode(processed_image[:, :, :, :3]) * scaling_factor
processed_mask = 1. - mask[:, None, :, :]
interpolated_mask = torch.nn.functional.interpolate(
processed_mask,
size=(
image_latents.shape[-2],
image_latents.shape[-1]
)
)
interpolated_mask = interpolated_mask.to(image_latents.device)
conditioning_latents = [image_latents, interpolated_mask]
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =',
interpolated_mask.shape)
return conditioning_latents
def brushnet_blocks(self, sd):
brushnet_down_block = 0
brushnet_mid_block = 0
brushnet_up_block = 0
for key in sd:
if 'brushnet_down_block' in key:
brushnet_down_block += 1
if 'brushnet_mid_block' in key:
brushnet_mid_block += 1
if 'brushnet_up_block' in key:
brushnet_up_block += 1
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
def get_model_type(self, brushnet_file):
sd = comfy.utils.load_torch_file(brushnet_file)
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = self.brushnet_blocks(sd)
del sd
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
is_SDXL = False
if keys == 322:
is_PP = False
print('BrushNet model type: SD1.5')
else:
is_PP = True
print('PowerPaint model type: SD1.5')
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
print('BrushNet model type: Loading SDXL')
is_SDXL = True
is_PP = False
else:
raise Exception("Unknown BrushNet model")
return is_SDXL, is_PP
def load_brushnet_model(self, brushnet_file, dtype='float16'):
is_SDXL, is_PP = self.get_model_type(brushnet_file)
with init_empty_weights():
if is_SDXL:
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
elif is_PP:
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
brushnet_model = PowerPaintModel.from_config(brushnet_config)
else:
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
if is_PP:
print("PowerPaint model file:", brushnet_file)
else:
print("BrushNet model file:", brushnet_file)
if dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float32':
torch_dtype = torch.float32
else:
torch_dtype = torch.float64
brushnet_model = load_checkpoint_and_dispatch(
brushnet_model,
brushnet_file,
device_map="sequential",
max_memory=None,
offload_folder=None,
offload_state_dict=False,
dtype=torch_dtype,
force_hooks=False,
)
if is_PP:
print("PowerPaint model is loaded")
elif is_SDXL:
print("BrushNet SDXL model is loaded")
else:
print("BrushNet SD1.5 model is loaded")
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype},)
def brushnet_model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
is_SDXL, is_PP = self.check_compatibilty(model, brushnet)
if is_PP:
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
width = masked_image.shape[2]
height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
elif is_SDXL:
scaling_factor = sdxl_scaling_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = brushnet['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
# unload vae
del vae
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# prepare embeddings
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
if prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape,
'multiplying prompt_embeds')
if negative_prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
negative_prompt_embeds = torch.concat(
[negative_prompt_embeds] + [negative_prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape,
'multiplying negative_prompt_embeds')
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
else:
print('BrushNet: positive conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
else:
print('BrushNet: negative conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]],
device=brushnet['brushnet'].device).to(dtype=torch_dtype)
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
if not is_SDXL:
pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None
time_ids = None
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
add_brushnet_patch(model,
brushnet['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=brushnet['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
#powperpaint
def load_powerpaint_clip(self, base_clip_file, pp_clip_file):
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_clip_file])
print('PowerPaint base CLIP file: ', base_clip_file)
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
add_tokens(
tokenizer=pp_tokenizer,
text_encoder=pp_text_encoder,
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
initialize_tokens=["a", "a", "a"],
num_vectors_per_token=10,
)
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_clip_file), strict=False)
print('PowerPaint CLIP file: ', pp_clip_file)
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
return (pp_clip,)
def powerpaint_model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
is_SDXL, is_PP = self.check_compatibilty(model, powerpaint)
if not is_PP:
raise Exception("BrushNet model was loaded, please use BrushNet node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
# width = masked_image.shape[2]
# height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = powerpaint['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
# prepare embeddings
if function == "object removal":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
print('You should add to positive prompt: "empty scene blur"')
# positive = positive + " empty scene blur"
elif function == "context aware":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = ""
negative_promptB = ""
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
elif function == "shape guided":
promptA = "P_shape"
promptB = "P_ctxt"
negative_promptA = "P_shape"
negative_promptB = "P_ctxt"
elif function == "image outpainting":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
else:
promptA = "P_obj"
promptB = "P_obj"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
tokens = clip.tokenize(promptA)
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptA)
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(promptB)
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptB)
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(
powerpaint['brushnet'].device)
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(
dtype=torch_dtype).to(powerpaint['brushnet'].device)
# unload vae and CLIPs
del vae
del clip
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
if save_memory != 'none':
powerpaint['brushnet'].set_attention_slice(save_memory)
add_brushnet_patch(model,
powerpaint['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
negative_prompt_embeds_pp, prompt_embeds_pp,
None, None, None,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=powerpaint['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
@torch.inference_mode()
def brushnet_inference(x, timesteps, transformer_options, debug):
if 'model_patch' not in transformer_options:
print('BrushNet inference: there is no model_patch key in transformer_options')
return ([], 0, [])
mp = transformer_options['model_patch']
if 'brushnet' not in mp:
print('BrushNet inference: there is no brushnet key in mdel_patch')
return ([], 0, [])
bo = mp['brushnet']
if 'model' not in bo:
print('BrushNet inference: there is no model key in brushnet')
return ([], 0, [])
brushnet = bo['model']
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
print('BrushNet model is not a BrushNetModel class')
return ([], 0, [])
torch_dtype = bo['dtype']
cl_list = bo['latents']
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
pe = bo['prompt_embeds']
npe = bo['negative_prompt_embeds']
ppe, nppe, time_ids = bo['add_embeds']
#do_classifier_free_guidance = mp['free_guidance']
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
x = x.detach().clone()
x = x.to(torch_dtype).to(brushnet.device)
timesteps = timesteps.detach().clone()
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
total_steps = mp['total_steps']
step = mp['step']
added_cond_kwargs = {}
if do_classifier_free_guidance and step == 0:
print('BrushNet inference: do_classifier_free_guidance is True')
sub_idx = None
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
sub_idx = transformer_options['ad_params']['sub_idxs']
# we have batch input images
batch = cl_list[0].shape[0]
# we have incoming latents
latents_incoming = x.shape[0]
# and we already got some
latents_got = bo['latent_id']
if step == 0 or batch > 1:
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
% (step, batch, latents_incoming, latents_got))
image_latents = []
masks = []
prompt_embeds = []
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
if sub_idx:
# AnimateDiff indexes detected
if step == 0:
print('BrushNet inference: AnimateDiff indexes detected and applied')
batch = len(sub_idx)
if do_classifier_free_guidance:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
negative_prompt_embeds.append(npe)
pooled_prompt_embeds.append(ppe)
negative_pooled_prompt_embeds.append(nppe)
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
else:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
else:
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
continue_batch = True
for i in range(latents_incoming):
number = latents_got + i
if number < batch:
# 1st pass, cond
image_latents.append(cl_list[0][number][None,:,:,:])
masks.append(cl_list[1][number][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
elif do_classifier_free_guidance and number < batch * 2:
# 2nd pass, uncond
image_latents.append(cl_list[0][number-batch][None,:,:,:])
masks.append(cl_list[1][number-batch][None,:,:,:])
negative_prompt_embeds.append(npe)
negative_pooled_prompt_embeds.append(nppe)
else:
# latent batch
image_latents.append(cl_list[0][0][None,:,:,:])
masks.append(cl_list[1][0][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
latents_got = -i
continue_batch = False
if continue_batch:
# we don't have full batch yet
if do_classifier_free_guidance:
if number < batch * 2 - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
if number < batch - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
bo['latent_id'] = 0
cl = []
for il, m in zip(image_latents, masks):
cl.append(torch.concat([il, m], dim=1))
cl2apply = torch.concat(cl, dim=0)
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
prompt_embeds.extend(negative_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
if ppe is not None:
added_cond_kwargs = {}
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
else:
added_cond_kwargs = None
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
if step == 0:
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
conditioning_latents = torch.nn.functional.interpolate(
conditioning_latents, size=(
x.shape[2],
x.shape[3],
), mode='bicubic',
).to(torch_dtype).to(brushnet.device)
if step == 0:
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
if debug: print('BrushNet: step =', step)
if step < control_guidance_start or step > control_guidance_end:
cond_scale = 0.0
else:
cond_scale = brushnet_conditioning_scale
return brushnet(x,
encoder_hidden_states=prompt_embeds,
brushnet_cond=conditioning_latents,
timestep = timesteps,
conditioning_scale=cond_scale,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
debug=debug,
)
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
controls,
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
debug):
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
if model.model.model_config.custom_operations is None:
fp8 = model.model.model_config.optimizations.get("fp8", model.model.model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(model.model.model_config.unet_config.get("dtype", None), model.model.manual_cast_dtype,
fp8_optimizations=fp8, scaled_fp8=model.model.model_config.scaled_fp8)
else:
# such as gguf
operations = model.model.model_config.custom_operations
if is_SDXL:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
else:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.attention.SpatialTransformer],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[9, comfy.ldm.modules.attention.SpatialTransformer],
[10, comfy.ldm.modules.attention.SpatialTransformer],
[11, comfy.ldm.modules.attention.SpatialTransformer]]
def last_layer_index(block, tp):
layer_list = []
for layer in block:
layer_list.append(type(layer))
layer_list.reverse()
if tp not in layer_list:
return -1, layer_list.reverse()
return len(layer_list) - 1 - layer_list.index(tp), layer_list
def brushnet_forward(model, x, timesteps, transformer_options, control):
if 'brushnet' not in transformer_options['model_patch']:
input_samples = []
mid_sample = 0
output_samples = []
else:
# brushnet inference
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
# give additional samples to blocks
for i, tp in input_blocks:
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "input block:", layer_list)
continue
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
if idx < 0:
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
model.middle_block[idx].add_sample_after = mid_sample
for i, tp in output_blocks:
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "outnput block:", layer_list)
continue
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
patch_model_function_wrapper(model, brushnet_forward)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'brushnet' not in mp:
mp['brushnet'] = {}
bo = mp['brushnet']
bo['model'] = brushnet
bo['dtype'] = torch_dtype
bo['latents'] = conditioning_latents
bo['controls'] = controls
bo['prompt_embeds'] = prompt_embeds
bo['negative_prompt_embeds'] = negative_prompt_embeds
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
bo['latent_id'] = 0
# patch layers `forward` so we can apply brushnet
def forward_patched_by_brushnet(self, x, *args, **kwargs):
h = self.original_forward(x, *args, **kwargs)
if hasattr(self, 'add_sample_after') and type(self):
to_add = self.add_sample_after
if torch.is_tensor(to_add):
# interpolate due to RAUNet
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
h += to_add.to(h.dtype).to(h.device)
else:
h += self.add_sample_after
self.add_sample_after = 0
return h
for i, block in enumerate(model.model.diffusion_model.input_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for j, layer in enumerate(model.model.diffusion_model.middle_block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for i, block in enumerate(model.model.diffusion_model.output_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0

View File

@@ -0,0 +1,58 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

View File

@@ -0,0 +1,63 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [
5,
10,
20
],
"block_out_channels": [
320,
640,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 2048,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": 2816,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": [
1,
2,
10
],
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": null,
"use_linear_projection": true
}

View File

@@ -0,0 +1,57 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.2",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
import torch
import comfy
# Check and add 'model_patch' to model.model_options['transformer_options']
def add_model_patch_option(model):
if 'transformer_options' not in model.model_options:
model.model_options['transformer_options'] = {}
to = model.model_options['transformer_options']
if "model_patch" not in to:
to["model_patch"] = {}
return to
# Patch model with model_function_wrapper
def patch_model_function_wrapper(model, forward_patch, remove=False):
def brushnet_model_function_wrapper(apply_model_method, options_dict):
to = options_dict['c']['transformer_options']
control = None
if 'control' in options_dict['c']:
control = options_dict['c']['control']
x = options_dict['input']
timestep = options_dict['timestep']
# check if there are patches to execute
if 'model_patch' not in to or 'forward' not in to['model_patch']:
return apply_model_method(x, timestep, **options_dict['c'])
mp = to['model_patch']
unet = mp['unet']
all_sigmas = mp['all_sigmas']
sigma = to['sigmas'][0].item()
total_steps = all_sigmas.shape[0] - 1
step = torch.argmin((all_sigmas - sigma).abs()).item()
mp['step'] = step
mp['total_steps'] = total_steps
# comfy.model_base.apply_model
xc = model.model.model_sampling.calculate_input(timestep, x)
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
t = model.model.model_sampling.timestep(timestep).float()
# execute all patches
for method in mp['forward']:
method(unet, xc, t, to, control)
return apply_model_method(x, timestep, **options_dict['c'])
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
print('BrushNet is going to replace existing model_function_wrapper:',
model.model_options["model_function_wrapper"])
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
to = add_model_patch_option(model)
mp = to['model_patch']
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
if 'forward' not in mp:
mp['forward'] = []
if remove:
if forward_patch in mp['forward']:
mp['forward'].remove(forward_patch)
else:
mp['forward'].append(forward_patch)
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1
# apply patches to code
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
comfy.samplers.original_sample = comfy.samplers.sample
comfy.samplers.sample = modified_sample
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
''' Modified by BrushNet nodes'''
cfg_guider = comfy.samplers.CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
### Modified part ######################################################################
to = add_model_patch_option(model)
to['model_patch']['all_sigmas'] = sigmas
#######################################################################################
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
def modified_apply_control(h, control, name):
'''Modified by BrushNet nodes'''
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(
h.dtype).to(h.device)
try:
h += ctrl
except:
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
def add_model_patch(model):
to = add_model_patch_option(model)
mp = to['model_patch']
if "brushnet" in mp:
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1

View File

@@ -0,0 +1,467 @@
import copy
import random
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from typing import Any, List, Optional, Union
class TokenizerWrapper:
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
currently. This wrapper is modified from https://github.com/huggingface/dif
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
py#L358 # noqa.
Args:
from_pretrained (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
from_config (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
will be passed to `from_pretrained` function. Otherwise, *args
and **kwargs will be used to initialize the model by
`self._module_cls(*args, **kwargs)`.
"""
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
def __getattr__(self, name: str) -> Any:
if name in self.__dict__:
return getattr(self, name)
# if name == "wrapped":
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text
def __repr__(self):
"""The representation of the wrapper."""
s = super().__repr__()
prefix = f"Wrapped Module Class: {self._module_cls}\n"
prefix += f"Wrapped Module Name: {self._module_name}\n"
if self._from_pretrained:
prefix += f"From Pretrained: {self._from_pretrained}\n"
s = prefix + s
return s
class EmbeddingLayerWithFixes(nn.Module):
"""The revised embedding layer to support external embeddings. This design
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
jack.py#L224 # noqa.
Args:
wrapped (nn.Emebdding): The embedding layer to be wrapped.
external_embeddings (Union[dict, List[dict]], optional): The external
embeddings added to this layer. Defaults to None.
"""
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
super().__init__()
self.wrapped = wrapped
self.num_embeddings = wrapped.weight.shape[0]
self.external_embeddings = []
if external_embeddings:
self.add_embeddings(external_embeddings)
self.trainable_embeddings = nn.ParameterDict()
@property
def weight(self):
"""Get the weight of wrapped embedding layer."""
return self.wrapped.weight
def check_duplicate_names(self, embeddings: List[dict]):
"""Check whether duplicate names exist in list of 'external
embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
names = [emb["name"] for emb in embeddings]
assert len(names) == len(set(names)), (
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
)
def check_ids_overlap(self, embeddings):
"""Check whether overlap exist in token ids of 'external_embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
ids_range.sort() # sort by 'start'
# check if 'end' has overlapping
for idx in range(len(ids_range) - 1):
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
)
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
"""Add external embeddings to this layer.
Use case:
Args:
embeddings (Union[dict, list[dict]]): The external embeddings to
be added. Each dict must contain the following 4 fields: 'name'
(the name of this embedding), 'embedding' (the embedding
tensor), 'start' (the start token id of this embedding), 'end'
(the end token id of this embedding). For example:
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
"""
if isinstance(embeddings, dict):
embeddings = [embeddings]
self.external_embeddings += embeddings
self.check_duplicate_names(self.external_embeddings)
self.check_ids_overlap(self.external_embeddings)
# set for trainable
added_trainable_emb_info = []
for embedding in embeddings:
trainable = embedding.get("trainable", False)
if trainable:
name = embedding["name"]
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
self.trainable_embeddings[name] = embedding["embedding"]
added_trainable_emb_info.append(name)
added_emb_info = [emb["name"] for emb in embeddings]
added_emb_info = ", ".join(added_emb_info)
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
if added_trainable_emb_info:
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Replace external input ids to 0.
Args:
input_ids (torch.Tensor): The input ids to be replaced.
Returns:
torch.Tensor: The replaced input ids.
"""
input_ids_fwd = input_ids.clone()
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
return input_ids_fwd
def replace_embeddings(
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
) -> torch.Tensor:
"""Replace external embedding to the embedding layer. Noted that, in
this function we use `torch.cat` to avoid inplace modification.
Args:
input_ids (torch.Tensor): The original token ids. Shape like
[LENGTH, ].
embedding (torch.Tensor): The embedding of token ids after
`replace_input_ids` function.
external_embedding (dict): The external embedding to be replaced.
Returns:
torch.Tensor: The replaced embedding.
"""
new_embedding = []
name = external_embedding["name"]
start = external_embedding["start"]
end = external_embedding["end"]
target_ids_to_replace = [i for i in range(start, end)]
ext_emb = external_embedding["embedding"].to(embedding.device)
# do not need to replace
if not (input_ids == start).any():
return embedding
# start replace
s_idx, e_idx = 0, 0
while e_idx < len(input_ids):
if input_ids[e_idx] == start:
if e_idx != 0:
# add embedding do not need to replace
new_embedding.append(embedding[s_idx:e_idx])
# check if the next embedding need to replace is valid
actually_ids_to_replace = [int(i) for i in input_ids[e_idx: e_idx + end - start]]
assert actually_ids_to_replace == target_ids_to_replace, (
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
f"Expect '{target_ids_to_replace}' for embedding "
f"'{name}' but found '{actually_ids_to_replace}'."
)
new_embedding.append(ext_emb)
s_idx = e_idx + end - start
e_idx = s_idx + 1
else:
e_idx += 1
if e_idx == len(input_ids):
new_embedding.append(embedding[s_idx:e_idx])
return torch.cat(new_embedding, dim=0)
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
"""The forward function.
Args:
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
[LENGTH, ].
external_embeddings (Optional[List[dict]]): The external
embeddings. If not passed, only `self.external_embeddings`
will be used. Defaults to None.
input_ids: shape like [bz, LENGTH] or [LENGTH].
"""
assert input_ids.ndim in [1, 2]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if external_embeddings is None and not self.external_embeddings:
return self.wrapped(input_ids, out_dtype=out_dtype)
input_ids_fwd = self.replace_input_ids(input_ids)
inputs_embeds = self.wrapped(input_ids_fwd)
vecs = []
if external_embeddings is None:
external_embeddings = []
elif isinstance(external_embeddings, dict):
external_embeddings = [external_embeddings]
embeddings = self.external_embeddings + external_embeddings
for input_id, embedding in zip(input_ids, inputs_embeds):
new_embedding = embedding
for external_embedding in embeddings:
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
vecs.append(new_embedding)
return torch.stack(vecs).to(out_dtype)
def add_tokens(
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None,
num_vectors_per_token: int = 1
):
"""Add token for training.
# TODO: support add tokens as dict, then we can load pretrained tokens.
"""
if initialize_tokens is not None:
assert len(initialize_tokens) == len(
placeholder_tokens
), "placeholder_token should be the same length as initialize_token"
for ii in range(len(placeholder_tokens)):
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
# text_encoder.set_embedding_layer()
embedding_layer = text_encoder.text_model.embeddings.token_embedding
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
embedding_layer = text_encoder.text_model.embeddings.token_embedding
assert embedding_layer is not None, (
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
)
initialize_embedding = []
if initialize_tokens is not None:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
else:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer("a").input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
len_emb = temp_embedding.shape[0]
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
initialize_embedding.append(init_weight)
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
token_info_all = []
for ii in range(len(placeholder_tokens)):
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
token_info["embedding"] = initialize_embedding[ii]
token_info["trainable"] = True
token_info_all.append(token_info)
embedding_layer.add_embeddings(token_info_all)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff