Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
853 lines
30 KiB
Python
853 lines
30 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
||
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
|
||
|
||
import torch
|
||
|
||
|
||
class PrepareLoopFrames:
|
||
DESCRIPTION = "Prepares the wrap seam: builds a tiny 2-frame batch [last, first] for your interpolator and also passes the original clip through unchanged."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"images": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Your clip as an IMAGE batch (frames×H×W×C, values 0–1). Outputs: [last, first] for the seam, plus the original clip."
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||
RETURN_NAMES = ("interp_batch", "original_images")
|
||
FUNCTION = "prepare"
|
||
CATEGORY = "video utils"
|
||
|
||
def prepare(self, images):
|
||
last_frame = images[-1:]
|
||
first_frame = images[0:1]
|
||
interp_batch = torch.cat((last_frame, first_frame), dim=0)
|
||
return (interp_batch, images)
|
||
|
||
|
||
class AssembleLoopFrames:
|
||
DESCRIPTION = "Builds the final loop: appends only the new in-between seam frames to your original clip—no duplicate of frame 1."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"original_images": (
|
||
"IMAGE",
|
||
{"tooltip": "Your original clip (frames×H×W×C)."},
|
||
),
|
||
"interpolated_frames": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Frames that bridge last→first. The first and last of this batch are the originals; only the middle ones get added."
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("images",)
|
||
FUNCTION = "assemble"
|
||
CATEGORY = "video utils"
|
||
|
||
def assemble(self, original_images, interpolated_frames):
|
||
original_images = original_images.to(interpolated_frames.device)
|
||
in_between = interpolated_frames[1:-1]
|
||
out = torch.cat((original_images, in_between), dim=0)
|
||
return (out,)
|
||
|
||
|
||
class RollFrames:
|
||
DESCRIPTION = "Rolls the clip in a loop by an integer amount (cyclic shift). Also returns the same offset so you can undo it later."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"images": ("IMAGE", {"tooltip": "Your clip (frames×H×W×C)."}),
|
||
"offset": (
|
||
"INT",
|
||
{
|
||
"default": 1,
|
||
"min": -9999,
|
||
"max": 9999,
|
||
"step": 1,
|
||
"tooltip": "How far to rotate the clip. Positive = forward in time; negative = backward.",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "INT")
|
||
RETURN_NAMES = ("images", "offset_out")
|
||
FUNCTION = "roll"
|
||
CATEGORY = "video utils"
|
||
|
||
def roll(self, images, offset):
|
||
B = images.shape[0]
|
||
if B == 0:
|
||
return (images, int(offset))
|
||
k = int(offset) % B
|
||
if k == 0:
|
||
return (images, int(offset))
|
||
rolled = torch.roll(images, shifts=-k, dims=0) # +1 → [2,3,...,1]
|
||
return (rolled, int(offset))
|
||
|
||
|
||
class UnrollFrames:
|
||
DESCRIPTION = "Undo a previous roll after interpolation by accounting for the inserted frames (rotate by base_offset × (m+1))."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"images": (
|
||
"IMAGE",
|
||
{"tooltip": "Clip after interpolation (frames′×H×W×C)."},
|
||
),
|
||
"base_offset": (
|
||
"INT",
|
||
{
|
||
"default": 1,
|
||
"min": -9999,
|
||
"max": 9999,
|
||
"step": 1,
|
||
"tooltip": "Use the exact offset_out that came from RollFrames.",
|
||
},
|
||
),
|
||
"m": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 9999,
|
||
"step": 1,
|
||
"tooltip": "How many in-betweens per gap were added (the interpolation multiple).",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("images",)
|
||
FUNCTION = "unroll"
|
||
CATEGORY = "video utils"
|
||
|
||
def unroll(self, images, base_offset, m):
|
||
Bp = images.shape[0]
|
||
if Bp == 0:
|
||
return (images,)
|
||
eff = (int(base_offset) * (int(m) + 1)) % Bp
|
||
return (torch.roll(images, shifts=+eff, dims=0),)
|
||
|
||
|
||
class AutocropToLoop:
|
||
"""
|
||
Finds a natural loop by cropping frames from the END of the batch.
|
||
Returns the cropped clip that makes the seam (last_kept -> first)
|
||
feel like a normal step between real neighbors.
|
||
|
||
Score = weighted mix of:
|
||
- step-size match (L1/MSE distance)
|
||
- similarity match (SSIM)
|
||
- exposure continuity (luma)
|
||
- motion consistency (optical flow; optional)
|
||
|
||
Speed: can run metrics on GPU and use mixed precision for SSIM/conv math.
|
||
Progress bar: one tick per candidate crop (0..max_end_crop_frames).
|
||
"""
|
||
|
||
DESCRIPTION = "Auto-crops the clip to create a smoother loop: tests crops from the end and scores the seam so it feels like a normal step."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"clip_frames": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Your full clip (NHWC, 0–1). Tries every crop from 0..max_end_crop_frames and returns the best loop."
|
||
},
|
||
),
|
||
"max_end_crop_frames": (
|
||
"INT",
|
||
{
|
||
"default": 12,
|
||
"min": 0,
|
||
"max": 10000,
|
||
"tooltip": "Largest crop to test at the END. Higher = more candidates (slower), but potentially better.",
|
||
},
|
||
),
|
||
"include_first_step": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Use the first neighbor pair (frame 0→1) as a target step size/similarity.",
|
||
},
|
||
),
|
||
"include_last_step": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Use the last neighbor pair inside the KEPT region as a target.",
|
||
},
|
||
),
|
||
"include_global_median_step": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Also use the median step across the KEPT region (needs ≥3 frames). Helps ignore outliers.",
|
||
},
|
||
),
|
||
"seam_window_frames": (
|
||
"INT",
|
||
{
|
||
"default": 2,
|
||
"min": 1,
|
||
"max": 6,
|
||
"tooltip": "Average over multiple aligned pairs across the seam. Larger = more robust.",
|
||
},
|
||
),
|
||
"distance_metric": (
|
||
["L1", "MSE"],
|
||
{
|
||
"default": "L1",
|
||
"tooltip": "How to measure step size for matching. L1 is usually more forgiving; MSE penalizes big errors more.",
|
||
},
|
||
),
|
||
"score_in_8bit": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Score with an 8-bit view (simulate export). Output video still stays float.",
|
||
},
|
||
),
|
||
"use_ssim_similarity": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Include SSIM so the seam ‘looks’ like a normal neighbor—avoid freeze or jump.",
|
||
},
|
||
),
|
||
"use_exposure_guard": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Promote smooth brightness across the seam (reduces flicker pops).",
|
||
},
|
||
),
|
||
"use_flow_guard": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Encourage consistent motion across the seam (needs OpenCV; slower).",
|
||
},
|
||
),
|
||
"weight_step_size": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.55,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.01,
|
||
"tooltip": "Importance of matching step size. Higher = less freeze/jump risk.",
|
||
},
|
||
),
|
||
"weight_similarity": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.30,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.01,
|
||
"tooltip": "Importance of visual similarity (SSIM). Helps avoid a frozen-looking seam.",
|
||
},
|
||
),
|
||
"weight_exposure": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.10,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.01,
|
||
"tooltip": "Importance of even brightness across the seam.",
|
||
},
|
||
),
|
||
"weight_flow": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.05,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.01,
|
||
"tooltip": "Importance of motion continuity across the seam.",
|
||
},
|
||
),
|
||
"ssim_downsample_scales": (
|
||
"STRING",
|
||
{
|
||
"default": "1,2",
|
||
"tooltip": "SSIM scales to average, as a comma list. Example: 1,2 = full-res and half-res.",
|
||
},
|
||
),
|
||
"accelerate_with_gpu": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "If ON and CUDA is available, run scoring on GPU for a big speedup (same results).",
|
||
},
|
||
),
|
||
"use_mixed_precision": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "If ON (with GPU), use mixed precision for SSIM/conv math (faster on larger clips).",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "INT", "INT", "FLOAT", "STRING")
|
||
RETURN_NAMES = (
|
||
"cropped_clip",
|
||
"end_crop_frames",
|
||
"cropped_length",
|
||
"score",
|
||
"diagnostics_csv",
|
||
)
|
||
FUNCTION = "find_and_crop"
|
||
CATEGORY = "video utils"
|
||
|
||
_gw_cache = {} # gaussian window cache
|
||
|
||
def _to_nchw(self, x):
|
||
import torch
|
||
|
||
if x.ndim == 4 and x.shape[-1] in (1, 3, 4):
|
||
return x.permute(0, 3, 1, 2).contiguous()
|
||
return x
|
||
|
||
def _parse_scales(self, csv):
|
||
scales = []
|
||
for s in str(csv).split(","):
|
||
s = s.strip()
|
||
if not s:
|
||
continue
|
||
try:
|
||
v = int(s)
|
||
if v >= 1 and v not in scales:
|
||
scales.append(v)
|
||
except Exception:
|
||
pass
|
||
return scales or [1]
|
||
|
||
def _downsample(self, x, s):
|
||
import torch.nn.functional as F
|
||
|
||
if s == 1:
|
||
return x
|
||
H, W = x.shape[-2:]
|
||
newH = max(1, H // s)
|
||
newW = max(1, W // s)
|
||
return F.interpolate(x, size=(newH, newW), mode="area", align_corners=None)
|
||
|
||
def _dist(self, A, B, kind="L1"):
|
||
if kind == "MSE":
|
||
return ((A - B) ** 2).mean(dim=(1, 2, 3))
|
||
return (A - B).abs().mean(dim=(1, 2, 3))
|
||
|
||
def _luma(self, x_nchw):
|
||
if x_nchw.shape[1] == 1:
|
||
return x_nchw[:, 0:1]
|
||
R = x_nchw[:, 0:1]
|
||
G = x_nchw[:, 1:2]
|
||
B = x_nchw[:, 2:3]
|
||
return 0.2126 * R + 0.7152 * G + 0.0722 * B
|
||
|
||
def _gaussian_window(self, C, k=7, sigma=1.5, device="cpu", dtype=None):
|
||
import torch
|
||
|
||
key = (int(C), int(k), float(sigma), str(device), str(dtype))
|
||
w = self._gw_cache.get(key)
|
||
if w is not None:
|
||
return w
|
||
ax = torch.arange(k, dtype=dtype, device=device) - (k - 1) / 2.0
|
||
gauss = torch.exp(-0.5 * (ax / sigma) ** 2)
|
||
kernel1d = (gauss / gauss.sum()).unsqueeze(1)
|
||
kernel2d = kernel1d @ kernel1d.t()
|
||
w = kernel2d.expand(C, 1, k, k).contiguous()
|
||
self._gw_cache[key] = w
|
||
return w
|
||
|
||
def _ssim_pair_batched(self, x, y, k=7, sigma=1.5, C1=0.01**2, C2=0.03**2):
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
C = x.shape[1]
|
||
w = self._gaussian_window(C, k=k, sigma=sigma, device=x.device, dtype=x.dtype)
|
||
mu_x = F.conv2d(x, w, padding=k // 2, groups=C)
|
||
mu_y = F.conv2d(y, w, padding=k // 2, groups=C)
|
||
mu_x2, mu_y2, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y
|
||
sigma_x2 = F.conv2d(x * x, w, padding=k // 2, groups=C) - mu_x2
|
||
sigma_y2 = F.conv2d(y * y, w, padding=k // 2, groups=C) - mu_y2
|
||
sigma_xy = F.conv2d(x * y, w, padding=k // 2, groups=C) - mu_xy
|
||
ssim_map = ((2.0 * mu_xy + C1) * (2.0 * sigma_xy + C2)) / (
|
||
(mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2) + 1e-12
|
||
)
|
||
return ssim_map.mean(dim=(1, 2, 3)) # (N,)
|
||
|
||
def _ssim_multiscale_batched(self, x, y, scales):
|
||
vecs = []
|
||
for s in scales:
|
||
xs = self._downsample(x, s)
|
||
ys = self._downsample(y, s)
|
||
vecs.append(self._ssim_pair_batched(xs, ys))
|
||
return sum(vecs) / float(len(vecs)) # (N,)
|
||
|
||
def _precompute_adjacent_metrics(
|
||
self, clip_nhwc_dev, kind, use_ssim, ds_scales, use_exp, use_flow
|
||
):
|
||
"""
|
||
Returns dict with vectors of length (B-1):
|
||
D_adj (torch), S_adj (torch), E_adj (torch), F_adj (np, CPU)
|
||
All torch tensors are on the same device as clip_nhwc_dev.
|
||
"""
|
||
import numpy as np
|
||
import torch
|
||
|
||
B = int(clip_nhwc_dev.shape[0])
|
||
N = max(0, B - 1)
|
||
result = {}
|
||
if N == 0:
|
||
result["D_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
|
||
result["S_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
|
||
result["E_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
|
||
result["F_adj"] = np.zeros((0,), dtype="float64")
|
||
return result, self._to_nchw(clip_nhwc_dev)
|
||
|
||
x_nchw = self._to_nchw(clip_nhwc_dev) # B,C,H,W (device)
|
||
X = x_nchw[:-1]
|
||
Y = x_nchw[1:] # N,C,H,W
|
||
|
||
result["D_adj"] = self._dist(X, Y, kind=kind) # (N,)
|
||
|
||
if use_ssim:
|
||
result["S_adj"] = self._ssim_multiscale_batched(X, Y, ds_scales) # (N,)
|
||
else:
|
||
result["S_adj"] = torch.empty(0, device=x_nchw.device)
|
||
|
||
if use_exp:
|
||
Y_luma = self._luma(x_nchw).mean(dim=(1, 2, 3)) # (B,)
|
||
result["E_adj"] = (Y_luma[:-1] - Y_luma[1:]).abs() # (N,)
|
||
else:
|
||
result["E_adj"] = torch.empty(0, device=x_nchw.device)
|
||
|
||
if use_flow:
|
||
F_adj = []
|
||
for i in range(N):
|
||
a = clip_nhwc_dev[i : i + 1].detach().cpu()
|
||
b = clip_nhwc_dev[i + 1 : i + 2].detach().cpu()
|
||
F_adj.append(self._flow_mag_mean(a, b))
|
||
import numpy as np
|
||
|
||
result["F_adj"] = np.array(F_adj, dtype="float64")
|
||
else:
|
||
import numpy as np
|
||
|
||
result["F_adj"] = np.zeros((N,), dtype="float64")
|
||
|
||
return result, x_nchw
|
||
|
||
def _precompute_seam_tables(self, x_nchw_dev, W, kind, use_ssim, ds_scales):
|
||
"""
|
||
For k = 0..W-1, precompute per-frame metrics vs first+k:
|
||
D_to_firstk[k] : (B,) distances to frame k
|
||
S_to_firstk[k] : (B,) SSIM to frame k (if use_ssim)
|
||
E_to_firstk[k] : (B,) |luma(i)-luma(k)|
|
||
Tensors live on x_nchw_dev.device.
|
||
"""
|
||
import torch
|
||
|
||
B = int(x_nchw_dev.shape[0])
|
||
W = max(1, min(int(W), B - 1))
|
||
D_to_firstk, S_to_firstk, E_to_firstk = [], [], []
|
||
|
||
Y = self._luma(x_nchw_dev).mean(dim=(1, 2, 3))
|
||
|
||
for k in range(W):
|
||
Bk = x_nchw_dev[k : k + 1].expand_as(x_nchw_dev)
|
||
Dk = self._dist(x_nchw_dev, Bk, kind=kind)
|
||
D_to_firstk.append(Dk)
|
||
if use_ssim:
|
||
Sk = self._ssim_multiscale_batched(x_nchw_dev, Bk, ds_scales)
|
||
S_to_firstk.append(Sk)
|
||
else:
|
||
S_to_firstk.append(torch.empty(0, device=x_nchw_dev.device))
|
||
Ek = (Y - Y[k]).abs()
|
||
E_to_firstk.append(Ek)
|
||
|
||
return D_to_firstk, S_to_firstk, E_to_firstk
|
||
|
||
def _flow_mag_mean(self, a_nhwc, b_nhwc, max_side=256):
|
||
"""
|
||
Mean optical-flow magnitude. Accepts NHWC with/without batch,
|
||
RGB/RGBA/Gray. Soft-fails to 0.0 if OpenCV unavailable.
|
||
"""
|
||
try:
|
||
import cv2
|
||
import numpy as np
|
||
except Exception:
|
||
return 0.0
|
||
|
||
a = (a_nhwc.detach().cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||
b = (b_nhwc.detach().cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||
if a.ndim == 4 and a.shape[0] == 1:
|
||
a = a[0]
|
||
if b.ndim == 4 and b.shape[0] == 1:
|
||
b = b[0]
|
||
|
||
def to_gray(x: np.ndarray) -> np.ndarray:
|
||
if x.ndim == 2:
|
||
return x
|
||
if x.ndim == 3:
|
||
c = x.shape[-1]
|
||
if c == 1:
|
||
return x[..., 0]
|
||
if c == 3:
|
||
return cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
|
||
if c == 4:
|
||
return cv2.cvtColor(x, cv2.COLOR_RGBA2GRAY)
|
||
return x.mean(axis=-1).astype(x.dtype)
|
||
x2 = np.squeeze(x)
|
||
if x2.ndim == 2:
|
||
return x2
|
||
if x2.ndim == 3:
|
||
return x2.mean(axis=-1).astype(x2.dtype)
|
||
return None
|
||
|
||
a_g, b_g = to_gray(a), to_gray(b)
|
||
if a_g is None or b_g is None or a_g.ndim != 2 or b_g.ndim != 2:
|
||
return 0.0
|
||
|
||
H, W = a_g.shape
|
||
scale = max(1.0, max(H, W) / float(max_side))
|
||
if scale > 1.0:
|
||
newW = int(round(W / scale))
|
||
newH = int(round(H / scale))
|
||
a_g = cv2.resize(a_g, (newW, newH), interpolation=cv2.INTER_AREA)
|
||
b_g = cv2.resize(b_g, (newW, newH), interpolation=cv2.INTER_AREA)
|
||
|
||
try:
|
||
flow = cv2.calcOpticalFlowFarneback(
|
||
a_g, b_g, None, 0.5, 3, 21, 3, 5, 1.1, 0
|
||
)
|
||
mag = (flow[..., 0] ** 2 + flow[..., 1] ** 2) ** 0.5
|
||
return float(mag.mean())
|
||
except Exception:
|
||
return 0.0
|
||
|
||
def find_and_crop(
|
||
self,
|
||
clip_frames,
|
||
max_end_crop_frames,
|
||
include_first_step,
|
||
include_last_step,
|
||
include_global_median_step,
|
||
seam_window_frames,
|
||
distance_metric,
|
||
score_in_8bit,
|
||
use_ssim_similarity,
|
||
use_exposure_guard,
|
||
use_flow_guard,
|
||
weight_step_size,
|
||
weight_similarity,
|
||
weight_exposure,
|
||
weight_flow,
|
||
ssim_downsample_scales,
|
||
accelerate_with_gpu,
|
||
use_mixed_precision,
|
||
):
|
||
import contextlib
|
||
|
||
import numpy as np
|
||
import torch
|
||
from comfy.utils import ProgressBar
|
||
|
||
clip_out = clip_frames
|
||
clip_eval = (
|
||
(clip_frames * 255.0).round().clamp(0, 255) / 255.0
|
||
if score_in_8bit
|
||
else clip_frames
|
||
)
|
||
|
||
B = int(clip_eval.shape[0])
|
||
if B < 2:
|
||
header = "end_crop,score,D_seam,D_target,S_seam,S_target,E_seam,E_target,F_seam,F_target"
|
||
return (clip_out, 0, B, 0.0, header)
|
||
|
||
dev = "cuda" if accelerate_with_gpu and torch.cuda.is_available() else "cpu"
|
||
amp_ctx = (
|
||
torch.cuda.amp.autocast
|
||
if (dev == "cuda" and use_mixed_precision)
|
||
else contextlib.nullcontext
|
||
)
|
||
|
||
ds_scales = self._parse_scales(ssim_downsample_scales)
|
||
kind = distance_metric
|
||
W = int(seam_window_frames)
|
||
total_candidates = int(max(0, max_end_crop_frames)) + 1
|
||
|
||
with torch.no_grad():
|
||
with amp_ctx():
|
||
|
||
clip_eval_dev = clip_eval.to(dev, non_blocking=True)
|
||
|
||
pre, x_nchw_dev = self._precompute_adjacent_metrics(
|
||
clip_nhwc_dev=clip_eval_dev,
|
||
kind=kind,
|
||
use_ssim=use_ssim_similarity,
|
||
ds_scales=ds_scales,
|
||
use_exp=use_exposure_guard,
|
||
use_flow=use_flow_guard,
|
||
)
|
||
D_adj, S_adj, E_adj, F_adj = (
|
||
pre["D_adj"],
|
||
pre["S_adj"],
|
||
pre["E_adj"],
|
||
pre["F_adj"],
|
||
)
|
||
|
||
D_seam_tab, S_seam_tab, E_seam_tab = self._precompute_seam_tables(
|
||
x_nchw_dev=x_nchw_dev,
|
||
W=W,
|
||
kind=kind,
|
||
use_ssim=use_ssim_similarity,
|
||
ds_scales=ds_scales,
|
||
)
|
||
|
||
Y = self._luma(x_nchw_dev).mean(dim=(1, 2, 3))
|
||
|
||
best_extra = 0
|
||
best_score = float("inf")
|
||
rows = []
|
||
pbar = ProgressBar(total_candidates)
|
||
|
||
for extra in range(0, total_candidates):
|
||
keep = B - extra
|
||
if keep < 2:
|
||
pbar.update(1)
|
||
continue
|
||
|
||
last_idx = keep - 1
|
||
W_eff = max(1, min(W, last_idx + 1, B - 1))
|
||
|
||
chosen_D = []
|
||
if include_first_step and keep >= 2:
|
||
chosen_D.append(D_adj[0])
|
||
if include_last_step and keep >= 2:
|
||
chosen_D.append(D_adj[last_idx - 1])
|
||
if include_global_median_step and keep >= 3:
|
||
chosen_D.append(D_adj[: keep - 1].median())
|
||
if not chosen_D and keep >= 2:
|
||
chosen_D = [D_adj[0]]
|
||
D_target = float(
|
||
(
|
||
chosen_D[0]
|
||
if len(chosen_D) == 1
|
||
else torch.stack(chosen_D).median()
|
||
).item()
|
||
)
|
||
|
||
if use_ssim_similarity and S_adj.numel() > 0 and keep >= 2:
|
||
chosen_S = []
|
||
if include_first_step:
|
||
chosen_S.append(S_adj[0])
|
||
if include_last_step:
|
||
chosen_S.append(S_adj[last_idx - 1])
|
||
if include_global_median_step and keep >= 3:
|
||
chosen_S.append(S_adj[: keep - 1].median())
|
||
S_target = float(
|
||
(
|
||
chosen_S[0]
|
||
if (chosen_S and len(chosen_S) == 1)
|
||
else (
|
||
torch.stack(chosen_S).median()
|
||
if chosen_S
|
||
else torch.tensor(0.0, device=S_adj.device)
|
||
)
|
||
).item()
|
||
)
|
||
else:
|
||
S_target = 0.0
|
||
|
||
if use_exposure_guard and keep >= 2:
|
||
e_first = (Y[0] - Y[1]).abs()
|
||
e_last = (Y[last_idx] - Y[last_idx - 1]).abs()
|
||
if include_global_median_step and keep >= 3:
|
||
e_med = (Y[: keep - 1] - Y[1:keep]).abs().median()
|
||
E_target = float(
|
||
torch.stack([e_first, e_last, e_med]).median().item()
|
||
)
|
||
else:
|
||
E_target = float(torch.stack([e_first, e_last]).median().item())
|
||
else:
|
||
E_target = 0.0
|
||
|
||
if use_flow_guard and keep >= 3 and F_adj.size > 0:
|
||
import numpy as np
|
||
|
||
F_target = float(np.median(F_adj[: keep - 1]))
|
||
else:
|
||
F_target = 0.0
|
||
|
||
idxs = [last_idx - (W_eff - 1 - r) for r in range(W_eff)]
|
||
idxs_t = torch.tensor(idxs, device=x_nchw_dev.device, dtype=torch.long)
|
||
|
||
D_vals = torch.stack(
|
||
[
|
||
D_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
|
||
for r in range(W_eff)
|
||
]
|
||
)
|
||
D_seam = float(D_vals.mean().item())
|
||
|
||
if use_ssim_similarity and S_seam_tab[0].numel() > 0:
|
||
S_vals = torch.stack(
|
||
[
|
||
S_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
|
||
for r in range(W_eff)
|
||
]
|
||
)
|
||
S_seam = float(S_vals.mean().item())
|
||
else:
|
||
S_seam = 0.0
|
||
|
||
if use_exposure_guard:
|
||
E_vals = torch.stack(
|
||
[
|
||
E_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
|
||
for r in range(W_eff)
|
||
]
|
||
)
|
||
E_seam = float(E_vals.mean().item())
|
||
else:
|
||
E_seam = 0.0
|
||
|
||
F_seam = 0.0
|
||
|
||
eps = 1e-12
|
||
cost_step = abs(D_seam - D_target) / (D_target + eps)
|
||
cost_sim = (
|
||
abs(S_seam - S_target) / (abs(S_target) + eps)
|
||
if use_ssim_similarity
|
||
else 0.0
|
||
)
|
||
cost_exp = (
|
||
abs(E_seam - E_target) / (E_target + eps)
|
||
if (use_exposure_guard and E_target > 0.0)
|
||
else 0.0
|
||
)
|
||
cost_flow = (
|
||
abs(F_seam - F_target) / (F_target + eps)
|
||
if (use_flow_guard and F_target > 0.0)
|
||
else 0.0
|
||
)
|
||
|
||
score = (
|
||
weight_step_size * cost_step
|
||
+ weight_similarity * cost_sim
|
||
+ weight_exposure * cost_exp
|
||
+ weight_flow * cost_flow
|
||
)
|
||
|
||
rows.append(
|
||
f"{extra},{score:.6f},{D_seam:.6f},{D_target:.6f},{S_seam:.6f},{S_target:.6f},{E_seam:.6f},{E_target:.6f},{F_seam:.6f},{F_target:.6f}"
|
||
)
|
||
|
||
if score < best_score:
|
||
best_score = score
|
||
best_extra = extra
|
||
|
||
pbar.update(1)
|
||
|
||
final_keep = max(2, B - best_extra)
|
||
cropped = clip_out[0:final_keep]
|
||
header = "end_crop,score,D_seam,D_target,S_seam,S_target,E_seam,E_target,F_seam,F_target"
|
||
diagnostics_csv = header + "\n" + "\n".join(rows) if rows else header
|
||
return (
|
||
cropped,
|
||
int(best_extra),
|
||
int(final_keep),
|
||
float(best_score),
|
||
diagnostics_csv,
|
||
)
|
||
|
||
|
||
class TrimBatchEnds:
|
||
"""
|
||
Trim frames from the START and/or END of an IMAGE batch (NHWC, [0..1]).
|
||
Both trims are applied in one pass. Always leaves at least one frame.
|
||
"""
|
||
|
||
DESCRIPTION = "Quickly remove frames from the start and/or end of a clip. Always keeps at least one frame."
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"clip_frames": ("IMAGE", {"tooltip": "Your clip (frames×H×W×C, 0–1)."}),
|
||
"trim_start_frames": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 100000,
|
||
"tooltip": "Frames to remove from the START.",
|
||
},
|
||
),
|
||
"trim_end_frames": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 100000,
|
||
"tooltip": "Frames to remove from the END.",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("images",)
|
||
FUNCTION = "crop"
|
||
CATEGORY = "video utils"
|
||
|
||
def crop(self, clip_frames, trim_start_frames, trim_end_frames):
|
||
import torch
|
||
|
||
if not isinstance(clip_frames, torch.Tensor) or clip_frames.ndim != 4:
|
||
return (clip_frames,)
|
||
|
||
B = int(clip_frames.shape[0])
|
||
if B <= 1:
|
||
return (clip_frames,)
|
||
|
||
s = max(0, int(trim_start_frames))
|
||
e = max(0, int(trim_end_frames))
|
||
|
||
if s + e >= B:
|
||
s = min(s, B - 1)
|
||
e = max(0, B - s - 1)
|
||
|
||
out = clip_frames[s : B - e] if e > 0 else clip_frames[s:]
|
||
if out.shape[0] == 0:
|
||
out = clip_frames[B - 1 : B]
|
||
return (out,)
|