Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
# Copyright (c) 2023–2025 Fannovel16 and contributors
|
||
# See LICENSES/MIT-ComfyUI-Frame-Interpolation.txt for the full text.
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
import pathlib
|
||
from vfi_utils import (
|
||
load_file_from_github_release,
|
||
preprocess_frames,
|
||
postprocess_frames,
|
||
generic_frame_loop,
|
||
InterpolationStateList,
|
||
)
|
||
import typing
|
||
from comfy.model_management import get_torch_device
|
||
import re
|
||
from functools import cmp_to_key
|
||
from packaging import version
|
||
|
||
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
||
CKPT_NAME_VER_DICT = {
|
||
"rife40.pth": "4.0",
|
||
"rife41.pth": "4.0",
|
||
"rife42.pth": "4.2",
|
||
"rife43.pth": "4.3",
|
||
"rife44.pth": "4.3",
|
||
"rife45.pth": "4.5",
|
||
"rife46.pth": "4.6",
|
||
"rife47.pth": "4.7",
|
||
"rife48.pth": "4.7",
|
||
"rife49.pth": "4.7",
|
||
"sudo_rife4_269.662_testV1_scale1.pth": "4.0",
|
||
# Arch 4.10 doesn't work due to state dict mismatch
|
||
# TODO: Investigating and fix it
|
||
# "rife410.pth": "4.10",
|
||
# "rife411.pth": "4.10",
|
||
# "rife412.pth": "4.10"
|
||
}
|
||
|
||
|
||
class RIFE_VFI:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"ckpt_name": (
|
||
sorted(
|
||
list(CKPT_NAME_VER_DICT.keys()),
|
||
key=lambda ckpt_name: version.parse(
|
||
CKPT_NAME_VER_DICT[ckpt_name]
|
||
),
|
||
),
|
||
{"default": "rife47.pth"},
|
||
),
|
||
"frames": ("IMAGE",),
|
||
"clear_cache_after_n_frames": (
|
||
"INT",
|
||
{"default": 10, "min": 1, "max": 1000},
|
||
),
|
||
"multiplier": ("INT", {"default": 2, "min": 1}),
|
||
"fast_mode": ("BOOLEAN", {"default": True}),
|
||
"ensemble": ("BOOLEAN", {"default": True}),
|
||
"scale_factor": ([0.25, 0.5, 1.0, 2.0, 4.0], {"default": 1.0}),
|
||
},
|
||
"optional": {"optional_interpolation_states": ("INTERPOLATION_STATES",)},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "vfi"
|
||
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
||
|
||
def vfi(
|
||
self,
|
||
ckpt_name: typing.AnyStr,
|
||
frames: torch.Tensor,
|
||
clear_cache_after_n_frames=10,
|
||
multiplier: typing.SupportsInt = 2,
|
||
fast_mode=False,
|
||
ensemble=False,
|
||
scale_factor=1.0,
|
||
optional_interpolation_states: InterpolationStateList = None,
|
||
**kwargs
|
||
):
|
||
"""
|
||
Perform video frame interpolation using a given checkpoint model.
|
||
|
||
Args:
|
||
ckpt_name (str): The name of the checkpoint model to use.
|
||
frames (torch.Tensor): A tensor containing input video frames.
|
||
clear_cache_after_n_frames (int, optional): The number of frames to process before clearing CUDA cache
|
||
to prevent memory overflow. Defaults to 10. Lower numbers are safer but mean more processing time.
|
||
How high you should set it depends on how many input frames there are, input resolution (after upscaling),
|
||
how many times you want to multiply them, and how long you're willing to wait for the process to complete.
|
||
multiplier (int, optional): The multiplier for each input frame. 60 input frames * 2 = 120 output frames. Defaults to 2.
|
||
|
||
Returns:
|
||
tuple: A tuple containing the output interpolated frames.
|
||
|
||
Note:
|
||
This method interpolates frames in a video sequence using a specified checkpoint model.
|
||
It processes each frame sequentially, generating interpolated frames between them.
|
||
|
||
To prevent memory overflow, it clears the CUDA cache after processing a specified number of frames.
|
||
"""
|
||
from .rife_arch import IFNet
|
||
|
||
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
|
||
arch_ver = CKPT_NAME_VER_DICT[ckpt_name]
|
||
interpolation_model = IFNet(arch_ver=arch_ver)
|
||
interpolation_model.load_state_dict(torch.load(model_path))
|
||
interpolation_model.eval().to(get_torch_device())
|
||
frames = preprocess_frames(frames)
|
||
|
||
def return_middle_frame(
|
||
frame_0, frame_1, timestep, model, scale_list, in_fast_mode, in_ensemble
|
||
):
|
||
return model(
|
||
frame_0, frame_1, timestep, scale_list, in_fast_mode, in_ensemble
|
||
)
|
||
|
||
scale_list = [
|
||
8 / scale_factor,
|
||
4 / scale_factor,
|
||
2 / scale_factor,
|
||
1 / scale_factor,
|
||
]
|
||
|
||
args = [interpolation_model, scale_list, fast_mode, ensemble]
|
||
out = postprocess_frames(
|
||
generic_frame_loop(
|
||
type(self).__name__,
|
||
frames,
|
||
clear_cache_after_n_frames,
|
||
multiplier,
|
||
return_middle_frame,
|
||
*args,
|
||
interpolation_states=optional_interpolation_states,
|
||
dtype=torch.float32
|
||
)
|
||
)
|
||
return (out,)
|