From b3a515080752a3ba7ca92161e25530c7f280f629 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 6 May 2026 05:37:46 +0100 Subject: [PATCH] Fix `stable_video_diffusion` --- .../unets/unet_spatio_temporal_condition.py | 9 +- .../pipeline_stable_video_diffusion.py | 205 +++++++++++------- .../test_stable_video_diffusion.py | 6 - tests/pipelines/test_pipelines_common.py | 2 +- 4 files changed, 134 insertions(+), 88 deletions(-) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index eddeb9826b0c..fad781ffe343 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -114,7 +114,7 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + if not isinstance(cross_attention_dim, int) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) @@ -124,6 +124,13 @@ def __init__( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) + if not isinstance(transformer_layers_per_block, int) and len(transformer_layers_per_block) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `transformer_layers_per_block` as `down_block_types`. `transformer_layers_per_block`: {transformer_layers_per_block}. `down_block_types`: {down_block_types}." + ) + # input self.conv_in = nn.Conv2d( in_channels, diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 05877f69d403..856677daffc1 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput +from ...image_processor import PipelineImageInput, is_valid_image_imagelist from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring @@ -62,14 +62,6 @@ """ -def _append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") - return x[(...,) + (None,) * dims_to_append] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -187,6 +179,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) + self.clip_image_processor = VideoProcessor(do_resize=False, do_normalize=False) def _encode_image( self, @@ -196,16 +189,14 @@ def _encode_image( do_classifier_free_guidance: bool, ) -> torch.Tensor: dtype = next(self.image_encoder.parameters()).dtype + image = self.clip_image_processor.preprocess(image) - if not isinstance(image, torch.Tensor): - image = self.video_processor.pil_to_numpy(image) - image = self.video_processor.numpy_to_pt(image) - - # We normalize the image before resizing to match with the original implementation. - # Then we unnormalize it after resizing. - image = image * 2.0 - 1.0 - image = _resize_with_antialiasing(image, (224, 224)) - image = (image + 1.0) / 2.0 + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = self.video_processor.normalize(image) + image_size = self.image_encoder.config.image_size + image = _resize_with_antialiasing(image, (image_size, image_size)) + image = self.video_processor.denormalize(image) # Normalize the image with for CLIP input image = self.feature_extractor( @@ -221,10 +212,13 @@ def _encode_image( image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + # duplicate image embeddings for each generation per prompt + image_embeddings = torch.repeat_interleave( + image_embeddings, + repeats=num_videos_per_prompt, + dim=0, + output_size=image_embeddings.shape[0] * num_videos_per_prompt, + ) if do_classifier_free_guidance: negative_image_embeddings = torch.zeros_like(image_embeddings) @@ -242,12 +236,17 @@ def _encode_vae_image( device: str | torch.device, num_videos_per_prompt: int, do_classifier_free_guidance: bool, - ): + ) -> torch.Tensor: image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.mode() - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + # duplicate image latents for each generation per prompt + image_latents = torch.repeat_interleave( + image_latents, + repeats=num_videos_per_prompt, + dim=0, + output_size=image_latents.shape[0] * num_videos_per_prompt, + ) if do_classifier_free_guidance: negative_image_latents = torch.zeros_like(image_latents) @@ -268,7 +267,7 @@ def _get_add_time_ids( batch_size: int, num_videos_per_prompt: int, do_classifier_free_guidance: bool, - ): + ) -> torch.Tensor: add_time_ids = [fps, motion_bucket_id, noise_aug_strength] passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) @@ -280,14 +279,19 @@ def _get_add_time_ids( ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + add_time_ids = torch.repeat_interleave( + add_time_ids, + repeats=batch_size * num_videos_per_prompt, + dim=0, + output_size=batch_size * num_videos_per_prompt, + ) if do_classifier_free_guidance: add_time_ids = torch.cat([add_time_ids, add_time_ids]) return add_time_ids - def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): + def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14) -> torch.Tensor: # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] latents = latents.flatten(0, 1) @@ -316,15 +320,22 @@ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_si frames = frames.float() return frames - def check_inputs(self, image, height, width): - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): + def check_inputs(self, image: PipelineImageInput, height: int, width: int) -> None: + if isinstance(image, list): + invalid_image = len(image) == 0 or any( + not isinstance(i, (PIL.Image.Image, np.ndarray, torch.Tensor)) + or isinstance(i, (np.ndarray, torch.Tensor)) + and i.ndim not in (3, 4) + for i in image + ) + else: + invalid_image = not is_valid_image_imagelist(image) or ( + isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim not in (3, 4) + ) + if invalid_image: raise ValueError( - "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" - f" {type(image)}" + "`image` has to be a PIL image, NumPy array, PyTorch tensor, or a list of PIL images, NumPy arrays," + f" or PyTorch tensors, but is {type(image)}." ) if height % 8 != 0 or width % 8 != 0: @@ -339,9 +350,9 @@ def prepare_latents( width: int, dtype: torch.dtype, device: str | torch.device, - generator: torch.Generator, + generator: torch.Generator | list[torch.Generator] | None, latents: torch.Tensor | None = None, - ): + ) -> torch.Tensor: shape = ( batch_size, num_frames, @@ -358,34 +369,34 @@ def prepare_latents( if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property - def guidance_scale(self): + def guidance_scale(self) -> float | torch.Tensor: return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property - def do_classifier_free_guidance(self): + def do_classifier_free_guidance(self) -> bool: if isinstance(self.guidance_scale, (int, float)): return self.guidance_scale > 1 - return self.guidance_scale.max() > 1 + return bool(self.guidance_scale.max() > 1) @property - def num_timesteps(self): + def num_timesteps(self) -> int: return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + image: PipelineImageInput, height: int = 576, width: int = 1024, num_frames: int | None = None, @@ -397,21 +408,26 @@ def __call__( motion_bucket_id: int = 127, noise_aug_strength: float = 0.02, decode_chunk_size: int | None = None, - num_videos_per_prompt: int | None = 1, + num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, output_type: str | None = "pil", - callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end: Callable[ + [DiffusionPipeline, int, int | torch.Tensor, dict[str, torch.Tensor]], dict[str, torch.Tensor] + ] + | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], return_dict: bool = True, - ): + ) -> StableVideoDiffusionPipelineOutput | tuple[list[list[PIL.Image.Image]] | np.ndarray | torch.Tensor]: r""" The call function to the pipeline for generation. Args: - image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): - Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, - 1]`. + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, + or `list[torch.Tensor]`): + Image(s) to guide video generation. NumPy arrays and tensors are expected to have values in `[0, 1]`. + Tensor inputs can use shape `(batch, channels, height, width)` or `(channels, height, width)`. + NumPy inputs can use shape `(batch, height, width, channels)` or `(height, width, channels)`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -453,7 +469,7 @@ def __call__( generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `pil`, `np` or `pt`. + The output format of the generated video. Choose between `pil`, `np`, `pt`, or `latent`. callback_on_step_end (`Callable`, *optional*): A function that is called at the end of each denoising step during inference. The function is called with the following arguments: @@ -465,14 +481,14 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. + Whether or not to return a [`~pipelines.stable_video_diffusion.StableVideoDiffusionPipelineOutput`] + instead of a plain tuple. Examples: Returns: - [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is + [`~pipelines.stable_video_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_video_diffusion.StableVideoDiffusionPipelineOutput`] is returned, otherwise a `tuple` of (`list[list[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is returned. """ @@ -487,45 +503,56 @@ def __call__( self.check_inputs(image, height, width) # 2. Define call parameters - if isinstance(image, PIL.Image.Image): + if isinstance(image, PIL.Image.Image) or isinstance(image, np.ndarray) and image.ndim == 3: + batch_size = 1 + elif isinstance(image, torch.Tensor) and image.ndim == 3: batch_size = 1 elif isinstance(image, list): - batch_size = len(image) + batch_size = ( + sum(i.shape[0] for i in image) + if isinstance(image[0], (np.ndarray, torch.Tensor)) and image[0].ndim == 4 + else len(image) + ) else: batch_size = image.shape[0] device = self._execution_device + denoising_dtype = self.unet.dtype # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - self._guidance_scale = max_guidance_scale + self._guidance_scale = max(min_guidance_scale, max_guidance_scale) # 3. Encode input image image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + image_embeddings = image_embeddings.to(dtype=denoising_dtype) # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here. # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 fps = fps - 1 # 4. Encode input image using VAE - image = self.video_processor.preprocess(image, height=height, width=width).to(device) - noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) - image = image + noise_aug_strength * noise - - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + vae_dtype = self.vae.dtype + needs_upcasting = vae_dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float32) + image = self.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=torch.float32 if needs_upcasting else vae_dtype + ) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + image_latents = self._encode_vae_image( image, device=device, num_videos_per_prompt=num_videos_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, ) - image_latents = image_latents.to(image_embeddings.dtype) + image_latents = image_latents.to(dtype=denoising_dtype) - # cast back to fp16 if needed + # restore the VAE dtype before the denoising loop if needs_upcasting: - self.vae.to(dtype=torch.float16) + self.vae.to(dtype=vae_dtype) # Repeat the image latents for each frame so we can concatenate them with the noise # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] @@ -536,7 +563,7 @@ def __call__( fps, motion_bucket_id, noise_aug_strength, - image_embeddings.dtype, + denoising_dtype, batch_size, num_videos_per_prompt, self.do_classifier_free_guidance, @@ -560,17 +587,22 @@ def __call__( num_channels_latents, height, width, - image_embeddings.dtype, + torch.float32, device, generator, latents, - ) + ).to(denoising_dtype) # 8. Prepare guidance scale guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) guidance_scale = guidance_scale.to(device, latents.dtype) - guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) - guidance_scale = _append_dims(guidance_scale, latents.ndim) + guidance_scale = torch.repeat_interleave( + guidance_scale, + repeats=batch_size * num_videos_per_prompt, + dim=0, + output_size=batch_size * num_videos_per_prompt, + ) + guidance_scale = guidance_scale[:, :, None, None, None] self._guidance_scale = guidance_scale @@ -582,6 +614,7 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(dtype=denoising_dtype) # Concatenate image_latents over channels dimension latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) @@ -602,6 +635,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = latents.to(dtype=denoising_dtype) if callback_on_step_end is not None: callback_kwargs = {} @@ -618,10 +652,12 @@ def __call__( xm.mark_step() if not output_type == "latent": - # cast back to fp16 if needed + # upcast the VAE for decoding if needed, then restore its original dtype + if needs_upcasting: + self.vae.to(dtype=torch.float32) + frames = self.decode_latents(latents.to(dtype=self.vae.dtype), num_frames, decode_chunk_size) if needs_upcasting: - self.vae.to(dtype=torch.float16) - frames = self.decode_latents(latents, num_frames, decode_chunk_size) + self.vae.to(dtype=vae_dtype) frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) else: frames = latents @@ -629,14 +665,19 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return frames + return (frames,) return StableVideoDiffusionPipelineOutput(frames=frames) # resizing utils # TODO: clean up later -def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): +def _resize_with_antialiasing( + input: torch.Tensor, + size: tuple[int, int], + interpolation: str = "bicubic", + align_corners: bool = True, +) -> torch.Tensor: h, w = input.shape[-2:] factors = (h / size[0], w / size[1]) @@ -665,7 +706,7 @@ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corner return output -def _compute_padding(kernel_size): +def _compute_padding(kernel_size: list[int] | tuple[int, ...]) -> list[int]: """Compute padding tuple.""" # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad @@ -688,7 +729,7 @@ def _compute_padding(kernel_size): return out_padding -def _filter2d(input, kernel): +def _filter2d(input: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: # prepare kernel b, c, h, w = input.shape tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) @@ -711,7 +752,7 @@ def _filter2d(input, kernel): return out -def _gaussian(window_size: int, sigma): +def _gaussian(window_size: int, sigma: float | torch.Tensor) -> torch.Tensor: if isinstance(sigma, float): sigma = torch.tensor([[sigma]]) @@ -727,7 +768,11 @@ def _gaussian(window_size: int, sigma): return gauss / gauss.sum(-1, keepdim=True) -def _gaussian_blur2d(input, kernel_size, sigma): +def _gaussian_blur2d( + input: torch.Tensor, + kernel_size: tuple[int, int], + sigma: tuple[float, float] | torch.Tensor, +) -> torch.Tensor: if isinstance(sigma, tuple): sigma = torch.tensor([sigma], dtype=input.dtype) else: diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 52595f7a8cd9..b28e76129d4a 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -151,7 +151,6 @@ def get_dummy_inputs(self, device, seed=0): def test_attention_slicing_forward_pass(self): pass - @unittest.skip("Batched inference works and outputs look correct, but the test is failing") def test_inference_batch_single_identical( self, batch_size=2, @@ -188,10 +187,6 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() assert max_diff < expected_max_diff - @unittest.skip("Test is similar to test_inference_batch_single_identical") - def test_inference_batch_consistent(self): - pass - def test_np_output_type(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -226,7 +221,6 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() self.assertLess(max_diff, expected_max_difference) - @unittest.skip("Test is currently failing") def test_float16_inference(self, expected_max_diff=5e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 010a5176c684..957a526754d5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2081,7 +2081,7 @@ def is_nan(tensor): has_nan = torch.isnan(tensor).any() return has_nan - with tempfile.TemporaryDirectory() as tmpdir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant)