Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
784
custom_nodes/x-flux-comfyui/nodes.py
Normal file
784
custom_nodes/x-flux-comfyui/nodes.py
Normal file
@@ -0,0 +1,784 @@
|
||||
import os
|
||||
|
||||
import comfy.model_management as mm
|
||||
import comfy.model_patcher as mp
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.clip_vision import load as load_clip_vision
|
||||
from comfy.clip_vision import clip_preprocess, Output
|
||||
import latent_preview
|
||||
import copy
|
||||
|
||||
import folder_paths
|
||||
|
||||
import torch
|
||||
#from .xflux.src.flux.modules.layers import DoubleStreamBlockLoraProcessor, DoubleStreamBlockProcessor
|
||||
#from .xflux.src.flux.model import Flux as ModFlux
|
||||
|
||||
from .xflux.src.flux.util import (configs, load_ae, load_clip,
|
||||
load_flow_model, load_t5, load_safetensors, load_from_repo_id,
|
||||
load_controlnet)
|
||||
|
||||
|
||||
from .utils import (FirstHalfStrengthModel, FluxUpdateModules, LinearStrengthModel,
|
||||
SecondHalfStrengthModel, SigmoidStrengthModel, attn_processors,
|
||||
set_attn_processor,
|
||||
is_model_pathched, merge_loras, LATENT_PROCESSOR_COMFY,
|
||||
ControlNetContainer,
|
||||
comfy_to_xlabs_lora, check_is_comfy_lora)
|
||||
from .layers import (DoubleStreamBlockLoraProcessor,
|
||||
DoubleStreamBlockProcessor,
|
||||
DoubleStreamBlockLorasMixerProcessor,
|
||||
DoubleStreamMixerProcessor,
|
||||
IPProcessor,
|
||||
ImageProjModel)
|
||||
from .xflux.src.flux.model import Flux as ModFlux
|
||||
#from .model_init import double_blocks_init, single_blocks_init
|
||||
|
||||
|
||||
from comfy.utils import get_attr, set_attr
|
||||
from .clip import FluxClipViT
|
||||
|
||||
|
||||
dir_xlabs = os.path.join(folder_paths.models_dir, "xlabs")
|
||||
os.makedirs(dir_xlabs, exist_ok=True)
|
||||
dir_xlabs_loras = os.path.join(dir_xlabs, "loras")
|
||||
os.makedirs(dir_xlabs_loras, exist_ok=True)
|
||||
dir_xlabs_controlnets = os.path.join(dir_xlabs, "controlnets")
|
||||
os.makedirs(dir_xlabs_controlnets, exist_ok=True)
|
||||
dir_xlabs_flux = os.path.join(dir_xlabs, "flux")
|
||||
os.makedirs(dir_xlabs_flux, exist_ok=True)
|
||||
dir_xlabs_ipadapters = os.path.join(dir_xlabs, "ipadapters")
|
||||
os.makedirs(dir_xlabs_ipadapters, exist_ok=True)
|
||||
|
||||
|
||||
folder_paths.folder_names_and_paths["xlabs"] = ([dir_xlabs], folder_paths.supported_pt_extensions)
|
||||
folder_paths.folder_names_and_paths["xlabs_loras"] = ([dir_xlabs_loras], folder_paths.supported_pt_extensions)
|
||||
folder_paths.folder_names_and_paths["xlabs_controlnets"] = ([dir_xlabs_controlnets], folder_paths.supported_pt_extensions)
|
||||
folder_paths.folder_names_and_paths["xlabs_ipadapters"] = ([dir_xlabs_ipadapters], folder_paths.supported_pt_extensions)
|
||||
folder_paths.folder_names_and_paths["xlabs_flux"] = ([dir_xlabs_flux], folder_paths.supported_pt_extensions)
|
||||
folder_paths.folder_names_and_paths["xlabs_flux_json"] = ([dir_xlabs_flux], set({'.json',}))
|
||||
|
||||
|
||||
|
||||
from .sampling import get_noise, prepare, get_schedule, denoise, denoise_controlnet, unpack
|
||||
import numpy as np
|
||||
|
||||
def load_flux_lora(path):
|
||||
if path is not None:
|
||||
if '.safetensors' in path:
|
||||
checkpoint = load_safetensors(path)
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
else:
|
||||
checkpoint = None
|
||||
print("Invalid path")
|
||||
a1 = sorted(list(checkpoint[list(checkpoint.keys())[0]].shape))[0]
|
||||
a2 = sorted(list(checkpoint[list(checkpoint.keys())[1]].shape))[0]
|
||||
if a1==a2:
|
||||
return checkpoint, int(a1)
|
||||
return checkpoint, 16
|
||||
|
||||
def cleanprint(a):
|
||||
print(a)
|
||||
return a
|
||||
|
||||
def print_if_not_empty(a):
|
||||
b = list(a.items())
|
||||
if len(b)<1:
|
||||
return "{}"
|
||||
return b[0]
|
||||
class LoadFluxLora:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (cleanprint(folder_paths.get_filename_list("xlabs_loras")), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def loadmodel(self, model, lora_name, strength_model):
|
||||
debug=False
|
||||
|
||||
|
||||
device=mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
is_patched = is_model_pathched(model.model)
|
||||
|
||||
print(f"Is model already patched? {is_patched}")
|
||||
mul = 1
|
||||
if is_patched:
|
||||
pbar = ProgressBar(5)
|
||||
else:
|
||||
mul = 3
|
||||
count = len(model.model.diffusion_model.double_blocks)
|
||||
pbar = ProgressBar(5*mul+count)
|
||||
|
||||
bi = model.clone()
|
||||
tyanochky = bi.model
|
||||
|
||||
if debug:
|
||||
print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n___\n", (print_if_not_empty(bi.object_patches)), "\n")
|
||||
try:
|
||||
print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor.lora_weight"))
|
||||
except:
|
||||
pass
|
||||
|
||||
pbar.update(mul)
|
||||
bi.model.to(device)
|
||||
checkpoint, lora_rank = load_flux_lora(os.path.join(dir_xlabs_loras, lora_name))
|
||||
pbar.update(mul)
|
||||
if not is_patched:
|
||||
print("We are patching diffusion model, be patient please")
|
||||
patches=FluxUpdateModules(tyanochky, pbar)
|
||||
#set_attn_processor(model.model.diffusion_model, DoubleStreamBlockProcessor())
|
||||
else:
|
||||
print("Model already updated")
|
||||
pbar.update(mul)
|
||||
#TYANOCHKYBY=16
|
||||
|
||||
lora_attn_procs = {}
|
||||
if checkpoint is not None:
|
||||
if check_is_comfy_lora(checkpoint):
|
||||
checkpoint = comfy_to_xlabs_lora(checkpoint)
|
||||
#cached_proccesors = attn_processors(tyanochky.diffusion_model).items()
|
||||
for name, _ in attn_processors(tyanochky.diffusion_model).items():
|
||||
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
|
||||
dim=3072, rank=lora_rank, lora_weight=strength_model)
|
||||
lora_state_dict = {}
|
||||
for k in checkpoint.keys():
|
||||
if name in k:
|
||||
lora_state_dict[k[len(name) + 1:]] = checkpoint[k]
|
||||
lora_attn_procs[name].load_state_dict(lora_state_dict)
|
||||
lora_attn_procs[name].to(device)
|
||||
tmp=DoubleStreamMixerProcessor()
|
||||
tmp.add_lora(lora_attn_procs[name])
|
||||
lora_attn_procs[name]=tmp
|
||||
pbar.update(mul)
|
||||
#set_attn_processor(tyanochky.diffusion_model, lora_attn_procs)
|
||||
if debug:
|
||||
try:
|
||||
if isinstance(
|
||||
get_attr(tyanochky, "diffusion_model.double_blocks.0.processor"),
|
||||
DoubleStreamMixerProcessor
|
||||
):
|
||||
pedovki = get_attr(tyanochky, "diffusion_model.double_blocks.0.processor").lora_weight
|
||||
if len(pedovki)>0:
|
||||
altushki="".join([f"{pedov:.2f}, " for pedov in pedovki])
|
||||
print(f"Loras applied: {altushki}")
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, _ in attn_processors(tyanochky.diffusion_model).items():
|
||||
attribute = f"diffusion_model.{name}"
|
||||
#old = copy.copy(get_attr(bi.model, attribute))
|
||||
if attribute in model.object_patches.keys():
|
||||
old = copy.copy((model.object_patches[attribute]))
|
||||
else:
|
||||
old = None
|
||||
lora = merge_loras(old, lora_attn_procs[name])
|
||||
bi.add_object_patch(attribute, lora)
|
||||
|
||||
|
||||
if debug:
|
||||
print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n_b_\n", (print_if_not_empty(bi.object_patches)), "\n")
|
||||
print("\n", (print_if_not_empty(model.object_patches_backup)), "\n_m__\n", (print_if_not_empty(model.object_patches)), "\n")
|
||||
|
||||
for _, b in bi.object_patches.items():
|
||||
print(b.lora_weight)
|
||||
break
|
||||
|
||||
#print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor"))
|
||||
pbar.update(mul)
|
||||
return (bi,)
|
||||
|
||||
def load_checkpoint_controlnet(local_path):
|
||||
if local_path is not None:
|
||||
if '.safetensors' in local_path:
|
||||
checkpoint = load_safetensors(local_path)
|
||||
else:
|
||||
checkpoint = torch.load(local_path, map_location='cpu')
|
||||
else:
|
||||
checkpoint=None
|
||||
print("Invalid path")
|
||||
return checkpoint
|
||||
|
||||
class LoadFluxControlNet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model_name": (["flux-dev", "flux-dev-fp8", "flux-schnell"],),
|
||||
"controlnet_path": (folder_paths.get_filename_list("xlabs_controlnets"), ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("FluxControlNet",)
|
||||
RETURN_NAMES = ("ControlNet",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def loadmodel(self, model_name, controlnet_path):
|
||||
device=mm.get_torch_device()
|
||||
|
||||
controlnet = load_controlnet(model_name, device)
|
||||
checkpoint = load_checkpoint_controlnet(os.path.join(dir_xlabs_controlnets, controlnet_path))
|
||||
if checkpoint is not None:
|
||||
controlnet.load_state_dict(checkpoint)
|
||||
control_type = "canny"
|
||||
ret_controlnet = {
|
||||
"model": controlnet,
|
||||
"control_type": control_type,
|
||||
}
|
||||
return (ret_controlnet,)
|
||||
|
||||
class ApplyFluxControlNet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"controlnet": ("FluxControlNet",),
|
||||
"image": ("IMAGE", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"controlnet_condition": ("ControlNetCondition", {"default": None}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ControlNetCondition",)
|
||||
RETURN_NAMES = ("controlnet_condition",)
|
||||
FUNCTION = "prepare"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def prepare(self, controlnet, image, strength, controlnet_condition = None):
|
||||
device=mm.get_torch_device()
|
||||
controlnet_image = torch.from_numpy((np.array(image) * 2) - 1)
|
||||
controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device)
|
||||
|
||||
if controlnet_condition is None:
|
||||
ret_cont = [{
|
||||
"img": controlnet_image,
|
||||
"controlnet_strength": strength,
|
||||
"model": controlnet["model"],
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}]
|
||||
else:
|
||||
ret_cont = controlnet_condition+[{
|
||||
"img": controlnet_image,
|
||||
"controlnet_strength": strength,
|
||||
"model": controlnet["model"],
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}]
|
||||
return (ret_cont,)
|
||||
|
||||
class ApplyAdvancedFluxControlNet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"controlnet": ("FluxControlNet",),
|
||||
"image": ("IMAGE", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
},
|
||||
"optional": {
|
||||
"controlnet_condition": ("ControlNetCondition", {"default": None}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ControlNetCondition",)
|
||||
RETURN_NAMES = ("controlnet_condition",)
|
||||
FUNCTION = "prepare"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def prepare(self, controlnet, image, strength, start, end, controlnet_condition = None):
|
||||
|
||||
device=mm.get_torch_device()
|
||||
controlnet_image = torch.from_numpy((np.array(image) * 2) - 1)
|
||||
controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device)
|
||||
|
||||
ret_cont = {
|
||||
"img": controlnet_image,
|
||||
"controlnet_strength": strength,
|
||||
"model": controlnet["model"],
|
||||
"start": start,
|
||||
"end": end
|
||||
}
|
||||
if controlnet_condition is None:
|
||||
ret_cont = [ret_cont]
|
||||
else:
|
||||
ret_cont = controlnet_condition+[ret_cont]
|
||||
return (ret_cont,)
|
||||
|
||||
class XlabsSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"neg_conditioning": ("CONDITIONING",),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"timestep_to_start_cfg": ("INT", {"default": 20, "min": 0, "max": 100}),
|
||||
"true_gs": ("FLOAT", {"default": 3, "min": 0, "max": 100}),
|
||||
"image_to_image_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"latent_image": ("LATENT", {"default": None}),
|
||||
"controlnet_condition": ("ControlNetCondition", {"default": None}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_NAMES = ("latent",)
|
||||
FUNCTION = "sampling"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def sampling(self, model, conditioning, neg_conditioning,
|
||||
noise_seed, steps, timestep_to_start_cfg, true_gs,
|
||||
image_to_image_strength, denoise_strength,
|
||||
latent_image=None, controlnet_condition=None
|
||||
):
|
||||
additional_steps = 11 if controlnet_condition is None else 12
|
||||
mm.load_model_gpu(model)
|
||||
inmodel = model.model
|
||||
#print(conditioning[0][0].shape) #//t5
|
||||
#print(conditioning[0][1]['pooled_output'].shape) #//clip
|
||||
#print(latent_image['samples'].shape) #// torch.Size([1, 4, 64, 64]) // bc, 4, w//8, h//8
|
||||
try:
|
||||
guidance = conditioning[0][1]['guidance']
|
||||
except:
|
||||
guidance = 1.0
|
||||
|
||||
device=mm.get_torch_device()
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
if torch.cuda.is_bf16_supported():
|
||||
dtype_model = torch.bfloat16
|
||||
else:
|
||||
dtype_model = torch.float16
|
||||
#dtype_model = torch.bfloat16#model.model.diffusion_model.img_in.weight.dtype
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
torch.manual_seed(noise_seed)
|
||||
|
||||
bc, c, h, w = latent_image['samples'].shape
|
||||
height = (h//2) * 16
|
||||
width = (w//2) * 16
|
||||
|
||||
x = get_noise(
|
||||
bc, height, width, device=device,
|
||||
dtype=dtype_model, seed=noise_seed
|
||||
)
|
||||
orig_x = None
|
||||
if c==16:
|
||||
orig_x=latent_image['samples']
|
||||
lat_processor2 = LATENT_PROCESSOR_COMFY()
|
||||
orig_x=lat_processor2.go_back(orig_x)
|
||||
orig_x=orig_x.to(device, dtype=dtype_model)
|
||||
|
||||
|
||||
timesteps = get_schedule(
|
||||
steps,
|
||||
(width // 8) * (height // 8) // 4,
|
||||
shift=True,
|
||||
)
|
||||
try:
|
||||
inmodel.to(device)
|
||||
except:
|
||||
pass
|
||||
x.to(device)
|
||||
|
||||
inmodel.diffusion_model.to(device)
|
||||
inp_cond = prepare(conditioning[0][0], conditioning[0][1]['pooled_output'], img=x)
|
||||
neg_inp_cond = prepare(neg_conditioning[0][0], neg_conditioning[0][1]['pooled_output'], img=x)
|
||||
|
||||
if denoise_strength<=0.99:
|
||||
try:
|
||||
timesteps=timesteps[:int(len(timesteps)*denoise_strength)]
|
||||
except:
|
||||
pass
|
||||
# for sampler preview
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, len(timesteps) - 1, x0_output)
|
||||
|
||||
if controlnet_condition is None:
|
||||
x = denoise(
|
||||
inmodel.diffusion_model, **inp_cond, timesteps=timesteps, guidance=guidance,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs,
|
||||
image2image_strength=image_to_image_strength,
|
||||
orig_image=orig_x,
|
||||
callback=callback,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
else:
|
||||
def prepare_controlnet_condition(controlnet_condition):
|
||||
controlnet = controlnet_condition['model']
|
||||
controlnet_image = controlnet_condition['img']
|
||||
controlnet_image = torch.nn.functional.interpolate(
|
||||
controlnet_image, size=(height, width), scale_factor=None, mode='bicubic',)
|
||||
controlnet_strength = controlnet_condition['controlnet_strength']
|
||||
controlnet_start = controlnet_condition['start']
|
||||
controlnet_end = controlnet_condition['end']
|
||||
controlnet.to(device, dtype=dtype_model)
|
||||
controlnet_image=controlnet_image.to(device, dtype=dtype_model)
|
||||
return {
|
||||
"img": controlnet_image,
|
||||
"controlnet_strength": controlnet_strength,
|
||||
"model": controlnet,
|
||||
"start": controlnet_start,
|
||||
"end": controlnet_end,
|
||||
}
|
||||
|
||||
|
||||
cnet_conditions = [prepare_controlnet_condition(el) for el in controlnet_condition]
|
||||
containers = []
|
||||
for el in cnet_conditions:
|
||||
start_step = int(el['start']*len(timesteps))
|
||||
end_step = int(el['end']*len(timesteps))
|
||||
container = ControlNetContainer(el['model'], el['img'], el['controlnet_strength'], start_step, end_step)
|
||||
containers.append(container)
|
||||
|
||||
mm.load_models_gpu([model,])
|
||||
#mm.load_model_gpu(controlnet)
|
||||
|
||||
total_steps = len(timesteps)
|
||||
|
||||
x = denoise_controlnet(
|
||||
inmodel.diffusion_model, **inp_cond,
|
||||
controlnets_container=containers,
|
||||
timesteps=timesteps, guidance=guidance,
|
||||
#controlnet_cond=controlnet_image,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs,
|
||||
#controlnet_gs=controlnet_strength,
|
||||
image2image_strength=image_to_image_strength,
|
||||
orig_image=orig_x,
|
||||
callback=callback,
|
||||
width=width,
|
||||
height=height,
|
||||
#controlnet_start_step=start_step,
|
||||
#controlnet_end_step=end_step
|
||||
)
|
||||
#controlnet.to(offload_device)
|
||||
|
||||
x = unpack(x, height, width)
|
||||
lat_processor = LATENT_PROCESSOR_COMFY()
|
||||
x = lat_processor(x)
|
||||
lat_ret = {"samples": x}
|
||||
|
||||
#model.model.to(offload_device)
|
||||
return (lat_ret,)
|
||||
|
||||
|
||||
|
||||
class LoadFluxIPAdapter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"ipadatper": (folder_paths.get_filename_list("xlabs_ipadapters"),),
|
||||
"clip_vision": (folder_paths.get_filename_list("clip_vision"),),
|
||||
"provider": (["CPU", "GPU",],),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("IP_ADAPTER_FLUX",)
|
||||
RETURN_NAMES = ("ipadapterFlux",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def loadmodel(self, ipadatper, clip_vision, provider):
|
||||
pbar = ProgressBar(6)
|
||||
device=mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
pbar.update(1)
|
||||
ret_ipa = {}
|
||||
path = os.path.join(dir_xlabs_ipadapters, ipadatper)
|
||||
ckpt = load_safetensors(path)
|
||||
pbar.update(1)
|
||||
path_clip = folder_paths.get_full_path("clip_vision", clip_vision)
|
||||
|
||||
try:
|
||||
clip = FluxClipViT(path_clip)
|
||||
except:
|
||||
clip = load_clip_vision(path_clip).model
|
||||
|
||||
ret_ipa["clip_vision"] = clip
|
||||
prefix = "double_blocks."
|
||||
blocks = {}
|
||||
proj = {}
|
||||
for key, value in ckpt.items():
|
||||
if key.startswith(prefix):
|
||||
blocks[key[len(prefix):].replace('.processor.', '.')] = value
|
||||
if key.startswith("ip_adapter_proj_model"):
|
||||
proj[key[len("ip_adapter_proj_model."):]] = value
|
||||
pbar.update(1)
|
||||
img_vec_in_dim=768
|
||||
context_in_dim=4096
|
||||
num_ip_tokens=16
|
||||
if ckpt['ip_adapter_proj_model.proj.weight'].shape[0]//4096==4:
|
||||
num_ip_tokens=4
|
||||
else:
|
||||
num_ip_tokens=16
|
||||
improj = ImageProjModel(context_in_dim, img_vec_in_dim, num_ip_tokens)
|
||||
improj.load_state_dict(proj)
|
||||
pbar.update(1)
|
||||
ret_ipa["ip_adapter_proj_model"] = improj
|
||||
|
||||
ret_ipa["double_blocks"] = torch.nn.ModuleList([IPProcessor(4096, 3072) for i in range(19)])
|
||||
ret_ipa["double_blocks"].load_state_dict(blocks)
|
||||
pbar.update(1)
|
||||
return (ret_ipa,)
|
||||
|
||||
|
||||
|
||||
class ApplyFluxIPAdapter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"ip_adapter_flux": ("IP_ADAPTER_FLUX",),
|
||||
"image": ("IMAGE",),
|
||||
"ip_scale": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
#"text_scale": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
FUNCTION = "applymodel"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def applymodel(self, model, ip_adapter_flux, image, ip_scale):
|
||||
debug=False
|
||||
|
||||
|
||||
device=mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
is_patched = is_model_pathched(model.model)
|
||||
|
||||
print(f"Is model already patched? {is_patched}")
|
||||
mul = 1
|
||||
if is_patched:
|
||||
pbar = ProgressBar(5)
|
||||
else:
|
||||
mul = 3
|
||||
count = len(model.model.diffusion_model.double_blocks)
|
||||
pbar = ProgressBar(5*mul+count)
|
||||
|
||||
bi = model.clone()
|
||||
tyanochky = bi.model
|
||||
|
||||
clip = ip_adapter_flux['clip_vision']
|
||||
|
||||
if isinstance(clip, FluxClipViT):
|
||||
#torch.Size([1, 526, 526, 3])
|
||||
#image = torch.permute(image, (0, ))
|
||||
#print(image.shape)
|
||||
#print(image)
|
||||
clip_device = next(clip.model.parameters()).device
|
||||
image = torch.clip(image*255, 0.0, 255)
|
||||
out = clip(image).to(dtype=torch.bfloat16)
|
||||
neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
print("Using old vit clip")
|
||||
clip_device = next(clip.parameters()).device
|
||||
pixel_values = clip_preprocess(image.to(clip_device)).float()
|
||||
out = clip(pixel_values=pixel_values)
|
||||
neg_out = clip(pixel_values=torch.zeros_like(pixel_values))
|
||||
neg_out = neg_out[2].to(dtype=torch.bfloat16)
|
||||
out = out[2].to(dtype=torch.bfloat16)
|
||||
pbar.update(mul)
|
||||
if not is_patched:
|
||||
print("We are patching diffusion model, be patient please")
|
||||
patches=FluxUpdateModules(tyanochky, pbar)
|
||||
print("Patched succesfully!")
|
||||
else:
|
||||
print("Model already updated")
|
||||
pbar.update(mul)
|
||||
|
||||
#TYANOCHKYBY=16
|
||||
ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
|
||||
ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
|
||||
ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
|
||||
ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
|
||||
ipad_blocks = []
|
||||
for block in ip_adapter_flux['double_blocks']:
|
||||
ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, ip_scale)
|
||||
ipad.load_state_dict(block.state_dict())
|
||||
ipad.in_hidden_states_neg = ip_neg_pr
|
||||
ipad.in_hidden_states_pos = ip_projes
|
||||
ipad.to(dtype=torch.bfloat16)
|
||||
npp = DoubleStreamMixerProcessor()
|
||||
npp.add_ipadapter(ipad)
|
||||
ipad_blocks.append(npp)
|
||||
pbar.update(mul)
|
||||
i=0
|
||||
for name, _ in attn_processors(tyanochky.diffusion_model).items():
|
||||
attribute = f"diffusion_model.{name}"
|
||||
#old = copy.copy(get_attr(bi.model, attribute))
|
||||
if attribute in model.object_patches.keys():
|
||||
old = copy.copy((model.object_patches[attribute]))
|
||||
else:
|
||||
old = None
|
||||
processor = merge_loras(old, ipad_blocks[i])
|
||||
processor.to(device, dtype=torch.bfloat16)
|
||||
bi.add_object_patch(attribute, processor)
|
||||
i+=1
|
||||
pbar.update(mul)
|
||||
return (bi,)
|
||||
|
||||
|
||||
|
||||
class ApplyAdvancedFluxIPAdapter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"ip_adapter_flux": ("IP_ADAPTER_FLUX",),
|
||||
"image": ("IMAGE",),
|
||||
#"text_scale": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
"begin_strength": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
"end_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
"smothing_type": (["Linear", "First half", "Second half", "Sigmoid"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
FUNCTION = "applymodel"
|
||||
CATEGORY = "XLabsNodes"
|
||||
|
||||
def applymodel(self, model, ip_adapter_flux, image, begin_strength, end_strength, smothing_type):
|
||||
debug=False
|
||||
|
||||
|
||||
device=mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
is_patched = is_model_pathched(model.model)
|
||||
|
||||
print(f"Is model already patched? {is_patched}")
|
||||
mul = 1
|
||||
if is_patched:
|
||||
pbar = ProgressBar(5)
|
||||
else:
|
||||
mul = 3
|
||||
count = len(model.model.diffusion_model.double_blocks)
|
||||
pbar = ProgressBar(5*mul+count)
|
||||
|
||||
bi = model.clone()
|
||||
tyanochky = bi.model
|
||||
|
||||
clip = ip_adapter_flux['clip_vision']
|
||||
|
||||
if isinstance(clip, FluxClipViT):
|
||||
#torch.Size([1, 526, 526, 3])
|
||||
#image = torch.permute(image, (0, ))
|
||||
#print(image.shape)
|
||||
#print(image)
|
||||
clip_device = next(clip.model.parameters()).device
|
||||
image = torch.clip(image*255, 0.0, 255)
|
||||
out = clip(image).to(dtype=torch.bfloat16)
|
||||
neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
print("Using old vit clip")
|
||||
clip_device = next(clip.parameters()).device
|
||||
pixel_values = clip_preprocess(image.to(clip_device)).float()
|
||||
out = clip(pixel_values=pixel_values)
|
||||
neg_out = clip(pixel_values=torch.zeros_like(pixel_values))
|
||||
neg_out = neg_out[2].to(dtype=torch.bfloat16)
|
||||
out = out[2].to(dtype=torch.bfloat16)
|
||||
|
||||
pbar.update(mul)
|
||||
if not is_patched:
|
||||
print("We are patching diffusion model, be patient please")
|
||||
patches=FluxUpdateModules(tyanochky, pbar)
|
||||
print("Patched succesfully!")
|
||||
else:
|
||||
print("Model already updated")
|
||||
pbar.update(mul)
|
||||
|
||||
#TYANOCHKYBY=16
|
||||
ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
|
||||
ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
|
||||
out=torch.mean(out, 0)
|
||||
neg_out=torch.mean(neg_out, 0)
|
||||
ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
|
||||
ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
|
||||
count = len(ip_adapter_flux['double_blocks'])
|
||||
|
||||
if smothing_type == "Linear":
|
||||
strength_model = LinearStrengthModel(begin_strength, end_strength, count)
|
||||
elif smothing_type == "First half":
|
||||
strength_model = FirstHalfStrengthModel(begin_strength, end_strength, count)
|
||||
elif smothing_type == "Second half":
|
||||
strength_model = SecondHalfStrengthModel(begin_strength, end_strength, count)
|
||||
elif smothing_type == "Sigmoid":
|
||||
strength_model = SigmoidStrengthModel(begin_strength, end_strength, count)
|
||||
else:
|
||||
raise ValueError("Invalid smothing type")
|
||||
|
||||
|
||||
ipad_blocks = []
|
||||
for i, block in enumerate(ip_adapter_flux['double_blocks']):
|
||||
ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model[i])
|
||||
ipad.load_state_dict(block.state_dict())
|
||||
ipad.in_hidden_states_neg = ip_neg_pr
|
||||
ipad.in_hidden_states_pos = ip_projes
|
||||
ipad.to(dtype=torch.bfloat16)
|
||||
npp = DoubleStreamMixerProcessor()
|
||||
npp.add_ipadapter(ipad)
|
||||
ipad_blocks.append(npp)
|
||||
pbar.update(mul)
|
||||
i=0
|
||||
for name, _ in attn_processors(tyanochky.diffusion_model).items():
|
||||
attribute = f"diffusion_model.{name}"
|
||||
#old = copy.copy(get_attr(bi.model, attribute))
|
||||
if attribute in model.object_patches.keys():
|
||||
old = copy.copy((model.object_patches[attribute]))
|
||||
else:
|
||||
old = None
|
||||
processor = merge_loras(old, ipad_blocks[i])
|
||||
processor.to(device, dtype=torch.bfloat16)
|
||||
bi.add_object_patch(attribute, processor)
|
||||
i+=1
|
||||
pbar.update(mul)
|
||||
return (bi,)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FluxLoraLoader": LoadFluxLora,
|
||||
"LoadFluxControlNet": LoadFluxControlNet,
|
||||
"ApplyFluxControlNet": ApplyFluxControlNet,
|
||||
"ApplyAdvancedFluxControlNet": ApplyAdvancedFluxControlNet,
|
||||
"XlabsSampler": XlabsSampler,
|
||||
"ApplyFluxIPAdapter": ApplyFluxIPAdapter,
|
||||
"LoadFluxIPAdapter": LoadFluxIPAdapter,
|
||||
"ApplyAdvancedFluxIPAdapter": ApplyAdvancedFluxIPAdapter,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxLoraLoader": "Load Flux LoRA",
|
||||
"LoadFluxControlNet": "Load Flux ControlNet",
|
||||
"ApplyFluxControlNet": "Apply Flux ControlNet",
|
||||
"ApplyAdvancedFluxControlNet": "Apply Advanced Flux ControlNet",
|
||||
"XlabsSampler": "Xlabs Sampler",
|
||||
"ApplyFluxIPAdapter": "Apply Flux IPAdapter",
|
||||
"LoadFluxIPAdapter": "Load Flux IPAdatpter",
|
||||
"ApplyAdvancedFluxIPAdapter": "Apply Advanced Flux IPAdapter",
|
||||
}
|
||||
Reference in New Issue
Block a user