Files
ComfyUI/custom_nodes/controlaltai-nodes/region_mask_validator_node.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

270 lines
11 KiB
Python

import torch
from typing import Tuple, Dict, Optional, List
import numpy as np
from PIL import Image, ImageDraw
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
class RegionMaskValidator:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask1": ("MASK",),
"bbox1": ("BBOX",),
"number_of_regions": ("INT", {
"default": 1,
"min": 1,
"max": 3,
"step": 1
}),
"min_region_size": ("INT", {
"default": 64,
"min": 32,
"max": 512,
"step": 32,
"display": "Minimum Region Size (px)"
}),
"max_overlap": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 0.5,
"step": 0.01,
"display": "Maximum Region Overlap"
}),
},
"optional": {
"mask2": ("MASK",),
"bbox2": ("BBOX",),
"mask3": ("MASK",),
"bbox3": ("BBOX",),
}
}
RETURN_TYPES = ("MASK", "BBOX", "MASK", "BBOX", "MASK", "BBOX",
"INT", "BOOLEAN", "STRING", "IMAGE")
RETURN_NAMES = ("valid_mask1", "valid_bbox1",
"valid_mask2", "valid_bbox2",
"valid_mask3", "valid_bbox3",
"valid_region_count", "is_valid", "validation_message",
"validation_preview")
FUNCTION = "validate_regions"
CATEGORY = "ControlAltAI Nodes/Flux Region"
def get_region_dimensions(self, bbox: Dict, width: int, height: int) -> Tuple[int, int, Tuple[int, int]]:
"""Calculate region dimensions in pixels"""
if not bbox["active"]:
return 0, (0, 0)
x1 = int(bbox["x1"] * width)
y1 = int(bbox["y1"] * height)
x2 = int(bbox["x2"] * width)
y2 = int(bbox["y2"] * height)
w = x2 - x1
h = y2 - y1
area = w * h
print(f"Region dimensions: {w}x{h} pixels")
return area, (w, h)
def calculate_overlap(self, bbox1: Dict, bbox2: Dict, width: int, height: int) -> Tuple[Tuple[int, int], float]:
"""Calculate overlap dimensions and ratio"""
if not (bbox1["active"] and bbox2["active"]):
return (0, 0), 0.0
# Convert to pixel coordinates
x1_1 = int(bbox1["x1"] * width)
y1_1 = int(bbox1["y1"] * height)
x2_1 = int(bbox1["x2"] * width)
y2_1 = int(bbox1["y2"] * height)
x1_2 = int(bbox2["x1"] * width)
y1_2 = int(bbox2["y1"] * height)
x2_2 = int(bbox2["x2"] * width)
y2_2 = int(bbox2["y2"] * height)
# Calculate intersection
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
if x_right > x_left and y_bottom > y_top:
overlap_width = x_right - x_left
overlap_height = y_bottom - y_top
overlap_area = overlap_width * overlap_height
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
smaller_area = min(area1, area2)
overlap_ratio = overlap_area / smaller_area
print(f"Overlap dimensions: {overlap_width}x{overlap_height} pixels ({overlap_ratio:.1%})")
return (overlap_width, overlap_height), overlap_ratio
return (0, 0), 0.0
def create_validation_preview(self, masks: List[torch.Tensor], bboxes: List[Dict],
number_of_regions: int, is_valid: bool,
messages: List[str], img_width: int, img_height: int) -> torch.Tensor:
"""Create visual validation feedback with improved text rendering"""
if not masks:
return torch.zeros((3, 64, 64), dtype=torch.float32)
preview = Image.new("RGB", (img_width, img_height), (0, 0, 0))
draw = ImageDraw.Draw(preview)
# Colors for valid/invalid regions
colors = {
'valid': [(0, 255, 0), (0, 200, 0), (0, 150, 0)], # Green shades
'invalid': [(255, 0, 0), (200, 0, 0), (150, 0, 0)] # Red shades
}
# Draw regions with validation status and improved text
for i, (mask, bbox) in enumerate(zip(masks[:number_of_regions], bboxes[:number_of_regions])):
if bbox["active"]:
x1 = int(bbox["x1"] * img_width)
y1 = int(bbox["y1"] * img_height)
x2 = int(bbox["x2"] * img_width)
y2 = int(bbox["y2"] * img_height)
w = x2 - x1
h = y2 - y1
color = colors['valid' if is_valid else 'invalid'][i]
# Draw thicker rectangle outline
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
# Improved region label with dimensions
label = f"R{i+1}: {w}x{h}"
# Position text with offset from corner and draw twice for better visibility
text_x = x1 + 10
text_y = y1 + 10
# Draw text shadow/outline for better contrast
shadow_offset = 2
shadow_color = (0, 0, 0)
for dx in [-shadow_offset, shadow_offset]:
for dy in [-shadow_offset, shadow_offset]:
draw.text((text_x + dx, text_y + dy), label, fill=shadow_color, font=None, size=64)
# Draw main text
draw.text((text_x, text_y), label, fill=color, font=None, size=64)
# If region is invalid, add error message below the label
if not is_valid and i < len(messages):
error_y = text_y + 30 # Position error message below label
# Draw error message with shadow for contrast
for dx in [-shadow_offset, shadow_offset]:
for dy in [-shadow_offset, shadow_offset]:
draw.text((text_x + dx, error_y + dy), messages[i], fill=shadow_color, font=None, size=20)
draw.text((text_x, error_y), messages[i], fill=color, font=None, size=20)
return pil2tensor(preview)
def validate_regions(self,
mask1: torch.Tensor,
bbox1: Dict,
number_of_regions: int,
min_region_size: int,
max_overlap: float,
mask2: Optional[torch.Tensor] = None,
bbox2: Optional[Dict] = None,
mask3: Optional[torch.Tensor] = None,
bbox3: Optional[Dict] = None) -> Tuple:
try:
print(f"\nValidating {number_of_regions} regions:")
messages = []
is_valid = True
height, width = mask1.shape
print(f"Canvas size: {width}x{height} pixels")
# Collect regions
regions = [
(mask1, bbox1),
(mask2, bbox2) if mask2 is not None else (None, None),
(mask3, bbox3) if mask3 is not None else (None, None),
]
# Validate each region
valid_regions = []
valid_count = 0
for i, (mask, bbox) in enumerate(regions):
if i < number_of_regions and mask is not None and bbox is not None:
print(f"\nValidating Region {i+1}:")
# Check region size
_, (w, h) = self.get_region_dimensions(bbox, width, height)
if w < min_region_size or h < min_region_size:
message = f"Region {i+1} too small: {w}x{h} pixels (minimum: {min_region_size}x{min_region_size})"
print(f"Failed: {message}")
messages.append(message)
is_valid = False
bbox = bbox.copy()
bbox["active"] = False
else:
print(f"Passed: Region {i+1} size check ({w}x{h} pixels)")
valid_count += 1
valid_regions.append((mask, bbox))
else:
valid_regions.append((
torch.zeros_like(mask1),
{"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
))
# Check overlaps
if valid_count > 1:
print("\nChecking region overlaps:")
for i in range(len(valid_regions)):
for j in range(i + 1, len(valid_regions)):
mask_i, bbox_i = valid_regions[i]
mask_j, bbox_j = valid_regions[j]
if bbox_i["active"] and bbox_j["active"]:
print(f"Checking overlap between regions {i+1} and {j+1}:")
(ow, oh), overlap_ratio = self.calculate_overlap(bbox_i, bbox_j, width, height)
if overlap_ratio > max_overlap:
message = f"Excessive overlap ({ow}x{oh} pixels, {overlap_ratio:.1%}) between regions {i+1} and {j+1}"
print(f"Failed: {message}")
messages.append(message)
is_valid = False
# Create validation message
validation_message = "All regions valid" if is_valid else "\n".join(messages)
print(f"\nValidation {'passed' if is_valid else 'failed'}:")
print(validation_message)
# Create validation preview
preview = self.create_validation_preview(
[r[0] for r in valid_regions],
[r[1] for r in valid_regions],
number_of_regions,
is_valid,
messages,
width,
height
)
return (*[item for region in valid_regions for item in region],
valid_count, is_valid, validation_message, preview)
except Exception as e:
print(f"Validation error: {str(e)}")
empty_mask = torch.zeros_like(mask1)
empty_bbox = {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
empty_preview = torch.zeros((3, height, width), dtype=torch.float32)
return (empty_mask, empty_bbox, empty_mask, empty_bbox,
empty_mask, empty_bbox,
0, False, f"Validation error: {str(e)}", empty_preview)
# Node class mappings
NODE_CLASS_MAPPINGS = {
"RegionMaskValidator": RegionMaskValidator
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RegionMaskValidator": "Region Mask Validator"
}