Files
ComfyUI/custom_nodes/controlaltai-nodes/flux_attention_control_node.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

325 lines
14 KiB
Python

import torch
from torch import Tensor
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple
from einops import rearrange
import comfy.model_management as model_management
from comfy.ldm.modules import attention as comfy_attention
from comfy.ldm.flux import math as flux_math
from comfy.ldm.flux import layers as flux_layers
import numpy as np
from PIL import Image, ImageFilter, ImageDraw
from functools import partial
# Protected xformers import
try:
from xformers.ops import memory_efficient_attention as xattention
has_xformers = True
except ImportError:
has_xformers = False
xattention = None
class FluxAttentionControl:
def __init__(self):
self.original_attention = comfy_attention.optimized_attention
self.original_flux_attention = flux_math.attention
self.original_flux_layers_attention = flux_layers.attention
if not has_xformers:
print("\n" + "="*70)
print("\033[94mControlAltAI-Nodes: This node requires xformers to function.\033[0m")
print("\033[33mPlease check \"xformers_instructions.txt\" in ComfyUI\\custom_nodes\\ControlAltAI-Nodes for how to install XFormers\033[0m")
print("="*70 + "\n")
print("FluxAttentionControl initialized")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"condition": ("CONDITIONING",),
"latent_dimensions": ("LATENT",),
"region1": ("REGION",),
"number_of_regions": ("INT", {
"default": 1,
"min": 1,
"max": 3,
"step": 1,
"display": "Number of Regions"
}),
"enabled": ("BOOLEAN", {
"default": True,
"display": "Enable Regional Control"
}),
"feather_radius1": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 100.0,
"step": 1.0,
"display": "Feather Radius for Region 1"
}),
},
"optional": {
"region2": ("REGION",),
"feather_radius2": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 100.0,
"step": 1.0,
"display": "Feather Radius for Region 2"
}),
"region3": ("REGION",),
"feather_radius3": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 100.0,
"step": 1.0,
"display": "Feather Radius for Region 3"
}),
}
}
RETURN_TYPES = ("MODEL", "CONDITIONING",)
RETURN_NAMES = ("model", "conditioning",)
FUNCTION = "apply_attention_control"
CATEGORY = "ControlAltAI Nodes/Flux Region"
def generate_region_mask(self, region: Dict, width: int, height: int, feather_radius: float) -> Image.Image:
if region.get('bbox') is not None:
x1, y1, x2, y2 = region['bbox']
x1_px = int(x1 * width)
y1_px = int(y1 * height)
x2_px = int(x2 * width)
y2_px = int(y2 * height)
mask = Image.new('L', (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle([x1_px, y1_px, x2_px, y2_px], fill=255)
if feather_radius > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=feather_radius))
print(f'Generating masks with {width}x{height} and [{x1}, {y1}, {x2}, {y2}], feather_radius={feather_radius}')
return mask
elif region.get('mask') is not None:
mask = region['mask'][0].cpu().numpy()
mask = (mask * 255).astype(np.uint8)
mask = Image.fromarray(mask)
mask = mask.resize((width, height))
if feather_radius > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=feather_radius))
return mask
else:
raise Exception('Unknown region type')
def generate_test_mask(self, masks: List[Image.Image], height: int, width: int):
hH, hW = int(height) // 16, int(width) // 16
print(f'{width} {height} -> {hW} {hH}')
lin_masks = []
for mask in masks:
mask = mask.convert('L')
mask = torch.tensor(np.array(mask)).unsqueeze(0).unsqueeze(0) / 255.0 # Normalize to 0-1
mask = F.interpolate(mask, (hH, hW), mode='bilinear', align_corners=False).flatten()
lin_masks.append(mask)
return lin_masks, hH, hW
def prepare_attention_mask(self, lin_masks: List[torch.Tensor], region_strengths: List[float], Nx: int, emb_size: int, emb_len: int):
"""Prepare attention mask for three regions with per-region strengths."""
total_len = emb_len + Nx
n_regs = len(lin_masks)
# Initialize attention mask and scales
cross_mask = torch.zeros(total_len, total_len)
q_scale = torch.ones(total_len)
k_scale = torch.ones(total_len)
# Indices for embeddings
main_prompt_start = 0
main_prompt_end = emb_size
# Subprompt indices
subprompt_starts = [emb_size * (i + 1) for i in range(n_regs)]
subprompt_ends = [emb_size * (i + 2) for i in range(n_regs)]
# Initialize position masks
position_masks = torch.stack(lin_masks) # Shape: [n_regs, Nx]
# Normalize masks so that overlapping areas sum to 1
position_masks_sum = position_masks.sum(dim=0)
position_masks_normalized = position_masks / (position_masks_sum + 1e-8)
# Build attention masks and scales
for i in range(n_regs):
sp_start = subprompt_starts[i]
sp_end = subprompt_ends[i]
mask_i = position_masks_normalized[i]
# Scale embeddings based on mask and per-region strength
strength = region_strengths[i]
q_scale[sp_start:sp_end] = mask_i.mean() * strength
k_scale[sp_start:sp_end] = mask_i.mean() * strength
# Create mask including tokens and positions
m_with_tokens = torch.cat([torch.ones(emb_len), mask_i])
mb = m_with_tokens > 0.0 # Include positions where mask > 0
# Block attention between positions not in mask and subprompt
cross_mask[~mb, sp_start:sp_end] = 1
cross_mask[sp_start:sp_end, ~mb] = 1
# Block attention between positions in region and main prompt
positions_idx = (mask_i > 0.0).nonzero(as_tuple=True)[0] + emb_len
cross_mask[positions_idx[:, None], main_prompt_start:main_prompt_end] = 1
cross_mask[main_prompt_start:main_prompt_end, positions_idx[None, :]] = 1
# Block attention between subprompts
for j in range(n_regs):
if i != j:
other_sp_start = subprompt_starts[j]
other_sp_end = subprompt_ends[j]
cross_mask[sp_start:sp_end, other_sp_start:other_sp_end] = 1
cross_mask[other_sp_start:other_sp_end, sp_start:sp_end] = 1
# Ensure self-attention is allowed
cross_mask.fill_diagonal_(0)
# Prepare scales for GPU
q_scale = q_scale.reshape(1, 1, -1, 1).cuda()
k_scale = k_scale.reshape(1, 1, -1, 1).cuda()
return cross_mask, q_scale, k_scale
def xformers_attention(self, q: Tensor, k: Tensor, v: Tensor, pe: Tensor,
attn_mask: Optional[Tensor] = None,
mask: Optional[Tensor] = None) -> Tensor: # Added mask parameter
q, k = flux_math.apply_rope(q, k, pe)
q = rearrange(q, "B H L D -> B L H D")
k = rearrange(k, "B H L D -> B L H D")
v = rearrange(v, "B H L D -> B L H D")
# Use attn_mask if provided, otherwise use the mask parameter
attention_bias = attn_mask if attn_mask is not None else mask
if attention_bias is not None:
x = xattention(q, k, v, attn_bias=attention_bias)
else:
x = xattention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)")
return x
def apply_attention_control(self,
model: object,
condition: List,
latent_dimensions: Dict,
region1: Dict,
number_of_regions: int,
enabled: bool,
feather_radius1: float = 0.0,
region2: Optional[Dict] = None,
feather_radius2: Optional[float] = 0.0,
region3: Optional[Dict] = None,
feather_radius3: Optional[float] = 0.0):
# Extract dimensions and embeddings first (moved before enabled check)
latent = latent_dimensions["samples"]
bs_l, n_ch, lH, lW = latent.shape
text_emb = condition[0][0].clone()
clip_emb = condition[0][1]['pooled_output'].clone()
bs, emb_size, emb_dim = text_emb.shape
iH, iW = lH * 8, lW * 8
if not enabled:
# Restore original attention functions
flux_math.attention = self.original_flux_attention
flux_layers.attention = self.original_flux_layers_attention
print("Regional control disabled. Restored original attention functions.")
return (model, condition) # Return original condition when disabled
if enabled and not has_xformers:
raise RuntimeError("Xformers is required for this node when enabled. Please install xformers.")
print(f'Region attention Node enabled: {enabled}, regions: {number_of_regions}')
# Extract dimensions and embeddings
latent = latent_dimensions["samples"]
bs_l, n_ch, lH, lW = latent.shape
text_emb = condition[0][0].clone()
clip_emb = condition[0][1]['pooled_output'].clone()
bs, emb_size, emb_dim = text_emb.shape
iH, iW = lH * 8, lW * 8
# Process active regions
subprompts_embeds = []
masks = []
region_strengths = []
# Collect regions and feather radii
regions = [region1, region2, region3]
feather_radii = [feather_radius1, feather_radius2, feather_radius3]
for idx, region in enumerate(regions[:number_of_regions]):
if region is not None and region.get('conditioning') is not None:
# Get 'strength' from region or default to 1.0
strength = region.get('strength', 1.0)
region_strengths.append(strength)
subprompt_emb = region['conditioning'][0][0]
subprompts_embeds.append(subprompt_emb)
# Use per-region feather_radius
feather_radius = feather_radii[idx] if feather_radii[idx] is not None else 0.0
masks.append(self.generate_region_mask(region, iW, iH, feather_radius))
else:
print(f"Region {idx+1} is None or has no conditioning")
if not subprompts_embeds:
print("No active regions with conditioning found.")
# Restore original attention functions
flux_math.attention = self.original_flux_attention
flux_layers.attention = self.original_flux_layers_attention
return (model, condition)
n_regs = len(subprompts_embeds)
# Generate attention components
lin_masks, hH, hW = self.generate_test_mask(masks, iH, iW)
Nx = int(hH * hW)
emb_len = emb_size * (n_regs + 1) # +1 for main prompt
# Create attention mask
attn_mask, q_scale, k_scale = self.prepare_attention_mask(
lin_masks, region_strengths, Nx, emb_size, emb_len)
# Format for xFormers
device = torch.device('cuda')
attn_dtype = torch.bfloat16 if model_management.should_use_bf16(device=device) else torch.float16
if attn_mask is not None:
print(f'Applying attention masks: torch.Size([{attn_mask.shape[0]}, {attn_mask.shape[1]}])')
L = attn_mask.shape[0]
H = 24 # Number of heads in FLUX model
pad = (8 - L % 8) % 8 # Ensure pad is between 0 and 7
pad_L = L + pad
mask_out = torch.zeros([bs, H, pad_L, pad_L], dtype=attn_dtype, device=device)
mask_out[:, :, :L, :L] = attn_mask.to(device, dtype=attn_dtype)
attn_mask = mask_out[:, :, :pad_L, :pad_L]
# Prepare final mask
attn_mask_bool = attn_mask > 0.5
attn_mask.masked_fill_(attn_mask_bool, float('-inf'))
# Override attention
attn_mask_arg = attn_mask if enabled else None
override_attention = partial(self.xformers_attention, attn_mask=attn_mask_arg)
flux_math.attention = override_attention
flux_layers.attention = override_attention
# Create extended conditioning
extended_condition = torch.cat([text_emb] + subprompts_embeds, dim=1)
return (model, [[extended_condition, {'pooled_output': clip_emb}]])
# Node class mappings
NODE_CLASS_MAPPINGS = {
"FluxAttentionControl": FluxAttentionControl
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxAttentionControl": "Flux Attention Control"
}