From 598874427a147e8541da962394f1ddd02a78d14d Mon Sep 17 00:00:00 2001 From: Viktoriia Romanova Date: Sun, 26 Apr 2026 20:04:40 +0000 Subject: [PATCH 1/3] =?UTF-8?q?Remove=20unnecessary=20CUDA=20synchronizati?= =?UTF-8?q?on=20points=20and=20avoid=20CPU=E2=86=92GPU=20tensor=20creation?= =?UTF-8?q?=20across=20the=20LTX2=20pipeline,=20transformer,=20scheduler,?= =?UTF-8?q?=20and=20connector=20logic.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add set_begin_index(0) to schedulers to eliminate DtoH sync in _init_step_index - Replace torch.tensor(..., device=...) with on-device tensor construction for decode scaling - Move RoPE-related tensor creation to GPU to avoid memcpy overhead - Refactor connector padding logic using vectorized masking instead of list-based ops --- .../models/transformers/transformer_ltx2.py | 13 +++++++++---- src/diffusers/pipelines/ltx2/connectors.py | 19 +++++++++---------- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 5 +++++ .../scheduling_flow_match_euler_discrete.py | 6 ++++-- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index a4915ccfb96a..14a9ea1f012d 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -894,9 +894,14 @@ def prepare_video_coords( grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches # 2. Get the patch boundaries with respect to the latent video grid - patch_size = (self.patch_size_t, self.patch_size, self.patch_size) - patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) - patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + patch_size_delta = torch.stack( + [ + grid.new_ones(1) * self.patch_size_t, + grid.new_ones(1) * self.patch_size, + grid.new_ones(1) * self.patch_size, + ] + ).reshape(3, 1, 1, 1) + patch_ends = grid + patch_size_delta # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] @@ -905,7 +910,7 @@ def prepare_video_coords( latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) # 3. Calculate the pixel space patch boundaries from the latent boundaries. - scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors]) # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape broadcast_shape = [1] * latent_coords.ndim broadcast_shape[1] = -1 # This is the (frame, height, width) dim diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index a49de4083342..d784933883d7 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -295,22 +295,21 @@ def forward( ) num_register_repeats = seq_len // self.num_learnable_registers - registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + registers = ( + self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1) + ) # [seq_len, inner_dim] binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() if binary_attn_mask.ndim == 4: binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] - valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] - pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] - padded_hidden_states = [ - F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) - ] - padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + # Replace padding positions with learned registers using vectorized masking + mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] + hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] - hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + # Flip sequence: embeddings move to front, registers to back (from left padding layout) + hidden_states = torch.flip(hidden_states, dims=[1]) # Overwrite attention_mask with an all-zeros mask if using registers. attention_mask = torch.zeros_like(attention_mask) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 73ebac0f173c..342ff74ad7db 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1189,6 +1189,11 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + audio_scheduler.set_begin_index(0) + # 6. Prepare micro-conditions # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 1021abf0f6f6..1852c529f9bb 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -362,11 +362,13 @@ def set_timesteps( sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # 5. Convert sigmas and timesteps to tensors and move to specified device - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) if not is_timesteps_provided: timesteps = sigmas * self.config.num_train_timesteps else: - timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + timesteps = ( + torch.from_numpy(timesteps).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) + ) # 6. Append the terminal sigma value. # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the From f71dc9f28f9b423c767c6e6948aba42ea8f5ead7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 6 May 2026 01:07:39 +0000 Subject: [PATCH 2/3] Apply style fixes --- src/diffusers/pipelines/ltx2/connectors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index d784933883d7..8a00a0c6b452 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin From 8b42fbc1c67697111c1cfe165aef82a077e9cd76 Mon Sep 17 00:00:00 2001 From: Viktoriia Romanova Date: Wed, 6 May 2026 10:46:35 +0000 Subject: [PATCH 3/3] Revert low-impact CUDA synchronization changes and remove redundant `hasattr` check --- .../models/transformers/transformer_ltx2.py | 13 ++++--------- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 5 ++--- .../scheduling_flow_match_euler_discrete.py | 6 ++---- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 14a9ea1f012d..a4915ccfb96a 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -894,14 +894,9 @@ def prepare_video_coords( grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches # 2. Get the patch boundaries with respect to the latent video grid - patch_size_delta = torch.stack( - [ - grid.new_ones(1) * self.patch_size_t, - grid.new_ones(1) * self.patch_size, - grid.new_ones(1) * self.patch_size, - ] - ).reshape(3, 1, 1, 1) - patch_ends = grid + patch_size_delta + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] @@ -910,7 +905,7 @@ def prepare_video_coords( latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) # 3. Calculate the pixel space patch boundaries from the latent boundaries. - scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors]) + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape broadcast_shape = [1] * latent_coords.ndim broadcast_shape[1] = -1 # This is the (frame, height, width) dim diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 342ff74ad7db..946360445e61 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1190,9 +1190,8 @@ def __call__( self._num_timesteps = len(timesteps) # Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(0) - audio_scheduler.set_begin_index(0) + self.scheduler.set_begin_index(0) + audio_scheduler.set_begin_index(0) # 6. Prepare micro-conditions # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index c990afaadef0..7b207f782079 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -363,13 +363,11 @@ def set_timesteps( sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # 5. Convert sigmas and timesteps to tensors and move to specified device - sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) if not is_timesteps_provided: timesteps = sigmas * self.config.num_train_timesteps else: - timesteps = ( - torch.from_numpy(timesteps).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) - ) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) # 6. Append the terminal sigma value. # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the