From 998af1c98aa4f3914b8538d62a79877ea8b13870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 12 May 2026 23:59:09 +0200 Subject: [PATCH] Temporal tile size + overlap --- src/ltx_vae.hpp | 100 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 22 deletions(-) diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index 8bcc1ca83..10d306ac9 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -141,16 +141,25 @@ namespace LTXVAE { std::vector& feat_map, int& feat_idx, int chunk_idx, - bool causal = true) { + bool causal = true, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2; ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr; + GGML_ASSERT(x->ne[2] >= temporal_pad); + + int end_idx = x->ne[2] - temporal_pad; + int start_idx = std::max(end_idx - pad, 0); + // Save a contiguous copy of the last `pad` frames so the large `x` // tensor is not kept alive across iterations by a dangling view. - if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) { - auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]); + if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) { + GGML_ASSERT(start_idx >= 0); + GGML_ASSERT(end_idx > 0); + + auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx); feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice); } feat_idx++; @@ -282,7 +291,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); @@ -309,14 +319,14 @@ namespace LTXVAE { h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); h = norm2->forward(ctx, h); if (timestep_conditioning) { h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); return ggml_add(ctx->ggml_ctx, h, x); } @@ -365,7 +375,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { ggml_tensor* timestep_embed = nullptr; if (timestep_conditioning) { GGML_ASSERT(timestep != nullptr); @@ -374,7 +385,7 @@ namespace LTXVAE { } for (int i = 0; i < num_layers; i++) { auto resnet = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); - x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx); + x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad); } return x; } @@ -435,7 +446,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); bool drop_first = (chunk_idx == 0) && (factor_t > 1); @@ -451,7 +463,7 @@ namespace LTXVAE { x_in = res; } - x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal); + x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad); x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first); if (residual) { x = ggml_add(ctx->ggml_ctx, x, x_in); @@ -889,7 +901,8 @@ namespace LTXVAE { ggml_tensor* timestep, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int& temporal_pad) { auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); auto conv_norm_out = std::dynamic_pointer_cast(blocks["conv_norm_out"]); auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); @@ -901,7 +914,7 @@ namespace LTXVAE { } // conv_in with feat_map for left temporal context - x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); // up_blocks int block_idx = 0; @@ -909,12 +922,13 @@ namespace LTXVAE { auto mid_block = std::dynamic_pointer_cast(blocks["up_blocks." + std::to_string(block_idx)]); if (mid_block) { x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); } else { auto upsample = std::dynamic_pointer_cast( blocks["up_blocks." + std::to_string(block_idx)]); x = upsample->forward(ctx, x, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); + temporal_pad *= upsample->factor_t; } block_idx++; } @@ -931,7 +945,7 @@ namespace LTXVAE { x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift); } x = ggml_silu_inplace(ctx->ggml_ctx, x); - x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); return x; } }; @@ -984,7 +998,9 @@ namespace LTXVAE { // tensors can be freed by GGML before the next iteration starts. ggml_tensor* decode_tiled(GGMLRunnerContext* ctx, ggml_tensor* z, - ggml_tensor* timestep) { + ggml_tensor* timestep, + int temporal_window_size = 1, + int temporal_pad = 0) { auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); auto latents = processor->un_normalize(ctx, z); @@ -999,13 +1015,38 @@ namespace LTXVAE { // 128 slots is generous enough for any supported decoder configuration. std::vector feat_map(128, nullptr); + // Ensure window size is at least 1 + int window = std::max(1, temporal_window_size); + + if (temporal_pad >= window) { + LOG_WARN("temporal_pad (%d) is greater than or equal to temporal_window_size (%d), adjusting values to avoid empty decode windows", + temporal_pad, window); + temporal_pad = window - 1; + } + LOG_DEBUG("Using temporal tiling: tile size = %d frames, padding/overlap = %d frames, total frames = %d, resulting in %d tiles", window, temporal_pad, (int)T, (T + window - temporal_pad - 1) / (window - temporal_pad)); ggml_tensor* out = nullptr; - for (int i = 0; i < (int)T; i++) { + for (int i = 0; i < (int)T - temporal_pad; i += (window - temporal_pad)) { int feat_idx = 0; - auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1); - auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep, - feat_map, feat_idx, i); - out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2); + + // Calculate the end index for the current temporal chunk + int end_i = std::min((int)T, i + window); + if (end_i >= (int)T) { + temporal_pad = 0; // to avoid any padding related issue (e.g. padding more than the number of frames in the chunk) in the last chunk + } + + int chunck_pad = temporal_pad; // modified by forward_tiled_frame temporal inflation + + auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i); + + auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep, + feat_map, feat_idx, i, chunck_pad); + + // discard last frames if it's not the final chunk and temporal_pad > 0 + if (temporal_pad > 0 && end_i < (int)T) { + out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunck_pad); + } + + out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2); } return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); @@ -1083,7 +1124,22 @@ struct LTXVideoVAE : public VAE { bool use_tiled = decode_graph && temporal_tiling_enabled && z_tensor.dim() == 5 && z_tensor.shape()[2] > 1; if (use_tiled) { - out = vae.decode_tiled(&runner_ctx, z, timestep); + // TODO: pass as args + int tile_frames = 1; + int tile_pad = 0; + // use env variables for now + const char* env_tile_frames = std::getenv("VAE_TILE_FRAMES"); + const char* env_tile_pad = std::getenv("VAE_TILE_PAD"); + if (env_tile_frames) { + tile_frames = std::max(1, std::atoi(env_tile_frames)); + LOG_DEBUG("Using temporal tiling with tile_frames=%d", tile_frames); + } + if (env_tile_pad) { + tile_pad = std::max(0, std::atoi(env_tile_pad)); + LOG_DEBUG("Using temporal tiling with tile_pad=%d", tile_pad); + } + + out = vae.decode_tiled(&runner_ctx, z, timestep, tile_frames, tile_pad); } else { out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); }