Files
ComfyUI/custom_nodes/controlaltai-nodes/region_mask_processor_node.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

230 lines
8.6 KiB
Python

import torch
import torch.nn.functional as F
from typing import Tuple, Dict, Optional, List
import numpy as np
from PIL import Image, ImageDraw
def pil2tensor(image):
"""Convert a PIL image to a PyTorch tensor in the expected format."""
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
class RegionMaskProcessor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask1": ("MASK",),
"bbox1": ("BBOX",),
"blur_radius": ("INT", {
"default": 5,
"min": 0,
"max": 32,
"step": 1,
"display": "Blur Radius"
}),
"threshold": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.1,
"display": "Mask Threshold"
}),
"feather_edges": ("BOOLEAN", {
"default": True,
"display": "Feather Edges"
}),
"number_of_regions": ("INT", {
"default": 1,
"min": 1,
"max": 3,
"display": "Number of Regions"
}),
},
"optional": {
"mask2": ("MASK",),
"bbox2": ("BBOX",),
"mask3": ("MASK",),
"bbox3": ("BBOX",),
}
}
RETURN_TYPES = ("MASK", "BBOX", "MASK", "BBOX", "MASK", "BBOX", "IMAGE", "INT")
RETURN_NAMES = ("processed_mask1", "processed_bbox1",
"processed_mask2", "processed_bbox2",
"processed_mask3", "processed_bbox3",
"preview_image", "region_count")
FUNCTION = "process_regions"
CATEGORY = "ControlAltAI Nodes/Flux Region"
def apply_gaussian_blur(self, mask: torch.Tensor, radius: int) -> torch.Tensor:
"""Apply gaussian blur to mask edges"""
if radius <= 0:
return mask
kernel_size = 2 * radius + 1
sigma = radius / 3.0
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
kernel_1d = torch.exp(torch.linspace(-radius, radius, kernel_size).pow(2) / (-2 * sigma ** 2))
kernel_1d = kernel_1d / kernel_1d.sum()
padding = radius
kernel_h = kernel_1d.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(mask.device)
kernel_v = kernel_1d.unsqueeze(0).unsqueeze(0).unsqueeze(-1).to(mask.device)
mask = F.pad(mask, (padding, padding, 0, 0), mode='reflect')
mask = F.conv2d(mask, kernel_h)
mask = F.pad(mask, (0, 0, padding, padding), mode='reflect')
mask = F.conv2d(mask, kernel_v)
return mask.squeeze()
def apply_feathering(self, mask: torch.Tensor, bbox: Dict, radius: int) -> Tuple[torch.Tensor, Dict]:
"""Apply feathering to mask edges while preserving bbox boundaries"""
if radius <= 0 or not bbox["active"]:
return mask, bbox
height, width = mask.shape
x1 = int(bbox["x1"] * width)
y1 = int(bbox["y1"] * height)
x2 = int(bbox["x2"] * width)
y2 = int(bbox["y2"] * height)
inner_mask = torch.zeros_like(mask)
inner_mask[y1+radius:y2-radius, x1+radius:x2-radius] = 1.0
edge_mask = mask - inner_mask
if edge_mask.any():
blurred = self.apply_gaussian_blur(mask, radius)
result = mask.clone()
result[edge_mask > 0] = blurred[edge_mask > 0]
else:
result = mask
return result, bbox
def process_single_region(self,
mask: torch.Tensor,
bbox: Dict,
blur_radius: int,
threshold: float,
feather_edges: bool) -> Tuple[torch.Tensor, Dict]:
"""Process a single mask-bbox pair"""
if mask is None or not bbox["active"]:
return mask, bbox
try:
processed = (mask > threshold).float()
if feather_edges and blur_radius > 0:
processed, bbox = self.apply_feathering(processed, bbox, blur_radius)
elif blur_radius > 0:
processed = self.apply_gaussian_blur(processed, blur_radius)
return processed, bbox
except Exception as e:
print(f"Error processing region: {str(e)}")
return mask, bbox
def create_preview(self, masks: List[torch.Tensor], bboxes: List[Dict],
number_of_regions: int) -> torch.Tensor:
"""Create preview of processed regions with PIL for consistent coloring"""
if not masks:
return torch.zeros((3, 64, 64), dtype=torch.float32)
height, width = masks[0].shape
# Create PIL Image for preview
preview = Image.new("RGB", (width, height), (0, 0, 0))
colors = [
(255, 0, 0), # Red - Region 1
(0, 255, 0), # Green - Region 2
(255, 255, 0), # Yellow - Region 3
]
# Store regions for ordered preview
preview_regions = []
for i in range(number_of_regions):
if bboxes[i]["active"] and masks[i] is not None:
mask_np = masks[i].cpu().numpy() > 0.5
preview_regions.append((i, mask_np))
# Draw regions in reverse order (Region 3 first, Region 1 last)
for i, mask_np in sorted(preview_regions, reverse=True):
color_array = np.zeros((height, width, 3), dtype=np.uint8)
color_array[mask_np] = colors[i]
# Convert to PIL and composite
region_img = Image.fromarray(color_array, 'RGB')
preview = Image.alpha_composite(
preview.convert('RGBA'),
Image.merge('RGBA', (*region_img.split(), Image.fromarray((mask_np * 255).astype(np.uint8))))
)
return pil2tensor(preview.convert('RGB'))
def process_regions(self,
mask1: torch.Tensor,
bbox1: Dict,
blur_radius: int,
threshold: float,
feather_edges: bool,
number_of_regions: int,
mask2: Optional[torch.Tensor] = None,
bbox2: Optional[Dict] = None,
mask3: Optional[torch.Tensor] = None,
bbox3: Optional[Dict] = None) -> Tuple:
try:
# Process each mask-bbox pair
mask_bbox_pairs = [
(mask1, bbox1),
(mask2, bbox2) if mask2 is not None else (None, None),
(mask3, bbox3) if mask3 is not None else (None, None),
]
processed_masks = []
processed_bboxes = []
active_count = 0
for i, (mask, bbox) in enumerate(mask_bbox_pairs):
if i < number_of_regions and mask is not None and bbox is not None:
proc_mask, proc_bbox = self.process_single_region(
mask, bbox, blur_radius, threshold, feather_edges
)
if proc_bbox["active"]:
active_count += 1
processed_masks.append(proc_mask)
processed_bboxes.append(proc_bbox)
else:
empty_mask = torch.zeros_like(mask1)
empty_bbox = {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
processed_masks.append(empty_mask)
processed_bboxes.append(empty_bbox)
# Create preview
preview = self.create_preview(processed_masks, processed_bboxes, number_of_regions)
return (*[item for pair in zip(processed_masks, processed_bboxes) for item in pair],
preview, active_count)
except Exception as e:
print(f"Error processing regions: {str(e)}")
empty_mask = torch.zeros_like(mask1)
empty_bbox = {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
empty_preview = torch.zeros((3, mask1.shape[0], mask1.shape[1]), dtype=torch.float32)
return (empty_mask, empty_bbox, empty_mask, empty_bbox,
empty_mask, empty_bbox,
empty_preview, 0)
# Node class mappings
NODE_CLASS_MAPPINGS = {
"RegionMaskProcessor": RegionMaskProcessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RegionMaskProcessor": "Region Mask Processor"
}