Support the LTXV 2 model. (#11632)
This commit is contained in:
@@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideoInplace",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
if image.shape[1] != height or image.shape[2] != width:
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
else:
|
||||
pixels = image
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
@@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
|
||||
class LTXVAddGuide(io.ComfyNode):
|
||||
NUM_PREFIX_FRAMES = 2
|
||||
PATCHIFIER = SymmetricPatchifier(1)
|
||||
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
_, latent_idx = cls.get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||
|
||||
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
|
||||
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
mask = guide_mask - strength
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||
return positive, negative, latent_image, noise_mask
|
||||
@@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image, noise_mask = cls.replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
@@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
|
||||
preprocess = execute # TODO: remove
|
||||
|
||||
|
||||
import comfy.nested_tensor
|
||||
class LTXVConcatAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConcatAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Latent.Input("video_latent"),
|
||||
io.Latent.Input("audio_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||
output = {}
|
||||
output.update(video_latent)
|
||||
output.update(audio_latent)
|
||||
video_noise_mask = video_latent.get("noise_mask", None)
|
||||
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||
|
||||
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||
if video_noise_mask is None:
|
||||
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||
if audio_noise_mask is None:
|
||||
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||
|
||||
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVSeparateAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
description="LTXV Separate AV Latent",
|
||||
inputs=[
|
||||
io.Latent.Input("av_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="video_latent"),
|
||||
io.Latent.Output(display_name="audio_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, av_latent) -> io.NodeOutput:
|
||||
latents = av_latent["samples"].unbind()
|
||||
video_latent = av_latent.copy()
|
||||
video_latent["samples"] = latents[0]
|
||||
audio_latent = av_latent.copy()
|
||||
audio_latent["samples"] = latents[1]
|
||||
if "noise_mask" in av_latent:
|
||||
masks = av_latent["noise_mask"]
|
||||
if masks is not None:
|
||||
masks = masks.unbind()
|
||||
video_latent["noise_mask"] = masks[0]
|
||||
audio_latent["noise_mask"] = masks[1]
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyLTXVLatentVideo,
|
||||
LTXVImgToVideo,
|
||||
LTXVImgToVideoInplace,
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user