Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1169 lines
41 KiB
Python
1169 lines
41 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
||
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
|
||
#
|
||
# Portions adapted from "ComfyUI-Frame-Interpolation" (MIT License)
|
||
# Copyright (c) 2023–2025 Fannovel16 and contributors
|
||
# See LICENSES/MIT-ComfyUI-Frame-Interpolation.txt for the full text.
|
||
|
||
# Imports for: RIFE_FPS_Resample, RIFE_SeamTimingAnalyzer_Advanced, RIFE_VFI_Advanced
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# Point Python at the sibling base pack
|
||
BASE = Path(__file__).resolve().parent.parent / "comfyui-frame-interpolation"
|
||
VFI_MODELS = BASE / "vfi_models"
|
||
|
||
# Prefer external comfyui-frame-interpolation; fall back to vendored /vendor/ inside this repo
|
||
p_models = str(
|
||
VFI_MODELS
|
||
) # external: .../custom_nodes/comfyui-frame-interpolation/vfi_models
|
||
p_base = str(BASE) # external: .../custom_nodes/comfyui-frame-interpolation
|
||
p_vendor = str(
|
||
Path(__file__).resolve().parent / "vendor"
|
||
) # <-- fix: repo-local vendor/
|
||
|
||
if VFI_MODELS.is_dir() and p_models not in sys.path:
|
||
sys.path.insert(0, p_models) # exposes external 'rife'
|
||
if BASE.is_dir() and p_base not in sys.path:
|
||
sys.path.insert(0, p_base) # exposes external 'vfi_utils'
|
||
if p_vendor not in sys.path:
|
||
sys.path.append(p_vendor) # fallback: vendored 'rife' and 'vfi_utils'
|
||
|
||
|
||
import math
|
||
import re
|
||
import typing
|
||
from fractions import Fraction
|
||
|
||
# === Cached RIFE model loader (shared by all nodes) ===
|
||
from functools import lru_cache
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
# --- Comfy + stdlib used by the three nodes ---
|
||
from comfy.model_management import get_torch_device
|
||
from comfy.utils import ProgressBar as _ComfyProgressBar
|
||
from rife import CKPT_NAME_VER_DICT, MODEL_TYPE
|
||
|
||
# --- imports from the base pack ---
|
||
from vfi_utils import (
|
||
InterpolationStateList,
|
||
generic_frame_loop,
|
||
load_file_from_github_release,
|
||
postprocess_frames,
|
||
preprocess_frames,
|
||
)
|
||
|
||
|
||
@lru_cache(maxsize=4)
|
||
def _get_rife_model(ckpt_name: str):
|
||
import torch
|
||
from comfy.model_management import get_torch_device
|
||
from rife import CKPT_NAME_VER_DICT, MODEL_TYPE
|
||
from rife.rife_arch import IFNet
|
||
from vfi_utils import load_file_from_github_release
|
||
|
||
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
|
||
arch_ver = CKPT_NAME_VER_DICT[ckpt_name]
|
||
m = IFNet(arch_ver=arch_ver)
|
||
sd = torch.load(model_path, map_location="cpu")
|
||
m.load_state_dict(sd, strict=True)
|
||
m.eval()
|
||
return m.to(get_torch_device()).to(memory_format=torch.channels_last)
|
||
|
||
|
||
# ---------- shared helpers (for all nodes) ----------
|
||
def _scale_list(scale_factor: float):
|
||
s = float(scale_factor)
|
||
return [8.0 / s, 4.0 / s, 2.0 / s, 1.0 / s]
|
||
|
||
|
||
def _prep_frames(frames: torch.Tensor) -> torch.Tensor:
|
||
x = preprocess_frames(frames)
|
||
if x.device.type == "cpu" and torch.cuda.is_available():
|
||
x = x.pin_memory()
|
||
return x
|
||
|
||
|
||
def _progress(total_ticks: int):
|
||
return _ComfyProgressBar(int(max(1, total_ticks)))
|
||
|
||
|
||
def _count_synth(frames: torch.Tensor, multiplier: int) -> int:
|
||
N = int(frames.shape[0])
|
||
m = max(0, int(multiplier))
|
||
return max(0, N - 1) * m
|
||
|
||
|
||
def _make_rife_callback():
|
||
def _cb(frame_0, frame_1, timestep, model, scale_list, ensemble, t_mapper):
|
||
if torch.is_tensor(timestep):
|
||
t_scalar = float(timestep.reshape(-1)[0].item())
|
||
else:
|
||
t_scalar = float(timestep)
|
||
t = t_mapper(t_scalar) if t_mapper is not None else t_scalar
|
||
# fast_mode is a no-op for 4.7/4.9; pass False
|
||
return model(frame_0, frame_1, t, scale_list, False, ensemble)
|
||
|
||
return _cb
|
||
|
||
|
||
def _with_progress(cb, pbar):
|
||
if pbar is None:
|
||
return cb
|
||
|
||
def _wrapped(*args, **kwargs):
|
||
y = cb(*args, **kwargs)
|
||
pbar.update(1)
|
||
return y
|
||
|
||
return _wrapped
|
||
|
||
|
||
# ----------------------------------------------------
|
||
|
||
|
||
class RIFE_VFI_Opt:
|
||
DESCRIPTION = (
|
||
"Interpolate a clip by a chosen multiple using RIFE 4.7/4.9 — inserts evenly spaced in-between frames "
|
||
"between every pair (e.g., ×2 adds 1 frame per pair)."
|
||
)
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"ckpt_name": (
|
||
["rife47.pth", "rife49.pth"],
|
||
{
|
||
"default": "rife47.pth",
|
||
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
|
||
},
|
||
),
|
||
"frames": (
|
||
"IMAGE",
|
||
{"tooltip": "Your input clip: one image per frame."},
|
||
),
|
||
"multiplier": (
|
||
"INT",
|
||
{
|
||
"default": 2,
|
||
"min": 1,
|
||
"tooltip": "Adds extra frames to smooth motion: 2 adds 1 new frame per pair; 4 adds 3.",
|
||
},
|
||
),
|
||
"scale_factor": (
|
||
[0.25, 0.5, 1.0, 2.0, 4.0],
|
||
{
|
||
"default": 1.0,
|
||
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
|
||
},
|
||
),
|
||
"ensemble": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
|
||
},
|
||
),
|
||
"clear_cache_after_n_frames": (
|
||
"INT",
|
||
{
|
||
"default": 10,
|
||
"min": 0,
|
||
"max": 1000,
|
||
"tooltip": "Free up GPU memory every N generated frames (advanced). Set 0 to never.",
|
||
},
|
||
),
|
||
},
|
||
"optional": {
|
||
"optional_interpolation_states": (
|
||
"INTERPOLATION_STATES",
|
||
{
|
||
"tooltip": "Don’t create in-between frames for selected frame pairs (e.g., scene cuts). Timing stays the same."
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "vfi"
|
||
CATEGORY = "video utils"
|
||
|
||
def vfi(
|
||
self,
|
||
ckpt_name: typing.AnyStr,
|
||
frames: torch.Tensor,
|
||
multiplier: int = 2,
|
||
scale_factor: float = 1.0,
|
||
ensemble: bool = True,
|
||
clear_cache_after_n_frames: int = 10,
|
||
optional_interpolation_states: InterpolationStateList = None,
|
||
**kwargs,
|
||
):
|
||
model = _get_rife_model(ckpt_name)
|
||
|
||
frames = _prep_frames(frames)
|
||
|
||
m_effective = int(multiplier)
|
||
if m_effective <= 1:
|
||
return (postprocess_frames(frames),)
|
||
|
||
scale_list = _scale_list(scale_factor)
|
||
total_synth = _count_synth(frames, m_effective)
|
||
pbar = _progress(total_synth)
|
||
cb = _make_rife_callback()
|
||
cbp = _with_progress(cb, pbar)
|
||
|
||
args = [model, scale_list, ensemble, None] # t_mapper=None → identity
|
||
|
||
with torch.inference_mode():
|
||
out = postprocess_frames(
|
||
generic_frame_loop(
|
||
type(self).__name__,
|
||
frames,
|
||
clear_cache_after_n_frames,
|
||
m_effective,
|
||
cbp,
|
||
*args,
|
||
interpolation_states=optional_interpolation_states,
|
||
dtype=torch.float32,
|
||
)
|
||
)
|
||
return (out,)
|
||
|
||
|
||
class RIFE_VFI_Advanced:
|
||
"""
|
||
Advanced RIFE node exposing timestep controls (by multiple, custom t-schedules).
|
||
"""
|
||
|
||
DESCRIPTION = (
|
||
"Custom timing for RIFE 4.7/4.9 — still “interpolate by multiple,” but you control where the in-betweens land "
|
||
"(ease in/out, clamps, or your own t-list)."
|
||
)
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"ckpt_name": (
|
||
["rife47.pth", "rife49.pth"],
|
||
{
|
||
"default": "rife47.pth",
|
||
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
|
||
},
|
||
),
|
||
"frames": (
|
||
"IMAGE",
|
||
{"tooltip": "Your input clip: one image per frame."},
|
||
),
|
||
"multiplier": (
|
||
"INT",
|
||
{
|
||
"default": 2,
|
||
"min": 0,
|
||
"tooltip": "How many new frames to create between each pair. 0 = passthrough (no new frames).",
|
||
},
|
||
),
|
||
"t_mode": (
|
||
[
|
||
"linear",
|
||
"gamma_in",
|
||
"gamma_out",
|
||
"gamma_in_out",
|
||
"bounded_linear",
|
||
"custom_list",
|
||
],
|
||
{
|
||
"default": "linear",
|
||
"tooltip": "How to spread the new frames over time: straight line, ease in/out, limit the range, or provide your own list.",
|
||
},
|
||
),
|
||
"t_gamma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.0,
|
||
"min": 0.05,
|
||
"max": 10.0,
|
||
"step": 0.05,
|
||
"tooltip": "Easing strength for gamma modes. Higher = more easing.",
|
||
},
|
||
),
|
||
"t_min": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.0,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "Earliest allowed position between the two frames (0 = exactly the first frame). Use with bounded_linear.",
|
||
},
|
||
),
|
||
"t_max": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.0,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "Latest allowed position between the two frames (1 = exactly the next frame). Use with bounded_linear.",
|
||
},
|
||
),
|
||
"scale_factor": (
|
||
[0.25, 0.5, 1.0, 2.0, 4.0],
|
||
{
|
||
"default": 1.0,
|
||
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
|
||
},
|
||
),
|
||
"ensemble": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
|
||
},
|
||
),
|
||
"clear_cache_after_n_frames": (
|
||
"INT",
|
||
{
|
||
"default": 10,
|
||
"min": 0,
|
||
"max": 1000,
|
||
"tooltip": "Free up GPU memory every N generated frames (advanced). Set 0 to never.",
|
||
},
|
||
),
|
||
},
|
||
"optional": {
|
||
"custom_t_list_csv": (
|
||
"STRING",
|
||
{
|
||
"default": "",
|
||
"tooltip": "Exact positions between the two frames (0–1), comma-separated, e.g. 0.18,0.41,0.66. Overrides the schedule.",
|
||
},
|
||
),
|
||
"optional_interpolation_states": (
|
||
"INTERPOLATION_STATES",
|
||
{
|
||
"tooltip": "Don’t create in-between frames for selected frame pairs (e.g., scene cuts). Timing stays the same."
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "vfi_advanced"
|
||
CATEGORY = "video utils"
|
||
|
||
def _map_t(
|
||
self, t_scalar, t_mode, t_gamma, t_min, t_max, custom_sorted, m_effective
|
||
):
|
||
t = max(0.0, min(1.0, float(t_scalar)))
|
||
if t_mode == "custom_list" and custom_sorted:
|
||
idx = int(round(t * (m_effective + 1))) - 1
|
||
idx = max(0, min(len(custom_sorted) - 1, idx))
|
||
return float(custom_sorted[idx])
|
||
|
||
if t_mode == "gamma_in":
|
||
t = t ** max(1e-6, t_gamma)
|
||
elif t_mode == "gamma_out":
|
||
t = 1.0 - (1.0 - t) ** max(1e-6, t_gamma)
|
||
elif t_mode == "gamma_in_out":
|
||
g = max(1e-6, t_gamma)
|
||
if t < 0.5:
|
||
t = 0.5 * (2.0 * t) ** g
|
||
else:
|
||
t = 1.0 - 0.5 * (2.0 * (1.0 - t)) ** g
|
||
return max(0.0, min(1.0, t_min + (t_max - t_min) * t))
|
||
|
||
def vfi_advanced(
|
||
self,
|
||
ckpt_name,
|
||
frames,
|
||
clear_cache_after_n_frames=10,
|
||
multiplier=2,
|
||
ensemble=True,
|
||
scale_factor=1.0,
|
||
t_mode="linear",
|
||
t_gamma=1.0,
|
||
t_min=0.0,
|
||
t_max=1.0,
|
||
custom_t_list_csv="",
|
||
optional_interpolation_states=None,
|
||
**kwargs,
|
||
):
|
||
model = _get_rife_model(ckpt_name)
|
||
device = next(model.parameters()).device
|
||
dtype = next(model.parameters()).dtype
|
||
|
||
frames = _prep_frames(frames)
|
||
|
||
custom_sorted = None
|
||
if custom_t_list_csv and t_mode == "custom_list":
|
||
toks = [x for x in re.split(r"[,\s]+", custom_t_list_csv.strip()) if x]
|
||
custom_sorted = sorted([max(0.0, min(1.0, float(x))) for x in toks])
|
||
|
||
m_effective = int(multiplier)
|
||
if m_effective <= 0:
|
||
return (postprocess_frames(frames),)
|
||
|
||
def _adv_t_mapper(t_scalar: float) -> float:
|
||
return self._map_t(
|
||
t_scalar, t_mode, t_gamma, t_min, t_max, custom_sorted, m_effective
|
||
)
|
||
|
||
scale_list = _scale_list(scale_factor)
|
||
total_synth = _count_synth(frames, m_effective)
|
||
pbar = _progress(total_synth)
|
||
cb = _make_rife_callback()
|
||
cbp = _with_progress(cb, pbar)
|
||
|
||
args = [model, scale_list, ensemble, _adv_t_mapper]
|
||
|
||
with torch.inference_mode():
|
||
out = postprocess_frames(
|
||
generic_frame_loop(
|
||
type(self).__name__,
|
||
frames,
|
||
clear_cache_after_n_frames,
|
||
m_effective,
|
||
cbp,
|
||
*args,
|
||
interpolation_states=optional_interpolation_states,
|
||
dtype=torch.float32,
|
||
)
|
||
)
|
||
return (out,)
|
||
|
||
|
||
class RIFE_FPS_Resample:
|
||
DESCRIPTION = (
|
||
"Convert a clip from one FPS to another using RIFE 4.7/4.9. Non-integer changes synthesize in-betweens; "
|
||
"exact integer downscales just decimate. Includes optional stabilizers to reduce flicker and protect edges."
|
||
)
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"ckpt_name": (
|
||
["rife47.pth", "rife49.pth"],
|
||
{
|
||
"default": "rife47.pth",
|
||
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
|
||
},
|
||
),
|
||
"frames": (
|
||
"IMAGE",
|
||
{"tooltip": "Your input clip: one image per frame."},
|
||
),
|
||
"fps_in": (
|
||
"FLOAT",
|
||
{
|
||
"default": 24.0,
|
||
"min": 1e-6,
|
||
"max": 1000.0,
|
||
"step": 0.01,
|
||
"tooltip": "Current frame rate of your clip (frames per second).",
|
||
},
|
||
),
|
||
"fps_out": (
|
||
"FLOAT",
|
||
{
|
||
"default": 60.0,
|
||
"min": 1e-6,
|
||
"max": 2000.0,
|
||
"step": 0.01,
|
||
"tooltip": "Target frame rate you want (frames per second).",
|
||
},
|
||
),
|
||
"scale_factor": (
|
||
[0.25, 0.5, 1.0, 2.0, 4.0],
|
||
{
|
||
"default": 1.0,
|
||
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
|
||
},
|
||
),
|
||
"ensemble": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
|
||
},
|
||
),
|
||
"linearize": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Work in linear light for more accurate brightness and gradients (slower).",
|
||
},
|
||
),
|
||
"lf_guardrail": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Keep overall brightness and gradients close to the originals to reduce flicker.",
|
||
},
|
||
),
|
||
"lf_sigma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 13.0,
|
||
"min": 0.0,
|
||
"max": 64.0,
|
||
"step": 0.5,
|
||
"tooltip": "How strong the low-frequency smoothing is. Higher = smoother changes.",
|
||
},
|
||
),
|
||
"source_pair_match": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Match exposure and contrast to the source pair to reduce flicker.",
|
||
},
|
||
),
|
||
"match_a_cap": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.02,
|
||
"min": 0.0,
|
||
"max": 0.2,
|
||
"step": 0.001,
|
||
"tooltip": "Maximum change allowed for exposure scale.",
|
||
},
|
||
),
|
||
"match_b_cap": (
|
||
"FLOAT",
|
||
{
|
||
"default": 2.0 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.1,
|
||
"step": 0.0005,
|
||
"tooltip": "Maximum change allowed for brightness offset.",
|
||
},
|
||
),
|
||
"edge_band_lock": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Protect sharp edges: near edges, mix in more of the nearest real frame to avoid smearing.",
|
||
},
|
||
),
|
||
"tau_low": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.5 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.25,
|
||
"step": 0.0005,
|
||
"tooltip": "Edge sensitivity: lower threshold (smaller finds more edges).",
|
||
},
|
||
),
|
||
"tau_high": (
|
||
"FLOAT",
|
||
{
|
||
"default": 6.0 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.5,
|
||
"step": 0.0005,
|
||
"tooltip": "Edge sensitivity: higher threshold (larger finds only strong edges).",
|
||
},
|
||
),
|
||
"band_radius": (
|
||
"INT",
|
||
{
|
||
"default": 4,
|
||
"min": 0,
|
||
"max": 64,
|
||
"tooltip": "Width of the edge protection band (pixels).",
|
||
},
|
||
),
|
||
"band_soft_sigma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 2.0,
|
||
"min": 0.0,
|
||
"max": 16.0,
|
||
"step": 0.5,
|
||
"tooltip": "Soften the edge band. Higher = smoother.",
|
||
},
|
||
),
|
||
"clear_cache_after_n_frames": (
|
||
"INT",
|
||
{
|
||
"default": 10,
|
||
"min": 1,
|
||
"max": 1000,
|
||
"tooltip": "Free up GPU memory every N output frames (advanced). Set 0 to never.",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "resample"
|
||
CATEGORY = "video utils"
|
||
|
||
@staticmethod
|
||
def _srgb_to_linear(x):
|
||
import torch
|
||
|
||
return torch.where(
|
||
x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055).clamp(min=0) ** 2.4
|
||
)
|
||
|
||
@staticmethod
|
||
def _linear_to_srgb(x):
|
||
import torch
|
||
|
||
return torch.where(
|
||
x <= 0.0031308, 12.92 * x, 1.055 * x.clamp(min=0) ** (1.0 / 2.4) - 0.055
|
||
)
|
||
|
||
@staticmethod
|
||
def _gaussian_kernel1d(sigma, radius):
|
||
import torch
|
||
|
||
if radius <= 0:
|
||
k = torch.tensor([1.0], dtype=torch.float32)
|
||
return k / k.sum()
|
||
xs = torch.arange(-radius, radius + 1, dtype=torch.float32)
|
||
k = torch.exp(-(xs * xs) / (2.0 * sigma * sigma))
|
||
return k / k.sum()
|
||
|
||
_GK_CACHE = {} # (float(sigma), int(radius), int(C)) -> (kH, kW)
|
||
|
||
@staticmethod
|
||
def _gaussian_blur_nhwc(x, sigma):
|
||
if sigma <= 0:
|
||
return x
|
||
|
||
N, H, W, C = x.shape
|
||
max_radius = max(0, min(H, W) // 2 - 1)
|
||
radius = min(int(math.ceil(3.0 * sigma)), max_radius)
|
||
if radius <= 0:
|
||
return x
|
||
|
||
key = (float(sigma), int(radius), int(C))
|
||
k = RIFE_FPS_Resample._GK_CACHE.get(key)
|
||
if k is None:
|
||
g = RIFE_FPS_Resample._gaussian_kernel1d(sigma, radius)
|
||
kH = g.view(1, 1, -1, 1).repeat(C, 1, 1, 1)
|
||
kW = g.view(1, 1, 1, -1).repeat(C, 1, 1, 1)
|
||
RIFE_FPS_Resample._GK_CACHE[key] = (kH, kW)
|
||
k = (kH, kW)
|
||
kH, kW = k
|
||
|
||
x_nchw = x.permute(0, 3, 1, 2).contiguous()
|
||
x1 = F.pad(x_nchw, (0, 0, radius, radius), mode="reflect")
|
||
x1 = F.conv2d(x1, kH, groups=C)
|
||
x2 = F.pad(x1, (radius, radius, 0, 0), mode="reflect")
|
||
x2 = F.conv2d(x2, kW, groups=C)
|
||
return x2.permute(0, 2, 3, 1).contiguous()
|
||
|
||
@staticmethod
|
||
def _luma_linear(x_lin):
|
||
import torch
|
||
|
||
r, g, b = x_lin[..., 0], x_lin[..., 1], x_lin[..., 2]
|
||
return 0.2126 * r + 0.7152 * g + 0.0722 * b
|
||
|
||
@staticmethod
|
||
def _exposure_match_ab(out_lin, tgt_lin, a_cap, b_cap):
|
||
import torch
|
||
|
||
y_o = RIFE_FPS_Resample._luma_linear(out_lin).flatten()
|
||
y_t = RIFE_FPS_Resample._luma_linear(tgt_lin).flatten()
|
||
p = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float32)
|
||
p5_o, p50_o, p95_o = torch.quantile(y_o, p)
|
||
p5_t, p50_t, p95_t = torch.quantile(y_t, p)
|
||
eps = 1e-6
|
||
a = ((p95_t - p5_t) / (p95_o - p5_o + eps)).item()
|
||
a = max(1.0 - a_cap, min(1.0 + a_cap, a))
|
||
b = (p50_t - a * p50_o).item()
|
||
b = max(-b_cap, min(b_cap, b))
|
||
return a, b
|
||
|
||
@staticmethod
|
||
def _dilate_mask_nhwc(mask01, radius):
|
||
if radius <= 0:
|
||
return mask01
|
||
x = mask01.permute(0, 3, 1, 2)
|
||
x = F.pad(x, (radius, radius, radius, radius), mode="replicate")
|
||
x = F.max_pool2d(x, kernel_size=2 * radius + 1, stride=1)
|
||
return x.permute(0, 2, 3, 1)
|
||
|
||
@torch.inference_mode()
|
||
def resample(
|
||
self,
|
||
ckpt_name,
|
||
fps_in,
|
||
fps_out,
|
||
frames,
|
||
scale_factor=1.0,
|
||
ensemble=True,
|
||
linearize=False,
|
||
lf_guardrail=False,
|
||
lf_sigma=13.0,
|
||
source_pair_match=False,
|
||
match_a_cap=0.02,
|
||
match_b_cap=2.0 / 255.0,
|
||
edge_band_lock=False,
|
||
tau_low=1.5 / 255.0,
|
||
tau_high=6.0 / 255.0,
|
||
band_radius=4,
|
||
band_soft_sigma=2.0,
|
||
clear_cache_after_n_frames=10,
|
||
**kwargs,
|
||
):
|
||
|
||
if fps_in <= 0 or fps_out <= 0:
|
||
raise ValueError("fps_in and fps_out must be > 0")
|
||
|
||
F_in = Fraction(str(float(fps_in)))
|
||
F_out = Fraction(str(float(fps_out)))
|
||
same_rate = F_in == F_out
|
||
ratio = F_in / F_out
|
||
integer_downscale = ratio.denominator == 1 and ratio.numerator > 1
|
||
|
||
if same_rate:
|
||
return (frames,)
|
||
if integer_downscale:
|
||
step = int(ratio.numerator)
|
||
return (frames[::step].contiguous(),)
|
||
|
||
model = _get_rife_model(ckpt_name)
|
||
device = next(model.parameters()).device
|
||
dtype = next(model.parameters()).dtype
|
||
|
||
x_nhwc = _prep_frames(frames)
|
||
|
||
n_in = int(x_nhwc.shape[0])
|
||
if n_in <= 1:
|
||
return (postprocess_frames(x_nhwc[:1]),)
|
||
|
||
K_max = int(Fraction(n_in - 1) * F_out / F_in)
|
||
n_out = K_max + 1
|
||
scale_list = _scale_list(scale_factor)
|
||
|
||
def to_nchw_on_device(idx: int):
|
||
f = x_nhwc[idx : idx + 1].permute(0, 3, 1, 2).contiguous()
|
||
return f.to(
|
||
device=device,
|
||
dtype=dtype,
|
||
non_blocking=True,
|
||
memory_format=torch.channels_last,
|
||
)
|
||
|
||
cache = {}
|
||
|
||
def synth(i: int, t_local: float):
|
||
if i not in cache:
|
||
cache[i] = to_nchw_on_device(i)
|
||
if (i + 1) not in cache:
|
||
cache[i + 1] = to_nchw_on_device(i + 1)
|
||
# fast_mode is a no-op for 4.7/4.9; pass False explicitly
|
||
y = model(
|
||
cache[i], cache[i + 1], float(t_local), scale_list, False, ensemble
|
||
)
|
||
return y.permute(0, 2, 3, 1).contiguous().detach().cpu().to(torch.float32)
|
||
|
||
to_lin = self._srgb_to_linear if linearize else (lambda z: z)
|
||
to_srgb = self._linear_to_srgb if linearize else (lambda z: z)
|
||
blur = self._gaussian_blur_nhwc
|
||
|
||
y0 = x_nhwc[0:1]
|
||
H, W, C = int(y0.shape[1]), int(y0.shape[2]), int(y0.shape[3])
|
||
y_buf = torch.empty((n_out, H, W, C), dtype=torch.float32, device="cpu")
|
||
|
||
def _write(idx: int, frame_nhwc_cpu: torch.Tensor):
|
||
y_buf[idx : idx + 1].copy_(frame_nhwc_cpu)
|
||
|
||
num = F_in.numerator * F_out.denominator
|
||
den = F_in.denominator * F_out.numerator
|
||
acc = 0
|
||
frames_since_clear = 0
|
||
|
||
pbar = _progress(int(n_out))
|
||
|
||
for k in range(n_out):
|
||
pbar.update(1)
|
||
i = acc // den
|
||
|
||
frac_num = acc % den
|
||
|
||
if i >= (n_in - 1):
|
||
_write(k, x_nhwc[n_in - 1 : n_in])
|
||
else:
|
||
if frac_num == 0:
|
||
_write(k, x_nhwc[i : i + 1])
|
||
else:
|
||
frac = float(frac_num) / den
|
||
y = synth(i, frac)
|
||
|
||
if linearize or lf_guardrail or source_pair_match or edge_band_lock:
|
||
A = x_nhwc[i : i + 1]
|
||
B = x_nhwc[i + 1 : i + 2]
|
||
y_lin, A_lin, B_lin = to_lin(y), to_lin(A), to_lin(B)
|
||
|
||
if lf_guardrail and lf_sigma > 0:
|
||
HF = y_lin - blur(y_lin, lf_sigma)
|
||
LF_target = (1.0 - frac) * blur(
|
||
A_lin, lf_sigma
|
||
) + frac * blur(B_lin, lf_sigma)
|
||
y_lin = HF + LF_target
|
||
|
||
if source_pair_match:
|
||
tgt = (1.0 - frac) * A_lin + frac * B_lin
|
||
a, b = self._exposure_match_ab(
|
||
y_lin, tgt, match_a_cap, match_b_cap
|
||
)
|
||
y_lin = (a * y_lin + b).clamp(0.0, 1.0)
|
||
|
||
if edge_band_lock:
|
||
d = (
|
||
(self._luma_linear(A_lin) - self._luma_linear(B_lin))
|
||
.abs()
|
||
.unsqueeze(-1)
|
||
)
|
||
high = (d > tau_high).float()
|
||
low = (d < tau_low).float()
|
||
band = self._dilate_mask_nhwc(high, band_radius) * low
|
||
if band_soft_sigma > 0:
|
||
band = blur(band, band_soft_sigma).clamp(0.0, 1.0)
|
||
near = A_lin if frac < 0.5 else B_lin
|
||
y_lin = band * near + (1.0 - band) * y_lin
|
||
|
||
y = to_srgb(y_lin).clamp(0.0, 1.0)
|
||
|
||
_write(k, y)
|
||
|
||
frames_since_clear += 1
|
||
if torch.cuda.is_available() and int(clear_cache_after_n_frames) > 0:
|
||
if frames_since_clear >= int(clear_cache_after_n_frames):
|
||
torch.cuda.empty_cache()
|
||
frames_since_clear = 0
|
||
|
||
acc += num
|
||
|
||
return (postprocess_frames(y_buf),)
|
||
|
||
|
||
class RIFE_SeamTimingAnalyzer:
|
||
"""
|
||
Batch-in, timings-out. Chooses t-schedule for the seam [last, first] using
|
||
visual step sizes measured from REAL frames in the input batch.
|
||
|
||
Inputs:
|
||
- full_clip (IMAGE, NHWC [N,H,W,C], required): real frames (e.g., 1..10)
|
||
- multiplier (INT): number of in-betweens to synthesize for the seam
|
||
- use_first_two (BOOL): include dist(clip[0], clip[1]) as a target
|
||
- use_last_two (BOOL): include dist(clip[-2], clip[-1]) as a target
|
||
- use_global_median (BOOL): include median of ALL adjacent deltas (N>=3)
|
||
- ckpt_name / ensemble / scale_factor: RIFE runtime knobs (only used to test t)
|
||
- calibrate_metric: "MSE" or "L1" visual distance
|
||
- calibrate_iters: binary-search iterations
|
||
- t_min / t_max: clamp t search range (keep t_max < 1.0 to avoid hugging frame 1)
|
||
|
||
Outputs:
|
||
- t_list_csv (STRING): e.g. "0.183112, 0.408217, 0.653991, 0.913541"
|
||
- multiplier (INT): echo of input multiplier for wiring convenience
|
||
"""
|
||
|
||
DESCRIPTION = (
|
||
"Finds a smooth loop timing: measures motion in your clip and solves a set of t-values across the wrap "
|
||
"[last→first] so the seam blends naturally."
|
||
)
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"ckpt_name": (
|
||
["rife47.pth", "rife49.pth"],
|
||
{
|
||
"default": "rife47.pth",
|
||
"tooltip": "Choose the RIFE 4.7 or 4.9 model file (used only to test candidate timings).",
|
||
},
|
||
),
|
||
"scale_factor": (
|
||
[0.25, 0.5, 1.0, 2.0, 4.0],
|
||
{
|
||
"default": 1.0,
|
||
"tooltip": "Quality vs speed for the probe renders. 1.0 recommended.",
|
||
},
|
||
),
|
||
"ensemble": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
|
||
},
|
||
),
|
||
"full_clip": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Your input clip (≥2 frames). Real motion here decides the loop seam timing."
|
||
},
|
||
),
|
||
"multiplier": (
|
||
"INT",
|
||
{
|
||
"default": 4,
|
||
"min": 0,
|
||
"tooltip": "How many new frames you plan to create at the loop seam [last→first]. Set 0 to skip.",
|
||
},
|
||
),
|
||
"use_first_two": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Match the motion between the first two frames in your clip.",
|
||
},
|
||
),
|
||
"use_last_two": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Match the motion between the last two frames in your clip.",
|
||
},
|
||
),
|
||
"use_global_median": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Use the median motion across the whole clip (needs ≥3 frames). Helps ignore outliers.",
|
||
},
|
||
),
|
||
"calibrate_metric": (
|
||
["MSE", "L1"],
|
||
{
|
||
"default": "MSE",
|
||
"tooltip": "How we compare frames while solving: MSE (more sensitive) or L1 (more forgiving).",
|
||
},
|
||
),
|
||
"calibrate_iters": (
|
||
"INT",
|
||
{
|
||
"default": 12,
|
||
"min": 4,
|
||
"max": 24,
|
||
"tooltip": "Search depth per solve. Higher = slower, but a tighter match.",
|
||
},
|
||
),
|
||
"t_min": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.00,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "Earliest allowed blend point at the seam (0 = exactly the last frame).",
|
||
},
|
||
),
|
||
"t_max": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.96,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "Latest allowed blend point at the seam (keep below 1.0 to avoid sticking to the first frame).",
|
||
},
|
||
),
|
||
},
|
||
"optional": {
|
||
"auto_tmax": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Automatically push the upper limit closer to the next frame to hit the target motion step.",
|
||
},
|
||
),
|
||
"t_cap": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.995,
|
||
"min": 0.5,
|
||
"max": 0.9999,
|
||
"step": 0.0001,
|
||
"tooltip": "Safety cap used with the auto upper limit (keeps it just shy of 1.0).",
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING", "INT")
|
||
RETURN_NAMES = ("t_list_csv", "multiplier")
|
||
FUNCTION = "analyze_wrapper"
|
||
CATEGORY = "video utils"
|
||
|
||
# --- helpers ---
|
||
def _to_nchw(self, x):
|
||
if x.ndim == 4 and x.shape[-1] in (1, 3, 4):
|
||
return x.permute(0, 3, 1, 2).contiguous()
|
||
return x
|
||
|
||
def _dist(self, A, B, kind="MSE"):
|
||
if kind == "L1":
|
||
return (A - B).abs().mean()
|
||
return ((A - B) ** 2).mean()
|
||
|
||
def _adjacent_deltas(self, clip_nchw, kind="MSE"):
|
||
N = clip_nchw.shape[0]
|
||
ds = []
|
||
for i in range(N - 1):
|
||
ds.append(self._dist(clip_nchw[i : i + 1], clip_nchw[i + 1 : i + 2], kind))
|
||
return ds # list of scalar tensors
|
||
|
||
def analyze_wrapper(self, **kwargs):
|
||
"""
|
||
Safe entrypoint for Comfy mapping:
|
||
- If 'multiplier' is omitted: default to 0 (passthrough).
|
||
- If 'multiplier' is invalid/negative: normalize to 0 (passthrough).
|
||
"""
|
||
if "multiplier" not in kwargs:
|
||
kwargs["multiplier"] = 0
|
||
|
||
try:
|
||
m = int(kwargs["multiplier"])
|
||
except (TypeError, ValueError):
|
||
m = 0
|
||
if m < 0:
|
||
m = 0
|
||
|
||
kwargs["multiplier"] = m
|
||
return self.analyze(**kwargs)
|
||
|
||
def analyze(
|
||
self,
|
||
full_clip,
|
||
multiplier,
|
||
use_first_two,
|
||
use_last_two,
|
||
use_global_median,
|
||
ckpt_name,
|
||
ensemble,
|
||
scale_factor,
|
||
calibrate_metric,
|
||
calibrate_iters,
|
||
t_min,
|
||
t_max,
|
||
auto_tmax=False,
|
||
t_cap=0.995,
|
||
):
|
||
|
||
try:
|
||
m = int(multiplier)
|
||
except (TypeError, ValueError):
|
||
m = 0
|
||
if m < 0:
|
||
m = 0
|
||
if m == 0:
|
||
return ("", 0)
|
||
|
||
clip_nhwc = preprocess_frames(full_clip)
|
||
if clip_nhwc.device.type == "cpu" and torch.cuda.is_available():
|
||
clip_nhwc = clip_nhwc.pin_memory()
|
||
|
||
N = int(clip_nhwc.shape[0])
|
||
if N < 2:
|
||
raise ValueError("full_clip must contain at least 2 frames.")
|
||
|
||
metric = "L1" if calibrate_metric == "L1" else "MSE"
|
||
clip_nchw_cpu = self._to_nchw(clip_nhwc)
|
||
deltas = self._adjacent_deltas(clip_nchw_cpu, metric)
|
||
|
||
chosen = []
|
||
if use_first_two:
|
||
chosen.append(deltas[0])
|
||
if use_last_two:
|
||
chosen.append(deltas[-1])
|
||
if use_global_median:
|
||
if N < 3:
|
||
raise ValueError(
|
||
"use_global_median requires full_clip with >= 3 frames."
|
||
)
|
||
chosen.append(torch.stack(deltas).median())
|
||
|
||
if not chosen:
|
||
raise ValueError(
|
||
"Enable at least one of: use_first_two, use_last_two, use_global_median."
|
||
)
|
||
|
||
d_target = (
|
||
chosen[0] if len(chosen) == 1 else torch.stack(chosen).median()
|
||
).item()
|
||
iters = int(calibrate_iters)
|
||
m = int(m)
|
||
|
||
model = _get_rife_model(ckpt_name)
|
||
device = next(model.parameters()).device
|
||
dtype = next(model.parameters()).dtype # usually float32
|
||
|
||
last_nchw = self._to_nchw(clip_nhwc[N - 1 : N]).to(
|
||
device=device,
|
||
dtype=dtype,
|
||
non_blocking=True,
|
||
memory_format=torch.channels_last,
|
||
)
|
||
first_nchw = self._to_nchw(clip_nhwc[0:1]).to(
|
||
device=device,
|
||
dtype=dtype,
|
||
non_blocking=True,
|
||
memory_format=torch.channels_last,
|
||
)
|
||
|
||
t_min = float(max(0.0, min(1.0, t_min)))
|
||
t_max = float(max(0.0, min(1.0, t_max)))
|
||
if not (0.0 <= t_min < t_max <= 1.0):
|
||
raise ValueError("Require 0 <= t_min < t_max <= 1.0")
|
||
|
||
hi_eff = t_max
|
||
if auto_tmax:
|
||
# Raise the upper bracket to a safe cap near 1.0 so we can hit the target step size.
|
||
hi_eff = max(hi_eff, min(max(t_min + 1e-6, float(t_cap)), 0.9999))
|
||
|
||
scale_list = [
|
||
8 / scale_factor,
|
||
4 / scale_factor,
|
||
2 / scale_factor,
|
||
1 / scale_factor,
|
||
]
|
||
|
||
@torch.no_grad()
|
||
def synth_at(t_scalar: float):
|
||
t_scalar = float(max(t_min, min(hi_eff, t_scalar)))
|
||
return model(last_nchw, first_nchw, t_scalar, scale_list, False, ensemble)
|
||
|
||
lo, hi = t_min, hi_eff
|
||
with torch.inference_mode():
|
||
for _ in range(iters):
|
||
mid = 0.5 * (lo + hi)
|
||
d_mid = float(self._dist(synth_at(mid), first_nchw, metric).item())
|
||
if d_mid > d_target:
|
||
lo = mid
|
||
else:
|
||
hi = mid
|
||
t_last = 0.5 * (lo + hi)
|
||
f_prev = synth_at(t_last)
|
||
|
||
ts = [t_last]
|
||
for _ in range(m - 1):
|
||
lo, hi = t_min, ts[-1]
|
||
for _ in range(iters):
|
||
mid = 0.5 * (lo + hi)
|
||
d_mid = float(self._dist(synth_at(mid), f_prev, metric).item())
|
||
if d_mid > d_target:
|
||
lo = mid
|
||
else:
|
||
hi = mid
|
||
tk = 0.5 * (lo + hi)
|
||
ts.append(tk)
|
||
f_prev = synth_at(tk)
|
||
|
||
ts_sorted = sorted(ts)
|
||
csv = ", ".join(f"{t:.6f}" for t in ts_sorted)
|
||
return (csv, m)
|