From beea18fab789676ccce20460a29a530270af9985 Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 30 Apr 2026 21:23:25 -0700 Subject: [PATCH] fix(free_noise): resolve None default, repeat_context/shuffle_context, list generator, mid-block disable --- src/diffusers/pipelines/free_noise_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 5990e680ba07..71890597cbe1 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -372,7 +372,7 @@ def _prepare_latents_free_noise( ) context_num_frames = ( - self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + num_frames if self._free_noise_noise_type == "random" else self._free_noise_context_length ) shape = ( @@ -407,7 +407,8 @@ def _prepare_latents_free_noise( break indices = torch.LongTensor(list(range(window_start, window_end))) - shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + perm_generator = generator[0] if isinstance(generator, list) else generator + shuffled_indices = indices[torch.randperm(window_length, generator=perm_generator)] current_start = i current_end = min(num_frames, current_start + window_length) @@ -491,6 +492,8 @@ def enable_free_noise( allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] + context_length = context_length or self.motion_adapter.config.motion_max_seq_length + if context_length > self.motion_adapter.config.motion_max_seq_length: logger.warning( f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results." @@ -502,7 +505,7 @@ def enable_free_noise( if noise_type not in allowed_noise_type: raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") - self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length + self._free_noise_context_length = context_length self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type @@ -525,7 +528,6 @@ def disable_free_noise(self) -> None: else: blocks = [*self.unet.down_blocks, *self.unet.up_blocks] - blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: self._disable_free_noise_in_block(block)