diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index 351c8b65de0e..5de0bf059987 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -483,7 +483,7 @@ def prepare_latents( audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) - encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + encoded_audio = encoded_audio.repeat_interleave(num_waveforms_per_prompt, dim=0) latents = encoded_audio + latents return latents