Files
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

785 lines
30 KiB
Python

import os
import comfy.model_management as mm
import comfy.model_patcher as mp
from comfy.utils import ProgressBar
from comfy.clip_vision import load as load_clip_vision
from comfy.clip_vision import clip_preprocess, Output
import latent_preview
import copy
import folder_paths
import torch
#from .xflux.src.flux.modules.layers import DoubleStreamBlockLoraProcessor, DoubleStreamBlockProcessor
#from .xflux.src.flux.model import Flux as ModFlux
from .xflux.src.flux.util import (configs, load_ae, load_clip,
load_flow_model, load_t5, load_safetensors, load_from_repo_id,
load_controlnet)
from .utils import (FirstHalfStrengthModel, FluxUpdateModules, LinearStrengthModel,
SecondHalfStrengthModel, SigmoidStrengthModel, attn_processors,
set_attn_processor,
is_model_pathched, merge_loras, LATENT_PROCESSOR_COMFY,
ControlNetContainer,
comfy_to_xlabs_lora, check_is_comfy_lora)
from .layers import (DoubleStreamBlockLoraProcessor,
DoubleStreamBlockProcessor,
DoubleStreamBlockLorasMixerProcessor,
DoubleStreamMixerProcessor,
IPProcessor,
ImageProjModel)
from .xflux.src.flux.model import Flux as ModFlux
#from .model_init import double_blocks_init, single_blocks_init
from comfy.utils import get_attr, set_attr
from .clip import FluxClipViT
dir_xlabs = os.path.join(folder_paths.models_dir, "xlabs")
os.makedirs(dir_xlabs, exist_ok=True)
dir_xlabs_loras = os.path.join(dir_xlabs, "loras")
os.makedirs(dir_xlabs_loras, exist_ok=True)
dir_xlabs_controlnets = os.path.join(dir_xlabs, "controlnets")
os.makedirs(dir_xlabs_controlnets, exist_ok=True)
dir_xlabs_flux = os.path.join(dir_xlabs, "flux")
os.makedirs(dir_xlabs_flux, exist_ok=True)
dir_xlabs_ipadapters = os.path.join(dir_xlabs, "ipadapters")
os.makedirs(dir_xlabs_ipadapters, exist_ok=True)
folder_paths.folder_names_and_paths["xlabs"] = ([dir_xlabs], folder_paths.supported_pt_extensions)
folder_paths.folder_names_and_paths["xlabs_loras"] = ([dir_xlabs_loras], folder_paths.supported_pt_extensions)
folder_paths.folder_names_and_paths["xlabs_controlnets"] = ([dir_xlabs_controlnets], folder_paths.supported_pt_extensions)
folder_paths.folder_names_and_paths["xlabs_ipadapters"] = ([dir_xlabs_ipadapters], folder_paths.supported_pt_extensions)
folder_paths.folder_names_and_paths["xlabs_flux"] = ([dir_xlabs_flux], folder_paths.supported_pt_extensions)
folder_paths.folder_names_and_paths["xlabs_flux_json"] = ([dir_xlabs_flux], set({'.json',}))
from .sampling import get_noise, prepare, get_schedule, denoise, denoise_controlnet, unpack
import numpy as np
def load_flux_lora(path):
if path is not None:
if '.safetensors' in path:
checkpoint = load_safetensors(path)
else:
checkpoint = torch.load(path, map_location='cpu')
else:
checkpoint = None
print("Invalid path")
a1 = sorted(list(checkpoint[list(checkpoint.keys())[0]].shape))[0]
a2 = sorted(list(checkpoint[list(checkpoint.keys())[1]].shape))[0]
if a1==a2:
return checkpoint, int(a1)
return checkpoint, 16
def cleanprint(a):
print(a)
return a
def print_if_not_empty(a):
b = list(a.items())
if len(b)<1:
return "{}"
return b[0]
class LoadFluxLora:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (cleanprint(folder_paths.get_filename_list("xlabs_loras")), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
FUNCTION = "loadmodel"
CATEGORY = "XLabsNodes"
def loadmodel(self, model, lora_name, strength_model):
debug=False
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
is_patched = is_model_pathched(model.model)
print(f"Is model already patched? {is_patched}")
mul = 1
if is_patched:
pbar = ProgressBar(5)
else:
mul = 3
count = len(model.model.diffusion_model.double_blocks)
pbar = ProgressBar(5*mul+count)
bi = model.clone()
tyanochky = bi.model
if debug:
print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n___\n", (print_if_not_empty(bi.object_patches)), "\n")
try:
print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor.lora_weight"))
except:
pass
pbar.update(mul)
bi.model.to(device)
checkpoint, lora_rank = load_flux_lora(os.path.join(dir_xlabs_loras, lora_name))
pbar.update(mul)
if not is_patched:
print("We are patching diffusion model, be patient please")
patches=FluxUpdateModules(tyanochky, pbar)
#set_attn_processor(model.model.diffusion_model, DoubleStreamBlockProcessor())
else:
print("Model already updated")
pbar.update(mul)
#TYANOCHKYBY=16
lora_attn_procs = {}
if checkpoint is not None:
if check_is_comfy_lora(checkpoint):
checkpoint = comfy_to_xlabs_lora(checkpoint)
#cached_proccesors = attn_processors(tyanochky.diffusion_model).items()
for name, _ in attn_processors(tyanochky.diffusion_model).items():
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
dim=3072, rank=lora_rank, lora_weight=strength_model)
lora_state_dict = {}
for k in checkpoint.keys():
if name in k:
lora_state_dict[k[len(name) + 1:]] = checkpoint[k]
lora_attn_procs[name].load_state_dict(lora_state_dict)
lora_attn_procs[name].to(device)
tmp=DoubleStreamMixerProcessor()
tmp.add_lora(lora_attn_procs[name])
lora_attn_procs[name]=tmp
pbar.update(mul)
#set_attn_processor(tyanochky.diffusion_model, lora_attn_procs)
if debug:
try:
if isinstance(
get_attr(tyanochky, "diffusion_model.double_blocks.0.processor"),
DoubleStreamMixerProcessor
):
pedovki = get_attr(tyanochky, "diffusion_model.double_blocks.0.processor").lora_weight
if len(pedovki)>0:
altushki="".join([f"{pedov:.2f}, " for pedov in pedovki])
print(f"Loras applied: {altushki}")
except:
pass
for name, _ in attn_processors(tyanochky.diffusion_model).items():
attribute = f"diffusion_model.{name}"
#old = copy.copy(get_attr(bi.model, attribute))
if attribute in model.object_patches.keys():
old = copy.copy((model.object_patches[attribute]))
else:
old = None
lora = merge_loras(old, lora_attn_procs[name])
bi.add_object_patch(attribute, lora)
if debug:
print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n_b_\n", (print_if_not_empty(bi.object_patches)), "\n")
print("\n", (print_if_not_empty(model.object_patches_backup)), "\n_m__\n", (print_if_not_empty(model.object_patches)), "\n")
for _, b in bi.object_patches.items():
print(b.lora_weight)
break
#print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor"))
pbar.update(mul)
return (bi,)
def load_checkpoint_controlnet(local_path):
if local_path is not None:
if '.safetensors' in local_path:
checkpoint = load_safetensors(local_path)
else:
checkpoint = torch.load(local_path, map_location='cpu')
else:
checkpoint=None
print("Invalid path")
return checkpoint
class LoadFluxControlNet:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model_name": (["flux-dev", "flux-dev-fp8", "flux-schnell"],),
"controlnet_path": (folder_paths.get_filename_list("xlabs_controlnets"), ),
}}
RETURN_TYPES = ("FluxControlNet",)
RETURN_NAMES = ("ControlNet",)
FUNCTION = "loadmodel"
CATEGORY = "XLabsNodes"
def loadmodel(self, model_name, controlnet_path):
device=mm.get_torch_device()
controlnet = load_controlnet(model_name, device)
checkpoint = load_checkpoint_controlnet(os.path.join(dir_xlabs_controlnets, controlnet_path))
if checkpoint is not None:
controlnet.load_state_dict(checkpoint)
control_type = "canny"
ret_controlnet = {
"model": controlnet,
"control_type": control_type,
}
return (ret_controlnet,)
class ApplyFluxControlNet:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"controlnet": ("FluxControlNet",),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
},
"optional": {
"controlnet_condition": ("ControlNetCondition", {"default": None}),
}
}
RETURN_TYPES = ("ControlNetCondition",)
RETURN_NAMES = ("controlnet_condition",)
FUNCTION = "prepare"
CATEGORY = "XLabsNodes"
def prepare(self, controlnet, image, strength, controlnet_condition = None):
device=mm.get_torch_device()
controlnet_image = torch.from_numpy((np.array(image) * 2) - 1)
controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device)
if controlnet_condition is None:
ret_cont = [{
"img": controlnet_image,
"controlnet_strength": strength,
"model": controlnet["model"],
"start": 0.0,
"end": 1.0
}]
else:
ret_cont = controlnet_condition+[{
"img": controlnet_image,
"controlnet_strength": strength,
"model": controlnet["model"],
"start": 0.0,
"end": 1.0
}]
return (ret_cont,)
class ApplyAdvancedFluxControlNet:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"controlnet": ("FluxControlNet",),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
},
"optional": {
"controlnet_condition": ("ControlNetCondition", {"default": None}),
}
}
RETURN_TYPES = ("ControlNetCondition",)
RETURN_NAMES = ("controlnet_condition",)
FUNCTION = "prepare"
CATEGORY = "XLabsNodes"
def prepare(self, controlnet, image, strength, start, end, controlnet_condition = None):
device=mm.get_torch_device()
controlnet_image = torch.from_numpy((np.array(image) * 2) - 1)
controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device)
ret_cont = {
"img": controlnet_image,
"controlnet_strength": strength,
"model": controlnet["model"],
"start": start,
"end": end
}
if controlnet_condition is None:
ret_cont = [ret_cont]
else:
ret_cont = controlnet_condition+[ret_cont]
return (ret_cont,)
class XlabsSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"conditioning": ("CONDITIONING",),
"neg_conditioning": ("CONDITIONING",),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"timestep_to_start_cfg": ("INT", {"default": 20, "min": 0, "max": 100}),
"true_gs": ("FLOAT", {"default": 3, "min": 0, "max": 100}),
"image_to_image_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"latent_image": ("LATENT", {"default": None}),
"controlnet_condition": ("ControlNetCondition", {"default": None}),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("latent",)
FUNCTION = "sampling"
CATEGORY = "XLabsNodes"
def sampling(self, model, conditioning, neg_conditioning,
noise_seed, steps, timestep_to_start_cfg, true_gs,
image_to_image_strength, denoise_strength,
latent_image=None, controlnet_condition=None
):
additional_steps = 11 if controlnet_condition is None else 12
mm.load_model_gpu(model)
inmodel = model.model
#print(conditioning[0][0].shape) #//t5
#print(conditioning[0][1]['pooled_output'].shape) #//clip
#print(latent_image['samples'].shape) #// torch.Size([1, 4, 64, 64]) // bc, 4, w//8, h//8
try:
guidance = conditioning[0][1]['guidance']
except:
guidance = 1.0
device=mm.get_torch_device()
if torch.backends.mps.is_available():
device = torch.device("mps")
if torch.cuda.is_bf16_supported():
dtype_model = torch.bfloat16
else:
dtype_model = torch.float16
#dtype_model = torch.bfloat16#model.model.diffusion_model.img_in.weight.dtype
offload_device=mm.unet_offload_device()
torch.manual_seed(noise_seed)
bc, c, h, w = latent_image['samples'].shape
height = (h//2) * 16
width = (w//2) * 16
x = get_noise(
bc, height, width, device=device,
dtype=dtype_model, seed=noise_seed
)
orig_x = None
if c==16:
orig_x=latent_image['samples']
lat_processor2 = LATENT_PROCESSOR_COMFY()
orig_x=lat_processor2.go_back(orig_x)
orig_x=orig_x.to(device, dtype=dtype_model)
timesteps = get_schedule(
steps,
(width // 8) * (height // 8) // 4,
shift=True,
)
try:
inmodel.to(device)
except:
pass
x.to(device)
inmodel.diffusion_model.to(device)
inp_cond = prepare(conditioning[0][0], conditioning[0][1]['pooled_output'], img=x)
neg_inp_cond = prepare(neg_conditioning[0][0], neg_conditioning[0][1]['pooled_output'], img=x)
if denoise_strength<=0.99:
try:
timesteps=timesteps[:int(len(timesteps)*denoise_strength)]
except:
pass
# for sampler preview
x0_output = {}
callback = latent_preview.prepare_callback(model, len(timesteps) - 1, x0_output)
if controlnet_condition is None:
x = denoise(
inmodel.diffusion_model, **inp_cond, timesteps=timesteps, guidance=guidance,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
image2image_strength=image_to_image_strength,
orig_image=orig_x,
callback=callback,
width=width,
height=height,
)
else:
def prepare_controlnet_condition(controlnet_condition):
controlnet = controlnet_condition['model']
controlnet_image = controlnet_condition['img']
controlnet_image = torch.nn.functional.interpolate(
controlnet_image, size=(height, width), scale_factor=None, mode='bicubic',)
controlnet_strength = controlnet_condition['controlnet_strength']
controlnet_start = controlnet_condition['start']
controlnet_end = controlnet_condition['end']
controlnet.to(device, dtype=dtype_model)
controlnet_image=controlnet_image.to(device, dtype=dtype_model)
return {
"img": controlnet_image,
"controlnet_strength": controlnet_strength,
"model": controlnet,
"start": controlnet_start,
"end": controlnet_end,
}
cnet_conditions = [prepare_controlnet_condition(el) for el in controlnet_condition]
containers = []
for el in cnet_conditions:
start_step = int(el['start']*len(timesteps))
end_step = int(el['end']*len(timesteps))
container = ControlNetContainer(el['model'], el['img'], el['controlnet_strength'], start_step, end_step)
containers.append(container)
mm.load_models_gpu([model,])
#mm.load_model_gpu(controlnet)
total_steps = len(timesteps)
x = denoise_controlnet(
inmodel.diffusion_model, **inp_cond,
controlnets_container=containers,
timesteps=timesteps, guidance=guidance,
#controlnet_cond=controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
#controlnet_gs=controlnet_strength,
image2image_strength=image_to_image_strength,
orig_image=orig_x,
callback=callback,
width=width,
height=height,
#controlnet_start_step=start_step,
#controlnet_end_step=end_step
)
#controlnet.to(offload_device)
x = unpack(x, height, width)
lat_processor = LATENT_PROCESSOR_COMFY()
x = lat_processor(x)
lat_ret = {"samples": x}
#model.model.to(offload_device)
return (lat_ret,)
class LoadFluxIPAdapter:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ipadatper": (folder_paths.get_filename_list("xlabs_ipadapters"),),
"clip_vision": (folder_paths.get_filename_list("clip_vision"),),
"provider": (["CPU", "GPU",],),
}
}
RETURN_TYPES = ("IP_ADAPTER_FLUX",)
RETURN_NAMES = ("ipadapterFlux",)
FUNCTION = "loadmodel"
CATEGORY = "XLabsNodes"
def loadmodel(self, ipadatper, clip_vision, provider):
pbar = ProgressBar(6)
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
pbar.update(1)
ret_ipa = {}
path = os.path.join(dir_xlabs_ipadapters, ipadatper)
ckpt = load_safetensors(path)
pbar.update(1)
path_clip = folder_paths.get_full_path("clip_vision", clip_vision)
try:
clip = FluxClipViT(path_clip)
except:
clip = load_clip_vision(path_clip).model
ret_ipa["clip_vision"] = clip
prefix = "double_blocks."
blocks = {}
proj = {}
for key, value in ckpt.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
pbar.update(1)
img_vec_in_dim=768
context_in_dim=4096
num_ip_tokens=16
if ckpt['ip_adapter_proj_model.proj.weight'].shape[0]//4096==4:
num_ip_tokens=4
else:
num_ip_tokens=16
improj = ImageProjModel(context_in_dim, img_vec_in_dim, num_ip_tokens)
improj.load_state_dict(proj)
pbar.update(1)
ret_ipa["ip_adapter_proj_model"] = improj
ret_ipa["double_blocks"] = torch.nn.ModuleList([IPProcessor(4096, 3072) for i in range(19)])
ret_ipa["double_blocks"].load_state_dict(blocks)
pbar.update(1)
return (ret_ipa,)
class ApplyFluxIPAdapter:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ip_adapter_flux": ("IP_ADAPTER_FLUX",),
"image": ("IMAGE",),
"ip_scale": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.001}),
#"text_scale": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
FUNCTION = "applymodel"
CATEGORY = "XLabsNodes"
def applymodel(self, model, ip_adapter_flux, image, ip_scale):
debug=False
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
is_patched = is_model_pathched(model.model)
print(f"Is model already patched? {is_patched}")
mul = 1
if is_patched:
pbar = ProgressBar(5)
else:
mul = 3
count = len(model.model.diffusion_model.double_blocks)
pbar = ProgressBar(5*mul+count)
bi = model.clone()
tyanochky = bi.model
clip = ip_adapter_flux['clip_vision']
if isinstance(clip, FluxClipViT):
#torch.Size([1, 526, 526, 3])
#image = torch.permute(image, (0, ))
#print(image.shape)
#print(image)
clip_device = next(clip.model.parameters()).device
image = torch.clip(image*255, 0.0, 255)
out = clip(image).to(dtype=torch.bfloat16)
neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
else:
print("Using old vit clip")
clip_device = next(clip.parameters()).device
pixel_values = clip_preprocess(image.to(clip_device)).float()
out = clip(pixel_values=pixel_values)
neg_out = clip(pixel_values=torch.zeros_like(pixel_values))
neg_out = neg_out[2].to(dtype=torch.bfloat16)
out = out[2].to(dtype=torch.bfloat16)
pbar.update(mul)
if not is_patched:
print("We are patching diffusion model, be patient please")
patches=FluxUpdateModules(tyanochky, pbar)
print("Patched succesfully!")
else:
print("Model already updated")
pbar.update(mul)
#TYANOCHKYBY=16
ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
ipad_blocks = []
for block in ip_adapter_flux['double_blocks']:
ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, ip_scale)
ipad.load_state_dict(block.state_dict())
ipad.in_hidden_states_neg = ip_neg_pr
ipad.in_hidden_states_pos = ip_projes
ipad.to(dtype=torch.bfloat16)
npp = DoubleStreamMixerProcessor()
npp.add_ipadapter(ipad)
ipad_blocks.append(npp)
pbar.update(mul)
i=0
for name, _ in attn_processors(tyanochky.diffusion_model).items():
attribute = f"diffusion_model.{name}"
#old = copy.copy(get_attr(bi.model, attribute))
if attribute in model.object_patches.keys():
old = copy.copy((model.object_patches[attribute]))
else:
old = None
processor = merge_loras(old, ipad_blocks[i])
processor.to(device, dtype=torch.bfloat16)
bi.add_object_patch(attribute, processor)
i+=1
pbar.update(mul)
return (bi,)
class ApplyAdvancedFluxIPAdapter:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ip_adapter_flux": ("IP_ADAPTER_FLUX",),
"image": ("IMAGE",),
#"text_scale": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
"begin_strength": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.01}),
"end_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
"smothing_type": (["Linear", "First half", "Second half", "Sigmoid"],),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
FUNCTION = "applymodel"
CATEGORY = "XLabsNodes"
def applymodel(self, model, ip_adapter_flux, image, begin_strength, end_strength, smothing_type):
debug=False
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
is_patched = is_model_pathched(model.model)
print(f"Is model already patched? {is_patched}")
mul = 1
if is_patched:
pbar = ProgressBar(5)
else:
mul = 3
count = len(model.model.diffusion_model.double_blocks)
pbar = ProgressBar(5*mul+count)
bi = model.clone()
tyanochky = bi.model
clip = ip_adapter_flux['clip_vision']
if isinstance(clip, FluxClipViT):
#torch.Size([1, 526, 526, 3])
#image = torch.permute(image, (0, ))
#print(image.shape)
#print(image)
clip_device = next(clip.model.parameters()).device
image = torch.clip(image*255, 0.0, 255)
out = clip(image).to(dtype=torch.bfloat16)
neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
else:
print("Using old vit clip")
clip_device = next(clip.parameters()).device
pixel_values = clip_preprocess(image.to(clip_device)).float()
out = clip(pixel_values=pixel_values)
neg_out = clip(pixel_values=torch.zeros_like(pixel_values))
neg_out = neg_out[2].to(dtype=torch.bfloat16)
out = out[2].to(dtype=torch.bfloat16)
pbar.update(mul)
if not is_patched:
print("We are patching diffusion model, be patient please")
patches=FluxUpdateModules(tyanochky, pbar)
print("Patched succesfully!")
else:
print("Model already updated")
pbar.update(mul)
#TYANOCHKYBY=16
ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
out=torch.mean(out, 0)
neg_out=torch.mean(neg_out, 0)
ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
count = len(ip_adapter_flux['double_blocks'])
if smothing_type == "Linear":
strength_model = LinearStrengthModel(begin_strength, end_strength, count)
elif smothing_type == "First half":
strength_model = FirstHalfStrengthModel(begin_strength, end_strength, count)
elif smothing_type == "Second half":
strength_model = SecondHalfStrengthModel(begin_strength, end_strength, count)
elif smothing_type == "Sigmoid":
strength_model = SigmoidStrengthModel(begin_strength, end_strength, count)
else:
raise ValueError("Invalid smothing type")
ipad_blocks = []
for i, block in enumerate(ip_adapter_flux['double_blocks']):
ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model[i])
ipad.load_state_dict(block.state_dict())
ipad.in_hidden_states_neg = ip_neg_pr
ipad.in_hidden_states_pos = ip_projes
ipad.to(dtype=torch.bfloat16)
npp = DoubleStreamMixerProcessor()
npp.add_ipadapter(ipad)
ipad_blocks.append(npp)
pbar.update(mul)
i=0
for name, _ in attn_processors(tyanochky.diffusion_model).items():
attribute = f"diffusion_model.{name}"
#old = copy.copy(get_attr(bi.model, attribute))
if attribute in model.object_patches.keys():
old = copy.copy((model.object_patches[attribute]))
else:
old = None
processor = merge_loras(old, ipad_blocks[i])
processor.to(device, dtype=torch.bfloat16)
bi.add_object_patch(attribute, processor)
i+=1
pbar.update(mul)
return (bi,)
NODE_CLASS_MAPPINGS = {
"FluxLoraLoader": LoadFluxLora,
"LoadFluxControlNet": LoadFluxControlNet,
"ApplyFluxControlNet": ApplyFluxControlNet,
"ApplyAdvancedFluxControlNet": ApplyAdvancedFluxControlNet,
"XlabsSampler": XlabsSampler,
"ApplyFluxIPAdapter": ApplyFluxIPAdapter,
"LoadFluxIPAdapter": LoadFluxIPAdapter,
"ApplyAdvancedFluxIPAdapter": ApplyAdvancedFluxIPAdapter,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxLoraLoader": "Load Flux LoRA",
"LoadFluxControlNet": "Load Flux ControlNet",
"ApplyFluxControlNet": "Apply Flux ControlNet",
"ApplyAdvancedFluxControlNet": "Apply Advanced Flux ControlNet",
"XlabsSampler": "Xlabs Sampler",
"ApplyFluxIPAdapter": "Apply Flux IPAdapter",
"LoadFluxIPAdapter": "Load Flux IPAdatpter",
"ApplyAdvancedFluxIPAdapter": "Apply Advanced Flux IPAdapter",
}