Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
325 lines
14 KiB
Python
325 lines
14 KiB
Python
import torch
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
from typing import List, Dict, Optional, Tuple
|
|
from einops import rearrange
|
|
import comfy.model_management as model_management
|
|
from comfy.ldm.modules import attention as comfy_attention
|
|
from comfy.ldm.flux import math as flux_math
|
|
from comfy.ldm.flux import layers as flux_layers
|
|
import numpy as np
|
|
from PIL import Image, ImageFilter, ImageDraw
|
|
from functools import partial
|
|
|
|
# Protected xformers import
|
|
try:
|
|
from xformers.ops import memory_efficient_attention as xattention
|
|
has_xformers = True
|
|
except ImportError:
|
|
has_xformers = False
|
|
xattention = None
|
|
|
|
class FluxAttentionControl:
|
|
def __init__(self):
|
|
self.original_attention = comfy_attention.optimized_attention
|
|
self.original_flux_attention = flux_math.attention
|
|
self.original_flux_layers_attention = flux_layers.attention
|
|
if not has_xformers:
|
|
print("\n" + "="*70)
|
|
print("\033[94mControlAltAI-Nodes: This node requires xformers to function.\033[0m")
|
|
print("\033[33mPlease check \"xformers_instructions.txt\" in ComfyUI\\custom_nodes\\ControlAltAI-Nodes for how to install XFormers\033[0m")
|
|
print("="*70 + "\n")
|
|
print("FluxAttentionControl initialized")
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"condition": ("CONDITIONING",),
|
|
"latent_dimensions": ("LATENT",),
|
|
"region1": ("REGION",),
|
|
"number_of_regions": ("INT", {
|
|
"default": 1,
|
|
"min": 1,
|
|
"max": 3,
|
|
"step": 1,
|
|
"display": "Number of Regions"
|
|
}),
|
|
"enabled": ("BOOLEAN", {
|
|
"default": True,
|
|
"display": "Enable Regional Control"
|
|
}),
|
|
"feather_radius1": ("FLOAT", {
|
|
"default": 0.0,
|
|
"min": 0.0,
|
|
"max": 100.0,
|
|
"step": 1.0,
|
|
"display": "Feather Radius for Region 1"
|
|
}),
|
|
},
|
|
"optional": {
|
|
"region2": ("REGION",),
|
|
"feather_radius2": ("FLOAT", {
|
|
"default": 0.0,
|
|
"min": 0.0,
|
|
"max": 100.0,
|
|
"step": 1.0,
|
|
"display": "Feather Radius for Region 2"
|
|
}),
|
|
"region3": ("REGION",),
|
|
"feather_radius3": ("FLOAT", {
|
|
"default": 0.0,
|
|
"min": 0.0,
|
|
"max": 100.0,
|
|
"step": 1.0,
|
|
"display": "Feather Radius for Region 3"
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL", "CONDITIONING",)
|
|
RETURN_NAMES = ("model", "conditioning",)
|
|
FUNCTION = "apply_attention_control"
|
|
CATEGORY = "ControlAltAI Nodes/Flux Region"
|
|
|
|
def generate_region_mask(self, region: Dict, width: int, height: int, feather_radius: float) -> Image.Image:
|
|
if region.get('bbox') is not None:
|
|
x1, y1, x2, y2 = region['bbox']
|
|
x1_px = int(x1 * width)
|
|
y1_px = int(y1 * height)
|
|
x2_px = int(x2 * width)
|
|
y2_px = int(y2 * height)
|
|
mask = Image.new('L', (width, height), 0)
|
|
mask_draw = ImageDraw.Draw(mask)
|
|
mask_draw.rectangle([x1_px, y1_px, x2_px, y2_px], fill=255)
|
|
if feather_radius > 0:
|
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=feather_radius))
|
|
print(f'Generating masks with {width}x{height} and [{x1}, {y1}, {x2}, {y2}], feather_radius={feather_radius}')
|
|
return mask
|
|
elif region.get('mask') is not None:
|
|
mask = region['mask'][0].cpu().numpy()
|
|
mask = (mask * 255).astype(np.uint8)
|
|
mask = Image.fromarray(mask)
|
|
mask = mask.resize((width, height))
|
|
if feather_radius > 0:
|
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=feather_radius))
|
|
return mask
|
|
else:
|
|
raise Exception('Unknown region type')
|
|
|
|
def generate_test_mask(self, masks: List[Image.Image], height: int, width: int):
|
|
hH, hW = int(height) // 16, int(width) // 16
|
|
print(f'{width} {height} -> {hW} {hH}')
|
|
|
|
lin_masks = []
|
|
for mask in masks:
|
|
mask = mask.convert('L')
|
|
mask = torch.tensor(np.array(mask)).unsqueeze(0).unsqueeze(0) / 255.0 # Normalize to 0-1
|
|
mask = F.interpolate(mask, (hH, hW), mode='bilinear', align_corners=False).flatten()
|
|
lin_masks.append(mask)
|
|
return lin_masks, hH, hW
|
|
|
|
def prepare_attention_mask(self, lin_masks: List[torch.Tensor], region_strengths: List[float], Nx: int, emb_size: int, emb_len: int):
|
|
"""Prepare attention mask for three regions with per-region strengths."""
|
|
total_len = emb_len + Nx
|
|
n_regs = len(lin_masks)
|
|
|
|
# Initialize attention mask and scales
|
|
cross_mask = torch.zeros(total_len, total_len)
|
|
q_scale = torch.ones(total_len)
|
|
k_scale = torch.ones(total_len)
|
|
|
|
# Indices for embeddings
|
|
main_prompt_start = 0
|
|
main_prompt_end = emb_size
|
|
|
|
# Subprompt indices
|
|
subprompt_starts = [emb_size * (i + 1) for i in range(n_regs)]
|
|
subprompt_ends = [emb_size * (i + 2) for i in range(n_regs)]
|
|
|
|
# Initialize position masks
|
|
position_masks = torch.stack(lin_masks) # Shape: [n_regs, Nx]
|
|
|
|
# Normalize masks so that overlapping areas sum to 1
|
|
position_masks_sum = position_masks.sum(dim=0)
|
|
position_masks_normalized = position_masks / (position_masks_sum + 1e-8)
|
|
|
|
# Build attention masks and scales
|
|
for i in range(n_regs):
|
|
sp_start = subprompt_starts[i]
|
|
sp_end = subprompt_ends[i]
|
|
mask_i = position_masks_normalized[i]
|
|
|
|
# Scale embeddings based on mask and per-region strength
|
|
strength = region_strengths[i]
|
|
q_scale[sp_start:sp_end] = mask_i.mean() * strength
|
|
k_scale[sp_start:sp_end] = mask_i.mean() * strength
|
|
|
|
# Create mask including tokens and positions
|
|
m_with_tokens = torch.cat([torch.ones(emb_len), mask_i])
|
|
mb = m_with_tokens > 0.0 # Include positions where mask > 0
|
|
|
|
# Block attention between positions not in mask and subprompt
|
|
cross_mask[~mb, sp_start:sp_end] = 1
|
|
cross_mask[sp_start:sp_end, ~mb] = 1
|
|
|
|
# Block attention between positions in region and main prompt
|
|
positions_idx = (mask_i > 0.0).nonzero(as_tuple=True)[0] + emb_len
|
|
cross_mask[positions_idx[:, None], main_prompt_start:main_prompt_end] = 1
|
|
cross_mask[main_prompt_start:main_prompt_end, positions_idx[None, :]] = 1
|
|
|
|
# Block attention between subprompts
|
|
for j in range(n_regs):
|
|
if i != j:
|
|
other_sp_start = subprompt_starts[j]
|
|
other_sp_end = subprompt_ends[j]
|
|
cross_mask[sp_start:sp_end, other_sp_start:other_sp_end] = 1
|
|
cross_mask[other_sp_start:other_sp_end, sp_start:sp_end] = 1
|
|
|
|
# Ensure self-attention is allowed
|
|
cross_mask.fill_diagonal_(0)
|
|
|
|
# Prepare scales for GPU
|
|
q_scale = q_scale.reshape(1, 1, -1, 1).cuda()
|
|
k_scale = k_scale.reshape(1, 1, -1, 1).cuda()
|
|
|
|
return cross_mask, q_scale, k_scale
|
|
|
|
def xformers_attention(self, q: Tensor, k: Tensor, v: Tensor, pe: Tensor,
|
|
attn_mask: Optional[Tensor] = None,
|
|
mask: Optional[Tensor] = None) -> Tensor: # Added mask parameter
|
|
q, k = flux_math.apply_rope(q, k, pe)
|
|
q = rearrange(q, "B H L D -> B L H D")
|
|
k = rearrange(k, "B H L D -> B L H D")
|
|
v = rearrange(v, "B H L D -> B L H D")
|
|
|
|
# Use attn_mask if provided, otherwise use the mask parameter
|
|
attention_bias = attn_mask if attn_mask is not None else mask
|
|
|
|
if attention_bias is not None:
|
|
x = xattention(q, k, v, attn_bias=attention_bias)
|
|
else:
|
|
x = xattention(q, k, v)
|
|
|
|
x = rearrange(x, "B L H D -> B L (H D)")
|
|
return x
|
|
|
|
def apply_attention_control(self,
|
|
model: object,
|
|
condition: List,
|
|
latent_dimensions: Dict,
|
|
region1: Dict,
|
|
number_of_regions: int,
|
|
enabled: bool,
|
|
feather_radius1: float = 0.0,
|
|
region2: Optional[Dict] = None,
|
|
feather_radius2: Optional[float] = 0.0,
|
|
region3: Optional[Dict] = None,
|
|
feather_radius3: Optional[float] = 0.0):
|
|
|
|
# Extract dimensions and embeddings first (moved before enabled check)
|
|
latent = latent_dimensions["samples"]
|
|
bs_l, n_ch, lH, lW = latent.shape
|
|
text_emb = condition[0][0].clone()
|
|
clip_emb = condition[0][1]['pooled_output'].clone()
|
|
bs, emb_size, emb_dim = text_emb.shape
|
|
iH, iW = lH * 8, lW * 8
|
|
|
|
if not enabled:
|
|
# Restore original attention functions
|
|
flux_math.attention = self.original_flux_attention
|
|
flux_layers.attention = self.original_flux_layers_attention
|
|
print("Regional control disabled. Restored original attention functions.")
|
|
return (model, condition) # Return original condition when disabled
|
|
|
|
if enabled and not has_xformers:
|
|
raise RuntimeError("Xformers is required for this node when enabled. Please install xformers.")
|
|
|
|
print(f'Region attention Node enabled: {enabled}, regions: {number_of_regions}')
|
|
|
|
# Extract dimensions and embeddings
|
|
latent = latent_dimensions["samples"]
|
|
bs_l, n_ch, lH, lW = latent.shape
|
|
text_emb = condition[0][0].clone()
|
|
clip_emb = condition[0][1]['pooled_output'].clone()
|
|
bs, emb_size, emb_dim = text_emb.shape
|
|
iH, iW = lH * 8, lW * 8
|
|
|
|
# Process active regions
|
|
subprompts_embeds = []
|
|
masks = []
|
|
region_strengths = []
|
|
|
|
# Collect regions and feather radii
|
|
regions = [region1, region2, region3]
|
|
feather_radii = [feather_radius1, feather_radius2, feather_radius3]
|
|
|
|
for idx, region in enumerate(regions[:number_of_regions]):
|
|
if region is not None and region.get('conditioning') is not None:
|
|
# Get 'strength' from region or default to 1.0
|
|
strength = region.get('strength', 1.0)
|
|
region_strengths.append(strength)
|
|
subprompt_emb = region['conditioning'][0][0]
|
|
subprompts_embeds.append(subprompt_emb)
|
|
# Use per-region feather_radius
|
|
feather_radius = feather_radii[idx] if feather_radii[idx] is not None else 0.0
|
|
masks.append(self.generate_region_mask(region, iW, iH, feather_radius))
|
|
else:
|
|
print(f"Region {idx+1} is None or has no conditioning")
|
|
|
|
if not subprompts_embeds:
|
|
print("No active regions with conditioning found.")
|
|
# Restore original attention functions
|
|
flux_math.attention = self.original_flux_attention
|
|
flux_layers.attention = self.original_flux_layers_attention
|
|
return (model, condition)
|
|
|
|
n_regs = len(subprompts_embeds)
|
|
|
|
# Generate attention components
|
|
lin_masks, hH, hW = self.generate_test_mask(masks, iH, iW)
|
|
Nx = int(hH * hW)
|
|
emb_len = emb_size * (n_regs + 1) # +1 for main prompt
|
|
|
|
# Create attention mask
|
|
attn_mask, q_scale, k_scale = self.prepare_attention_mask(
|
|
lin_masks, region_strengths, Nx, emb_size, emb_len)
|
|
|
|
# Format for xFormers
|
|
device = torch.device('cuda')
|
|
attn_dtype = torch.bfloat16 if model_management.should_use_bf16(device=device) else torch.float16
|
|
|
|
if attn_mask is not None:
|
|
print(f'Applying attention masks: torch.Size([{attn_mask.shape[0]}, {attn_mask.shape[1]}])')
|
|
L = attn_mask.shape[0]
|
|
H = 24 # Number of heads in FLUX model
|
|
pad = (8 - L % 8) % 8 # Ensure pad is between 0 and 7
|
|
pad_L = L + pad
|
|
mask_out = torch.zeros([bs, H, pad_L, pad_L], dtype=attn_dtype, device=device)
|
|
mask_out[:, :, :L, :L] = attn_mask.to(device, dtype=attn_dtype)
|
|
attn_mask = mask_out[:, :, :pad_L, :pad_L]
|
|
|
|
# Prepare final mask
|
|
attn_mask_bool = attn_mask > 0.5
|
|
attn_mask.masked_fill_(attn_mask_bool, float('-inf'))
|
|
|
|
# Override attention
|
|
attn_mask_arg = attn_mask if enabled else None
|
|
override_attention = partial(self.xformers_attention, attn_mask=attn_mask_arg)
|
|
flux_math.attention = override_attention
|
|
flux_layers.attention = override_attention
|
|
|
|
# Create extended conditioning
|
|
extended_condition = torch.cat([text_emb] + subprompts_embeds, dim=1)
|
|
|
|
return (model, [[extended_condition, {'pooled_output': clip_emb}]])
|
|
|
|
# Node class mappings
|
|
NODE_CLASS_MAPPINGS = {
|
|
"FluxAttentionControl": FluxAttentionControl
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"FluxAttentionControl": "Flux Attention Control"
|
|
} |