Support the LTXV 2 model. (#11632)
This commit is contained in:
@@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ import comfy.model_management
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||
import folder_paths
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||
|
||||
if "blocks.0.block.0.conv.weight" in sd:
|
||||
config = {
|
||||
@@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"global_residual": False,
|
||||
}
|
||||
model_type = "720p"
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||
config = {
|
||||
@@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||
}
|
||||
model_type = "1080p"
|
||||
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||
config = json.loads(metadata["config"])
|
||||
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||
model.load_state_dict(sd)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
@@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideoInplace",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
if image.shape[1] != height or image.shape[2] != width:
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
else:
|
||||
pixels = image
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
@@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
|
||||
class LTXVAddGuide(io.ComfyNode):
|
||||
NUM_PREFIX_FRAMES = 2
|
||||
PATCHIFIER = SymmetricPatchifier(1)
|
||||
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
_, latent_idx = cls.get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||
|
||||
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
|
||||
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
mask = guide_mask - strength
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||
return positive, negative, latent_image, noise_mask
|
||||
@@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image, noise_mask = cls.replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
@@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
|
||||
preprocess = execute # TODO: remove
|
||||
|
||||
|
||||
import comfy.nested_tensor
|
||||
class LTXVConcatAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConcatAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Latent.Input("video_latent"),
|
||||
io.Latent.Input("audio_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||
output = {}
|
||||
output.update(video_latent)
|
||||
output.update(audio_latent)
|
||||
video_noise_mask = video_latent.get("noise_mask", None)
|
||||
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||
|
||||
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||
if video_noise_mask is None:
|
||||
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||
if audio_noise_mask is None:
|
||||
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||
|
||||
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVSeparateAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
description="LTXV Separate AV Latent",
|
||||
inputs=[
|
||||
io.Latent.Input("av_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="video_latent"),
|
||||
io.Latent.Output(display_name="audio_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, av_latent) -> io.NodeOutput:
|
||||
latents = av_latent["samples"].unbind()
|
||||
video_latent = av_latent.copy()
|
||||
video_latent["samples"] = latents[0]
|
||||
audio_latent = av_latent.copy()
|
||||
audio_latent["samples"] = latents[1]
|
||||
if "noise_mask" in av_latent:
|
||||
masks = av_latent["noise_mask"]
|
||||
if masks is not None:
|
||||
masks = masks.unbind()
|
||||
video_latent["noise_mask"] = masks[0]
|
||||
audio_latent["noise_mask"] = masks[1]
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyLTXVLatentVideo,
|
||||
LTXVImgToVideo,
|
||||
LTXVImgToVideoInplace,
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
]
|
||||
|
||||
|
||||
|
||||
183
comfy_extras/nodes_lt_audio.py
Normal file
183
comfy_extras/nodes_lt_audio.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
tooltip="Audio VAE checkpoint to load.",
|
||||
)
|
||||
],
|
||||
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to use for encoding.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model used for decoding the latent.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Audio.Output(display_name="Audio")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
"sample_rate": int(output_audio_sample_rate),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVEmptyLatentAudio",
|
||||
display_name="LTXV Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"frames_number",
|
||||
default=97,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_rate",
|
||||
default=25,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames per second.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"batch_size",
|
||||
default=1,
|
||||
min=1,
|
||||
max=4096,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="The number of latent audio samples in the batch.",
|
||||
),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to get configuration from.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae: AudioVAE,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": sampling_rate,
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVAudioExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LTXVAudioVAELoader,
|
||||
LTXVAudioVAEEncode,
|
||||
LTXVAudioVAEDecode,
|
||||
LTXVEmptyLatentAudio,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ComfyExtension:
|
||||
return LTXVAudioExtension()
|
||||
75
comfy_extras/nodes_lt_upsampler.py
Normal file
75
comfy_extras/nodes_lt_upsampler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from comfy import model_management
|
||||
import math
|
||||
|
||||
class LTXVLatentUpsampler:
|
||||
"""
|
||||
Upsamples a video latent by a factor of 2.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||
"vae": ("VAE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upsample_latent"
|
||||
CATEGORY = "latent/video"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def upsample_latent(
|
||||
self,
|
||||
samples: dict,
|
||||
upscale_model,
|
||||
vae,
|
||||
) -> tuple:
|
||||
"""
|
||||
Upsample the input latent using the provided model.
|
||||
|
||||
Args:
|
||||
samples (dict): Input latent samples
|
||||
upscale_model (LatentUpsampler): Loaded upscale model
|
||||
vae: VAE model for normalization
|
||||
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing the upsampled latent
|
||||
"""
|
||||
device = model_management.get_torch_device()
|
||||
memory_required = model_management.module_size(upscale_model)
|
||||
|
||||
model_dtype = next(upscale_model.parameters()).dtype
|
||||
latents = samples["samples"]
|
||||
input_dtype = latents.dtype
|
||||
|
||||
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
|
||||
model_management.free_memory(memory_required, device)
|
||||
|
||||
try:
|
||||
upscale_model.to(device) # TODO: use the comfy model management system.
|
||||
|
||||
latents = latents.to(dtype=model_dtype, device=device)
|
||||
|
||||
"""Upsample latents without tiling."""
|
||||
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
|
||||
upsampled_latents = upscale_model(latents)
|
||||
finally:
|
||||
upscale_model.cpu()
|
||||
|
||||
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
|
||||
upsampled_latents
|
||||
)
|
||||
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
|
||||
return_dict = samples.copy()
|
||||
return_dict["samples"] = upsampled_latents
|
||||
return_dict.pop("noise_mask", None)
|
||||
return (return_dict,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||
}
|
||||
Reference in New Issue
Block a user