Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/diffusers/pipelines/free_noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _prepare_latents_free_noise(
)

context_num_frames = (
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
num_frames if self._free_noise_noise_type == "random" else self._free_noise_context_length
)

shape = (
Expand Down Expand Up @@ -407,7 +407,8 @@ def _prepare_latents_free_noise(
break

indices = torch.LongTensor(list(range(window_start, window_end)))
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
perm_generator = generator[0] if isinstance(generator, list) else generator
shuffled_indices = indices[torch.randperm(window_length, generator=perm_generator)]

current_start = i
current_end = min(num_frames, current_start + window_length)
Expand Down Expand Up @@ -491,6 +492,8 @@ def enable_free_noise(
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]

context_length = context_length or self.motion_adapter.config.motion_max_seq_length

if context_length > self.motion_adapter.config.motion_max_seq_length:
logger.warning(
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
Expand All @@ -502,7 +505,7 @@ def enable_free_noise(
if noise_type not in allowed_noise_type:
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")

self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
self._free_noise_context_length = context_length
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
Expand All @@ -525,7 +528,6 @@ def disable_free_noise(self) -> None:
else:
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]

blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._disable_free_noise_in_block(block)

Expand Down
Loading