Files
ComfyUI/custom_nodes/whiterabbit/video_loop.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

853 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: AGPL-3.0-only
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
import torch
class PrepareLoopFrames:
DESCRIPTION = "Prepares the wrap seam: builds a tiny 2-frame batch [last, first] for your interpolator and also passes the original clip through unchanged."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (
"IMAGE",
{
"tooltip": "Your clip as an IMAGE batch (frames×H×W×C, values 01). Outputs: [last, first] for the seam, plus the original clip."
},
),
}
}
RETURN_TYPES = ("IMAGE", "IMAGE")
RETURN_NAMES = ("interp_batch", "original_images")
FUNCTION = "prepare"
CATEGORY = "video utils"
def prepare(self, images):
last_frame = images[-1:]
first_frame = images[0:1]
interp_batch = torch.cat((last_frame, first_frame), dim=0)
return (interp_batch, images)
class AssembleLoopFrames:
DESCRIPTION = "Builds the final loop: appends only the new in-between seam frames to your original clip—no duplicate of frame 1."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": (
"IMAGE",
{"tooltip": "Your original clip (frames×H×W×C)."},
),
"interpolated_frames": (
"IMAGE",
{
"tooltip": "Frames that bridge last→first. The first and last of this batch are the originals; only the middle ones get added."
},
),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "assemble"
CATEGORY = "video utils"
def assemble(self, original_images, interpolated_frames):
original_images = original_images.to(interpolated_frames.device)
in_between = interpolated_frames[1:-1]
out = torch.cat((original_images, in_between), dim=0)
return (out,)
class RollFrames:
DESCRIPTION = "Rolls the clip in a loop by an integer amount (cyclic shift). Also returns the same offset so you can undo it later."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {"tooltip": "Your clip (frames×H×W×C)."}),
"offset": (
"INT",
{
"default": 1,
"min": -9999,
"max": 9999,
"step": 1,
"tooltip": "How far to rotate the clip. Positive = forward in time; negative = backward.",
},
),
}
}
RETURN_TYPES = ("IMAGE", "INT")
RETURN_NAMES = ("images", "offset_out")
FUNCTION = "roll"
CATEGORY = "video utils"
def roll(self, images, offset):
B = images.shape[0]
if B == 0:
return (images, int(offset))
k = int(offset) % B
if k == 0:
return (images, int(offset))
rolled = torch.roll(images, shifts=-k, dims=0) # +1 → [2,3,...,1]
return (rolled, int(offset))
class UnrollFrames:
DESCRIPTION = "Undo a previous roll after interpolation by accounting for the inserted frames (rotate by base_offset × (m+1))."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (
"IMAGE",
{"tooltip": "Clip after interpolation (frames×H×W×C)."},
),
"base_offset": (
"INT",
{
"default": 1,
"min": -9999,
"max": 9999,
"step": 1,
"tooltip": "Use the exact offset_out that came from RollFrames.",
},
),
"m": (
"INT",
{
"default": 0,
"min": 0,
"max": 9999,
"step": 1,
"tooltip": "How many in-betweens per gap were added (the interpolation multiple).",
},
),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "unroll"
CATEGORY = "video utils"
def unroll(self, images, base_offset, m):
Bp = images.shape[0]
if Bp == 0:
return (images,)
eff = (int(base_offset) * (int(m) + 1)) % Bp
return (torch.roll(images, shifts=+eff, dims=0),)
class AutocropToLoop:
"""
Finds a natural loop by cropping frames from the END of the batch.
Returns the cropped clip that makes the seam (last_kept -> first)
feel like a normal step between real neighbors.
Score = weighted mix of:
- step-size match (L1/MSE distance)
- similarity match (SSIM)
- exposure continuity (luma)
- motion consistency (optical flow; optional)
Speed: can run metrics on GPU and use mixed precision for SSIM/conv math.
Progress bar: one tick per candidate crop (0..max_end_crop_frames).
"""
DESCRIPTION = "Auto-crops the clip to create a smoother loop: tests crops from the end and scores the seam so it feels like a normal step."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"clip_frames": (
"IMAGE",
{
"tooltip": "Your full clip (NHWC, 01). Tries every crop from 0..max_end_crop_frames and returns the best loop."
},
),
"max_end_crop_frames": (
"INT",
{
"default": 12,
"min": 0,
"max": 10000,
"tooltip": "Largest crop to test at the END. Higher = more candidates (slower), but potentially better.",
},
),
"include_first_step": (
"BOOLEAN",
{
"default": True,
"tooltip": "Use the first neighbor pair (frame 0→1) as a target step size/similarity.",
},
),
"include_last_step": (
"BOOLEAN",
{
"default": True,
"tooltip": "Use the last neighbor pair inside the KEPT region as a target.",
},
),
"include_global_median_step": (
"BOOLEAN",
{
"default": False,
"tooltip": "Also use the median step across the KEPT region (needs ≥3 frames). Helps ignore outliers.",
},
),
"seam_window_frames": (
"INT",
{
"default": 2,
"min": 1,
"max": 6,
"tooltip": "Average over multiple aligned pairs across the seam. Larger = more robust.",
},
),
"distance_metric": (
["L1", "MSE"],
{
"default": "L1",
"tooltip": "How to measure step size for matching. L1 is usually more forgiving; MSE penalizes big errors more.",
},
),
"score_in_8bit": (
"BOOLEAN",
{
"default": False,
"tooltip": "Score with an 8-bit view (simulate export). Output video still stays float.",
},
),
"use_ssim_similarity": (
"BOOLEAN",
{
"default": True,
"tooltip": "Include SSIM so the seam looks like a normal neighbor—avoid freeze or jump.",
},
),
"use_exposure_guard": (
"BOOLEAN",
{
"default": True,
"tooltip": "Promote smooth brightness across the seam (reduces flicker pops).",
},
),
"use_flow_guard": (
"BOOLEAN",
{
"default": False,
"tooltip": "Encourage consistent motion across the seam (needs OpenCV; slower).",
},
),
"weight_step_size": (
"FLOAT",
{
"default": 0.55,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Importance of matching step size. Higher = less freeze/jump risk.",
},
),
"weight_similarity": (
"FLOAT",
{
"default": 0.30,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Importance of visual similarity (SSIM). Helps avoid a frozen-looking seam.",
},
),
"weight_exposure": (
"FLOAT",
{
"default": 0.10,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Importance of even brightness across the seam.",
},
),
"weight_flow": (
"FLOAT",
{
"default": 0.05,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Importance of motion continuity across the seam.",
},
),
"ssim_downsample_scales": (
"STRING",
{
"default": "1,2",
"tooltip": "SSIM scales to average, as a comma list. Example: 1,2 = full-res and half-res.",
},
),
"accelerate_with_gpu": (
"BOOLEAN",
{
"default": True,
"tooltip": "If ON and CUDA is available, run scoring on GPU for a big speedup (same results).",
},
),
"use_mixed_precision": (
"BOOLEAN",
{
"default": True,
"tooltip": "If ON (with GPU), use mixed precision for SSIM/conv math (faster on larger clips).",
},
),
}
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "FLOAT", "STRING")
RETURN_NAMES = (
"cropped_clip",
"end_crop_frames",
"cropped_length",
"score",
"diagnostics_csv",
)
FUNCTION = "find_and_crop"
CATEGORY = "video utils"
_gw_cache = {} # gaussian window cache
def _to_nchw(self, x):
import torch
if x.ndim == 4 and x.shape[-1] in (1, 3, 4):
return x.permute(0, 3, 1, 2).contiguous()
return x
def _parse_scales(self, csv):
scales = []
for s in str(csv).split(","):
s = s.strip()
if not s:
continue
try:
v = int(s)
if v >= 1 and v not in scales:
scales.append(v)
except Exception:
pass
return scales or [1]
def _downsample(self, x, s):
import torch.nn.functional as F
if s == 1:
return x
H, W = x.shape[-2:]
newH = max(1, H // s)
newW = max(1, W // s)
return F.interpolate(x, size=(newH, newW), mode="area", align_corners=None)
def _dist(self, A, B, kind="L1"):
if kind == "MSE":
return ((A - B) ** 2).mean(dim=(1, 2, 3))
return (A - B).abs().mean(dim=(1, 2, 3))
def _luma(self, x_nchw):
if x_nchw.shape[1] == 1:
return x_nchw[:, 0:1]
R = x_nchw[:, 0:1]
G = x_nchw[:, 1:2]
B = x_nchw[:, 2:3]
return 0.2126 * R + 0.7152 * G + 0.0722 * B
def _gaussian_window(self, C, k=7, sigma=1.5, device="cpu", dtype=None):
import torch
key = (int(C), int(k), float(sigma), str(device), str(dtype))
w = self._gw_cache.get(key)
if w is not None:
return w
ax = torch.arange(k, dtype=dtype, device=device) - (k - 1) / 2.0
gauss = torch.exp(-0.5 * (ax / sigma) ** 2)
kernel1d = (gauss / gauss.sum()).unsqueeze(1)
kernel2d = kernel1d @ kernel1d.t()
w = kernel2d.expand(C, 1, k, k).contiguous()
self._gw_cache[key] = w
return w
def _ssim_pair_batched(self, x, y, k=7, sigma=1.5, C1=0.01**2, C2=0.03**2):
import torch
import torch.nn.functional as F
C = x.shape[1]
w = self._gaussian_window(C, k=k, sigma=sigma, device=x.device, dtype=x.dtype)
mu_x = F.conv2d(x, w, padding=k // 2, groups=C)
mu_y = F.conv2d(y, w, padding=k // 2, groups=C)
mu_x2, mu_y2, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y
sigma_x2 = F.conv2d(x * x, w, padding=k // 2, groups=C) - mu_x2
sigma_y2 = F.conv2d(y * y, w, padding=k // 2, groups=C) - mu_y2
sigma_xy = F.conv2d(x * y, w, padding=k // 2, groups=C) - mu_xy
ssim_map = ((2.0 * mu_xy + C1) * (2.0 * sigma_xy + C2)) / (
(mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2) + 1e-12
)
return ssim_map.mean(dim=(1, 2, 3)) # (N,)
def _ssim_multiscale_batched(self, x, y, scales):
vecs = []
for s in scales:
xs = self._downsample(x, s)
ys = self._downsample(y, s)
vecs.append(self._ssim_pair_batched(xs, ys))
return sum(vecs) / float(len(vecs)) # (N,)
def _precompute_adjacent_metrics(
self, clip_nhwc_dev, kind, use_ssim, ds_scales, use_exp, use_flow
):
"""
Returns dict with vectors of length (B-1):
D_adj (torch), S_adj (torch), E_adj (torch), F_adj (np, CPU)
All torch tensors are on the same device as clip_nhwc_dev.
"""
import numpy as np
import torch
B = int(clip_nhwc_dev.shape[0])
N = max(0, B - 1)
result = {}
if N == 0:
result["D_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
result["S_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
result["E_adj"] = torch.empty(0, device=clip_nhwc_dev.device)
result["F_adj"] = np.zeros((0,), dtype="float64")
return result, self._to_nchw(clip_nhwc_dev)
x_nchw = self._to_nchw(clip_nhwc_dev) # B,C,H,W (device)
X = x_nchw[:-1]
Y = x_nchw[1:] # N,C,H,W
result["D_adj"] = self._dist(X, Y, kind=kind) # (N,)
if use_ssim:
result["S_adj"] = self._ssim_multiscale_batched(X, Y, ds_scales) # (N,)
else:
result["S_adj"] = torch.empty(0, device=x_nchw.device)
if use_exp:
Y_luma = self._luma(x_nchw).mean(dim=(1, 2, 3)) # (B,)
result["E_adj"] = (Y_luma[:-1] - Y_luma[1:]).abs() # (N,)
else:
result["E_adj"] = torch.empty(0, device=x_nchw.device)
if use_flow:
F_adj = []
for i in range(N):
a = clip_nhwc_dev[i : i + 1].detach().cpu()
b = clip_nhwc_dev[i + 1 : i + 2].detach().cpu()
F_adj.append(self._flow_mag_mean(a, b))
import numpy as np
result["F_adj"] = np.array(F_adj, dtype="float64")
else:
import numpy as np
result["F_adj"] = np.zeros((N,), dtype="float64")
return result, x_nchw
def _precompute_seam_tables(self, x_nchw_dev, W, kind, use_ssim, ds_scales):
"""
For k = 0..W-1, precompute per-frame metrics vs first+k:
D_to_firstk[k] : (B,) distances to frame k
S_to_firstk[k] : (B,) SSIM to frame k (if use_ssim)
E_to_firstk[k] : (B,) |luma(i)-luma(k)|
Tensors live on x_nchw_dev.device.
"""
import torch
B = int(x_nchw_dev.shape[0])
W = max(1, min(int(W), B - 1))
D_to_firstk, S_to_firstk, E_to_firstk = [], [], []
Y = self._luma(x_nchw_dev).mean(dim=(1, 2, 3))
for k in range(W):
Bk = x_nchw_dev[k : k + 1].expand_as(x_nchw_dev)
Dk = self._dist(x_nchw_dev, Bk, kind=kind)
D_to_firstk.append(Dk)
if use_ssim:
Sk = self._ssim_multiscale_batched(x_nchw_dev, Bk, ds_scales)
S_to_firstk.append(Sk)
else:
S_to_firstk.append(torch.empty(0, device=x_nchw_dev.device))
Ek = (Y - Y[k]).abs()
E_to_firstk.append(Ek)
return D_to_firstk, S_to_firstk, E_to_firstk
def _flow_mag_mean(self, a_nhwc, b_nhwc, max_side=256):
"""
Mean optical-flow magnitude. Accepts NHWC with/without batch,
RGB/RGBA/Gray. Soft-fails to 0.0 if OpenCV unavailable.
"""
try:
import cv2
import numpy as np
except Exception:
return 0.0
a = (a_nhwc.detach().cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
b = (b_nhwc.detach().cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
if a.ndim == 4 and a.shape[0] == 1:
a = a[0]
if b.ndim == 4 and b.shape[0] == 1:
b = b[0]
def to_gray(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
return x
if x.ndim == 3:
c = x.shape[-1]
if c == 1:
return x[..., 0]
if c == 3:
return cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
if c == 4:
return cv2.cvtColor(x, cv2.COLOR_RGBA2GRAY)
return x.mean(axis=-1).astype(x.dtype)
x2 = np.squeeze(x)
if x2.ndim == 2:
return x2
if x2.ndim == 3:
return x2.mean(axis=-1).astype(x2.dtype)
return None
a_g, b_g = to_gray(a), to_gray(b)
if a_g is None or b_g is None or a_g.ndim != 2 or b_g.ndim != 2:
return 0.0
H, W = a_g.shape
scale = max(1.0, max(H, W) / float(max_side))
if scale > 1.0:
newW = int(round(W / scale))
newH = int(round(H / scale))
a_g = cv2.resize(a_g, (newW, newH), interpolation=cv2.INTER_AREA)
b_g = cv2.resize(b_g, (newW, newH), interpolation=cv2.INTER_AREA)
try:
flow = cv2.calcOpticalFlowFarneback(
a_g, b_g, None, 0.5, 3, 21, 3, 5, 1.1, 0
)
mag = (flow[..., 0] ** 2 + flow[..., 1] ** 2) ** 0.5
return float(mag.mean())
except Exception:
return 0.0
def find_and_crop(
self,
clip_frames,
max_end_crop_frames,
include_first_step,
include_last_step,
include_global_median_step,
seam_window_frames,
distance_metric,
score_in_8bit,
use_ssim_similarity,
use_exposure_guard,
use_flow_guard,
weight_step_size,
weight_similarity,
weight_exposure,
weight_flow,
ssim_downsample_scales,
accelerate_with_gpu,
use_mixed_precision,
):
import contextlib
import numpy as np
import torch
from comfy.utils import ProgressBar
clip_out = clip_frames
clip_eval = (
(clip_frames * 255.0).round().clamp(0, 255) / 255.0
if score_in_8bit
else clip_frames
)
B = int(clip_eval.shape[0])
if B < 2:
header = "end_crop,score,D_seam,D_target,S_seam,S_target,E_seam,E_target,F_seam,F_target"
return (clip_out, 0, B, 0.0, header)
dev = "cuda" if accelerate_with_gpu and torch.cuda.is_available() else "cpu"
amp_ctx = (
torch.cuda.amp.autocast
if (dev == "cuda" and use_mixed_precision)
else contextlib.nullcontext
)
ds_scales = self._parse_scales(ssim_downsample_scales)
kind = distance_metric
W = int(seam_window_frames)
total_candidates = int(max(0, max_end_crop_frames)) + 1
with torch.no_grad():
with amp_ctx():
clip_eval_dev = clip_eval.to(dev, non_blocking=True)
pre, x_nchw_dev = self._precompute_adjacent_metrics(
clip_nhwc_dev=clip_eval_dev,
kind=kind,
use_ssim=use_ssim_similarity,
ds_scales=ds_scales,
use_exp=use_exposure_guard,
use_flow=use_flow_guard,
)
D_adj, S_adj, E_adj, F_adj = (
pre["D_adj"],
pre["S_adj"],
pre["E_adj"],
pre["F_adj"],
)
D_seam_tab, S_seam_tab, E_seam_tab = self._precompute_seam_tables(
x_nchw_dev=x_nchw_dev,
W=W,
kind=kind,
use_ssim=use_ssim_similarity,
ds_scales=ds_scales,
)
Y = self._luma(x_nchw_dev).mean(dim=(1, 2, 3))
best_extra = 0
best_score = float("inf")
rows = []
pbar = ProgressBar(total_candidates)
for extra in range(0, total_candidates):
keep = B - extra
if keep < 2:
pbar.update(1)
continue
last_idx = keep - 1
W_eff = max(1, min(W, last_idx + 1, B - 1))
chosen_D = []
if include_first_step and keep >= 2:
chosen_D.append(D_adj[0])
if include_last_step and keep >= 2:
chosen_D.append(D_adj[last_idx - 1])
if include_global_median_step and keep >= 3:
chosen_D.append(D_adj[: keep - 1].median())
if not chosen_D and keep >= 2:
chosen_D = [D_adj[0]]
D_target = float(
(
chosen_D[0]
if len(chosen_D) == 1
else torch.stack(chosen_D).median()
).item()
)
if use_ssim_similarity and S_adj.numel() > 0 and keep >= 2:
chosen_S = []
if include_first_step:
chosen_S.append(S_adj[0])
if include_last_step:
chosen_S.append(S_adj[last_idx - 1])
if include_global_median_step and keep >= 3:
chosen_S.append(S_adj[: keep - 1].median())
S_target = float(
(
chosen_S[0]
if (chosen_S and len(chosen_S) == 1)
else (
torch.stack(chosen_S).median()
if chosen_S
else torch.tensor(0.0, device=S_adj.device)
)
).item()
)
else:
S_target = 0.0
if use_exposure_guard and keep >= 2:
e_first = (Y[0] - Y[1]).abs()
e_last = (Y[last_idx] - Y[last_idx - 1]).abs()
if include_global_median_step and keep >= 3:
e_med = (Y[: keep - 1] - Y[1:keep]).abs().median()
E_target = float(
torch.stack([e_first, e_last, e_med]).median().item()
)
else:
E_target = float(torch.stack([e_first, e_last]).median().item())
else:
E_target = 0.0
if use_flow_guard and keep >= 3 and F_adj.size > 0:
import numpy as np
F_target = float(np.median(F_adj[: keep - 1]))
else:
F_target = 0.0
idxs = [last_idx - (W_eff - 1 - r) for r in range(W_eff)]
idxs_t = torch.tensor(idxs, device=x_nchw_dev.device, dtype=torch.long)
D_vals = torch.stack(
[
D_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
for r in range(W_eff)
]
)
D_seam = float(D_vals.mean().item())
if use_ssim_similarity and S_seam_tab[0].numel() > 0:
S_vals = torch.stack(
[
S_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
for r in range(W_eff)
]
)
S_seam = float(S_vals.mean().item())
else:
S_seam = 0.0
if use_exposure_guard:
E_vals = torch.stack(
[
E_seam_tab[r].index_select(0, idxs_t[r : r + 1]).squeeze(0)
for r in range(W_eff)
]
)
E_seam = float(E_vals.mean().item())
else:
E_seam = 0.0
F_seam = 0.0
eps = 1e-12
cost_step = abs(D_seam - D_target) / (D_target + eps)
cost_sim = (
abs(S_seam - S_target) / (abs(S_target) + eps)
if use_ssim_similarity
else 0.0
)
cost_exp = (
abs(E_seam - E_target) / (E_target + eps)
if (use_exposure_guard and E_target > 0.0)
else 0.0
)
cost_flow = (
abs(F_seam - F_target) / (F_target + eps)
if (use_flow_guard and F_target > 0.0)
else 0.0
)
score = (
weight_step_size * cost_step
+ weight_similarity * cost_sim
+ weight_exposure * cost_exp
+ weight_flow * cost_flow
)
rows.append(
f"{extra},{score:.6f},{D_seam:.6f},{D_target:.6f},{S_seam:.6f},{S_target:.6f},{E_seam:.6f},{E_target:.6f},{F_seam:.6f},{F_target:.6f}"
)
if score < best_score:
best_score = score
best_extra = extra
pbar.update(1)
final_keep = max(2, B - best_extra)
cropped = clip_out[0:final_keep]
header = "end_crop,score,D_seam,D_target,S_seam,S_target,E_seam,E_target,F_seam,F_target"
diagnostics_csv = header + "\n" + "\n".join(rows) if rows else header
return (
cropped,
int(best_extra),
int(final_keep),
float(best_score),
diagnostics_csv,
)
class TrimBatchEnds:
"""
Trim frames from the START and/or END of an IMAGE batch (NHWC, [0..1]).
Both trims are applied in one pass. Always leaves at least one frame.
"""
DESCRIPTION = "Quickly remove frames from the start and/or end of a clip. Always keeps at least one frame."
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"clip_frames": ("IMAGE", {"tooltip": "Your clip (frames×H×W×C, 01)."}),
"trim_start_frames": (
"INT",
{
"default": 0,
"min": 0,
"max": 100000,
"tooltip": "Frames to remove from the START.",
},
),
"trim_end_frames": (
"INT",
{
"default": 0,
"min": 0,
"max": 100000,
"tooltip": "Frames to remove from the END.",
},
),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "crop"
CATEGORY = "video utils"
def crop(self, clip_frames, trim_start_frames, trim_end_frames):
import torch
if not isinstance(clip_frames, torch.Tensor) or clip_frames.ndim != 4:
return (clip_frames,)
B = int(clip_frames.shape[0])
if B <= 1:
return (clip_frames,)
s = max(0, int(trim_start_frames))
e = max(0, int(trim_end_frames))
if s + e >= B:
s = min(s, B - 1)
e = max(0, B - s - 1)
out = clip_frames[s : B - e] if e > 0 else clip_frames[s:]
if out.shape[0] == 0:
out = clip_frames[B - 1 : B]
return (out,)