Support the LTXV 2 model. (#11632)
This commit is contained in:
837
comfy/ldm/lightricks/av_model.py
Normal file
837
comfy/ldm/lightricks/av_model.py
Normal file
@@ -0,0 +1,837 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
a_dim,
|
||||
v_heads,
|
||||
a_heads,
|
||||
vd_head,
|
||||
ad_head,
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn1 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=v_context_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn2 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=a_context_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Video, K,V: Audio
|
||||
self.audio_to_video_attn = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
self.video_to_audio_attn = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=v_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.audio_ff = FeedForward(
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||
).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
num_scale_shift_values: int = 4,
|
||||
):
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
|
||||
vx, ax = x
|
||||
run_ax = run_ax and ax.numel() > 0
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
|
||||
return vx, ax
|
||||
|
||||
|
||||
class LTXAVModel(LTXVModel):
|
||||
"""LTXAV model for audio-video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
audio_in_channels=128,
|
||||
cross_attention_dim=4096,
|
||||
audio_cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
audio_attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
audio_num_attention_heads=32,
|
||||
caption_channels=3840,
|
||||
num_layers=48,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store audio-specific parameters
|
||||
self.audio_in_channels = audio_in_channels
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
self.audio_out_channels = audio_in_channels
|
||||
|
||||
# Audio-specific constants
|
||||
self.num_audio_channels = 8
|
||||
self.audio_frequency_bins = 16
|
||||
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXAV-specific components."""
|
||||
# Audio-specific projections
|
||||
self.audio_patchify_proj = self.operations.Linear(
|
||||
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicAVTransformerBlock(
|
||||
v_dim=self.inner_dim,
|
||||
a_dim=self.audio_inner_dim,
|
||||
v_heads=self.num_attention_heads,
|
||||
a_heads=self.audio_num_attention_heads,
|
||||
vd_head=self.attention_head_dim,
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXAV."""
|
||||
# Video output components
|
||||
super()._init_output_components(device, dtype)
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||
)
|
||||
self.audio_norm_out = self.operations.LayerNorm(
|
||||
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.audio_proj_out = self.operations.Linear(
|
||||
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||
)
|
||||
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||
|
||||
def separate_audio_and_video_latents(self, x, audio_length):
|
||||
"""Separate audio and video latents from combined input."""
|
||||
# vx = x[:, : self.in_channels]
|
||||
# ax = x[:, self.in_channels :]
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||
#
|
||||
# ax = ax.reshape(
|
||||
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||
# )
|
||||
|
||||
vx = x[0]
|
||||
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||
device=vx.device, dtype=vx.dtype
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||
if ax.numel() == 0:
|
||||
return vx
|
||||
else:
|
||||
return [vx, ax]
|
||||
"""Recombine audio and video latents for output."""
|
||||
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# # pad to f x h x w of the video latents
|
||||
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||
# if target_shape is None:
|
||||
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||
# else:
|
||||
# repetitions = target_shape[1] - vx.shape[1]
|
||||
# padded_len = repetitions * divisor
|
||||
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||
# return torch.cat([vx, ax], dim=1)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||
v_embedded_timestep = v_embedded_timestep.view(
|
||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep.flatten() * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(
|
||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||
)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
]
|
||||
cross_av_timestep_ss = list(
|
||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
v_pixel_coords = pixel_coords[0]
|
||||
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||
|
||||
a_latent_coords = pixel_coords[1]
|
||||
a_pe = self._precompute_freqs_cis(
|
||||
a_latent_coords,
|
||||
dim=self.audio_inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||
max_pos = max(
|
||||
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||
)
|
||||
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||
v_pixel_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||
a_latent_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context = context[0]
|
||||
a_context = context[1]
|
||||
v_timestep = timestep[0]
|
||||
a_timestep = timestep[1]
|
||||
v_pe, av_cross_video_freq_cis = pe[0]
|
||||
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||
|
||||
(
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(
|
||||
args["img"],
|
||||
v_context=args["v_context"],
|
||||
a_context=args["a_context"],
|
||||
attention_mask=args["attention_mask"],
|
||||
v_timestep=args["v_timestep"],
|
||||
a_timestep=args["a_timestep"],
|
||||
v_pe=args["v_pe"],
|
||||
a_pe=args["a_pe"],
|
||||
v_cross_pe=args["v_cross_pe"],
|
||||
a_cross_pe=args["a_cross_pe"],
|
||||
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": (vx, ax),
|
||||
"v_context": v_context,
|
||||
"a_context": a_context,
|
||||
"attention_mask": attention_mask,
|
||||
"v_timestep": v_timestep,
|
||||
"a_timestep": a_timestep,
|
||||
"v_pe": v_pe,
|
||||
"a_pe": a_pe,
|
||||
"v_cross_pe": av_cross_video_freq_cis,
|
||||
"a_cross_pe": av_cross_audio_freq_cis,
|
||||
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
vx, ax = out["img"]
|
||||
else:
|
||||
vx, ax = block(
|
||||
(vx, ax),
|
||||
v_context=v_context,
|
||||
a_context=a_context,
|
||||
attention_mask=attention_mask,
|
||||
v_timestep=v_timestep,
|
||||
a_timestep=a_timestep,
|
||||
v_pe=v_pe,
|
||||
a_pe=a_pe,
|
||||
v_cross_pe=av_cross_video_freq_cis,
|
||||
a_cross_pe=av_cross_audio_freq_cis,
|
||||
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
a_scale_shift_values = (
|
||||
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||
+ a_embedded_timestep[:, :, None]
|
||||
)
|
||||
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||
|
||||
ax = self.audio_norm_out(ax)
|
||||
ax = ax * (1 + a_scale) + a_shift
|
||||
ax = self.audio_proj_out(ax)
|
||||
|
||||
# Unpatchify audio
|
||||
ax = self.a_patchifier.unpatchify(
|
||||
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||
)
|
||||
|
||||
# Recombine audio and video
|
||||
original_shape = kwargs.get("av_orig_shape")
|
||||
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask=None,
|
||||
frame_rate=25,
|
||||
transformer_options={},
|
||||
keyframe_idxs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass for LTXAV model.
|
||||
|
||||
Args:
|
||||
x: Combined audio-video input tensor
|
||||
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments including audio_length
|
||||
|
||||
Returns:
|
||||
Combined audio-video output tensor
|
||||
"""
|
||||
# Handle timestep format
|
||||
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||
v_timestep, a_timestep = timestep
|
||||
kwargs["a_timestep"] = a_timestep
|
||||
timestep = v_timestep
|
||||
else:
|
||||
kwargs["a_timestep"] = timestep
|
||||
|
||||
# Call parent forward method
|
||||
return super().forward(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask,
|
||||
frame_rate,
|
||||
transformer_options,
|
||||
keyframe_idxs,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user