Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 78 additions & 22 deletions src/ltx_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,25 @@ namespace LTXVAE {
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx,
bool causal = true) {
bool causal = true,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;

ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;

GGML_ASSERT(x->ne[2] >= temporal_pad);

int end_idx = x->ne[2] - temporal_pad;
int start_idx = std::max(end_idx - pad, 0);

// Save a contiguous copy of the last `pad` frames so the large `x`
// tensor is not kept alive across iterations by a dangling view.
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
GGML_ASSERT(start_idx >= 0);
GGML_ASSERT(end_idx > 0);

auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
}
feat_idx++;
Expand Down Expand Up @@ -282,7 +291,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
Expand All @@ -309,14 +319,14 @@ namespace LTXVAE {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);

h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);

return ggml_add(ctx->ggml_ctx, h, x);
}
Expand Down Expand Up @@ -365,7 +375,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
Expand All @@ -374,7 +385,7 @@ namespace LTXVAE {
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
}
return x;
}
Expand Down Expand Up @@ -435,7 +446,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);

bool drop_first = (chunk_idx == 0) && (factor_t > 1);
Expand All @@ -451,7 +463,7 @@ namespace LTXVAE {
x_in = res;
}

x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
Expand Down Expand Up @@ -889,7 +901,8 @@ namespace LTXVAE {
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int& temporal_pad) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
Expand All @@ -901,20 +914,21 @@ namespace LTXVAE {
}

// conv_in with feat_map for left temporal context
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);

// up_blocks
int block_idx = 0;
while (blocks.find("up_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
temporal_pad *= upsample->factor_t;
}
block_idx++;
}
Expand All @@ -931,7 +945,7 @@ namespace LTXVAE {
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
return x;
}
};
Expand Down Expand Up @@ -984,7 +998,9 @@ namespace LTXVAE {
// tensors can be freed by GGML before the next iteration starts.
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
ggml_tensor* timestep,
int temporal_window_size = 1,
int temporal_pad = 0) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
Expand All @@ -999,13 +1015,38 @@ namespace LTXVAE {
// 128 slots is generous enough for any supported decoder configuration.
std::vector<ggml_tensor*> feat_map(128, nullptr);

// Ensure window size is at least 1
int window = std::max(1, temporal_window_size);

if (temporal_pad >= window) {
LOG_WARN("temporal_pad (%d) is greater than or equal to temporal_window_size (%d), adjusting values to avoid empty decode windows",
temporal_pad, window);
temporal_pad = window - 1;
}
LOG_DEBUG("Using temporal tiling: tile size = %d frames, padding/overlap = %d frames, total frames = %d, resulting in %d tiles", window, temporal_pad, (int)T, (T + window - temporal_pad - 1) / (window - temporal_pad));
ggml_tensor* out = nullptr;
for (int i = 0; i < (int)T; i++) {
for (int i = 0; i < (int)T - temporal_pad; i += (window - temporal_pad)) {
int feat_idx = 0;
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
feat_map, feat_idx, i);
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);

// Calculate the end index for the current temporal chunk
int end_i = std::min((int)T, i + window);
if (end_i >= (int)T) {
temporal_pad = 0; // to avoid any padding related issue (e.g. padding more than the number of frames in the chunk) in the last chunk
}

int chunck_pad = temporal_pad; // modified by forward_tiled_frame temporal inflation

auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);

auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
feat_map, feat_idx, i, chunck_pad);

// discard last frames if it's not the final chunk and temporal_pad > 0
if (temporal_pad > 0 && end_i < (int)T) {
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunck_pad);
}

out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
}

return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
Expand Down Expand Up @@ -1083,7 +1124,22 @@ struct LTXVideoVAE : public VAE {
bool use_tiled = decode_graph && temporal_tiling_enabled &&
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
if (use_tiled) {
out = vae.decode_tiled(&runner_ctx, z, timestep);
// TODO: pass as args
int tile_frames = 1;
int tile_pad = 0;
// use env variables for now
const char* env_tile_frames = std::getenv("VAE_TILE_FRAMES");
const char* env_tile_pad = std::getenv("VAE_TILE_PAD");
if (env_tile_frames) {
tile_frames = std::max(1, std::atoi(env_tile_frames));
LOG_DEBUG("Using temporal tiling with tile_frames=%d", tile_frames);
}
if (env_tile_pad) {
tile_pad = std::max(0, std::atoi(env_tile_pad));
LOG_DEBUG("Using temporal tiling with tile_pad=%d", tile_pad);
}

out = vae.decode_tiled(&runner_ctx, z, timestep, tile_frames, tile_pad);
} else {
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
}
Expand Down
Loading