diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 85e59adc39a4..f07b9a5b2078 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -16,6 +16,7 @@ import torch +from ...image_processor import VaeImageProcessor from ...models import UNet2DModel from ...schedulers import CMStochasticIterativeScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -89,9 +90,20 @@ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) scheduler=scheduler, ) + self.image_processor = VaeImageProcessor() self.safety_checker = None - def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size: int, + num_channels: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -108,28 +120,13 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device latents = latents * self.scheduler.init_noise_sigma return latents - # Follows diffusers.VaeImageProcessor.postprocess - def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"): - if output_type not in ["pt", "np", "pil"]: - raise ValueError( - f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" - ) - - # Equivalent to diffusers.VaeImageProcessor.denormalize - sample = (sample / 2 + 0.5).clamp(0, 1) - if output_type == "pt": - return sample - - # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy - sample = sample.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "np": - return sample - - # Output_type must be 'pil' - sample = self.numpy_to_pil(sample) - return sample - - def prepare_class_labels(self, batch_size, device, class_labels=None): + def prepare_class_labels( + self, + batch_size: int, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None = None, + class_labels: torch.Tensor | list[int] | int | None = None, + ) -> torch.Tensor | None: if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) @@ -137,15 +134,43 @@ def prepare_class_labels(self, batch_size, device, class_labels=None): assert batch_size == 1, "Batch size must be 1 if classes is an int" class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: - # Randomly generate batch_size class labels - # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils - class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + if isinstance(generator, list): + class_labels = torch.cat( + [ + torch.randint( + 0, + self.unet.config.num_class_embeds, + size=(1,), + generator=g, + device=g.device, + ).cpu() + for g in generator + ] + ) + else: + rand_device = generator.device if generator is not None else torch.device("cpu") + class_labels = torch.randint( + 0, + self.unet.config.num_class_embeds, + size=(batch_size,), + generator=generator, + device=rand_device, + ) class_labels = class_labels.to(device) else: class_labels = None return class_labels - def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + def check_inputs( + self, + num_inference_steps: int | None, + timesteps: list[int] | None, + latents: torch.Tensor | None, + batch_size: int, + height: int, + width: int, + callback_steps: int, + ) -> None: if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") @@ -156,7 +181,7 @@ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_ ) if latents is not None: - expected_shape = (batch_size, 3, img_size, img_size) + expected_shape = (batch_size, self.unet.config.in_channels, height, width) if latents.shape != expected_shape: raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") @@ -175,28 +200,29 @@ def __call__( batch_size: int = 1, class_labels: torch.Tensor | list[int] | int | None = None, num_inference_steps: int = 1, - timesteps: list[int] = None, + timesteps: list[int] | None = None, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, - output_type: str | None = "pil", + output_type: str = "pil", return_dict: bool = True, callback: Callable[[int, int, torch.Tensor], None] | None = None, callback_steps: int = 1, - ): + ) -> ImagePipelineOutput | tuple: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. class_labels (`torch.Tensor` or `list[int]` or `int`, *optional*): - Optional class labels for conditioning class-conditional consistency models. Not used if the model is - not class-conditional. + Optional class labels for conditioning class-conditional consistency models. If not provided for a + class-conditional model, labels are sampled randomly using `generator`. Not used if the model is not + class-conditional. num_inference_steps (`int`, *optional*, defaults to 1): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`list[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): @@ -204,7 +230,7 @@ def __call__( generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + The output format of the generated image. Choose between `"pil"`, `"np"`, or `"pt"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): @@ -222,19 +248,23 @@ def __call__( returned where the first element is a list with the generated images. """ # 0. Prepare call parameters - img_size = self.unet.config.sample_size + sample_size = self.unet.config.sample_size + if isinstance(sample_size, int): + height = width = sample_size + else: + height, width = sample_size device = self._execution_device # 1. Check inputs - self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, height, width, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, - height=img_size, - width=img_size, + height=height, + width=width, dtype=self.unet.dtype, device=device, generator=generator, @@ -242,7 +272,7 @@ def __call__( ) # 3. Handle class_labels for class-conditional models - class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + class_labels = self.prepare_class_labels(batch_size, device, generator=generator, class_labels=class_labels) # 4. Prepare timesteps if timesteps is not None: @@ -271,7 +301,11 @@ def __call__( xm.mark_step() # 6. Post-process image sample - image = self.postprocess_image(sample, output_type=output_type) + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + image = self.image_processor.postprocess(sample, output_type=output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 0ab0c0af2588..a55da60fd1f3 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -82,6 +82,26 @@ def get_dummy_components(self, class_cond=False): return components + def get_tiny_components(self, sample_size=8, in_channels=3, out_channels=3, num_class_embeds=None): + torch.manual_seed(0) + unet = UNet2DModel( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + layers_per_block=1, + block_out_channels=(8,), + down_block_types=("DownBlock2D",), + up_block_types=("UpBlock2D",), + norm_num_groups=4, + num_class_embeds=num_class_embeds, + ) + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + return {"unet": unet, "scheduler": scheduler} + def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) @@ -168,6 +188,86 @@ def test_consistency_model_pipeline_onestep_class_cond(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_random_class_labels_use_generator(self): + device = "cpu" + components = self.get_tiny_components(num_class_embeds=10) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(123) + labels = pipe.prepare_class_labels(batch_size=4, device=torch.device(device), generator=generator) + + expected_generator = torch.Generator(device=device).manual_seed(123) + expected_labels = torch.randint(0, pipe.unet.config.num_class_embeds, size=(4,), generator=expected_generator) + assert torch.equal(labels.cpu(), expected_labels) + + list_labels = pipe.prepare_class_labels(batch_size=2, device=torch.device(device), class_labels=[1, 2]) + int_labels = pipe.prepare_class_labels(batch_size=1, device=torch.device(device), class_labels=1) + assert list_labels.dtype == torch.int + assert int_labels.dtype == torch.int + + latents = torch.zeros((1, 3, 8, 8)) + generator = torch.Generator(device=device).manual_seed(123) + image = pipe( + batch_size=1, + class_labels=None, + num_inference_steps=1, + timesteps=None, + generator=generator, + latents=latents, + output_type="pt", + ).images + + generator = torch.Generator(device=device).manual_seed(123) + image_2 = pipe( + batch_size=1, + class_labels=None, + num_inference_steps=1, + timesteps=None, + generator=generator, + latents=latents, + output_type="pt", + ).images + + assert torch.equal(image, image_2) + + def test_latents_use_unet_in_channels(self): + device = "cpu" + components = self.get_tiny_components(in_channels=1, out_channels=1) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + latents = torch.zeros((1, 1, 8, 8)) + image = pipe( + batch_size=1, + num_inference_steps=1, + timesteps=None, + latents=latents, + output_type="pt", + ).images + + assert image.shape == (1, 1, 8, 8) + + def test_tuple_sample_size(self): + device = "cpu" + components = self.get_tiny_components(sample_size=(8, 10)) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + image = pipe( + batch_size=1, + num_inference_steps=1, + timesteps=None, + generator=generator, + output_type="pt", + ).images + + assert image.shape == (1, 3, 8, 10) + @nightly @require_torch_accelerator