diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index a49de4083342..8a00a0c6b452 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -295,22 +294,21 @@ def forward( ) num_register_repeats = seq_len // self.num_learnable_registers - registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + registers = ( + self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1) + ) # [seq_len, inner_dim] binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() if binary_attn_mask.ndim == 4: binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] - valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] - pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] - padded_hidden_states = [ - F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) - ] - padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + # Replace padding positions with learned registers using vectorized masking + mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] + hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] - hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + # Flip sequence: embeddings move to front, registers to back (from left padding layout) + hidden_states = torch.flip(hidden_states, dims=[1]) # Overwrite attention_mask with an all-zeros mask if using registers. attention_mask = torch.zeros_like(attention_mask) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 73ebac0f173c..946360445e61 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1189,6 +1189,10 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync + self.scheduler.set_begin_index(0) + audio_scheduler.set_begin_index(0) + # 6. Prepare micro-conditions # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords(