Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from ...image_processor import VaeImageProcessor
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
Expand Down Expand Up @@ -89,9 +90,20 @@ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler)
scheduler=scheduler,
)

self.image_processor = VaeImageProcessor()
self.safety_checker = None

def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
def prepare_latents(
self,
batch_size: int,
num_channels: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand All @@ -108,44 +120,57 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device
latents = latents * self.scheduler.init_noise_sigma
return latents

# Follows diffusers.VaeImageProcessor.postprocess
def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)

# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample

# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample

# Output_type must be 'pil'
sample = self.numpy_to_pil(sample)
return sample

def prepare_class_labels(self, batch_size, device, class_labels=None):
def prepare_class_labels(
self,
batch_size: int,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None = None,
class_labels: torch.Tensor | list[int] | int | None = None,
) -> torch.Tensor | None:
if self.unet.config.num_class_embeds is not None:
if isinstance(class_labels, list):
class_labels = torch.tensor(class_labels, dtype=torch.int)
elif isinstance(class_labels, int):
assert batch_size == 1, "Batch size must be 1 if classes is an int"
class_labels = torch.tensor([class_labels], dtype=torch.int)
elif class_labels is None:
# Randomly generate batch_size class labels
# TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
if isinstance(generator, list):
class_labels = torch.cat(
[
torch.randint(
0,
self.unet.config.num_class_embeds,
size=(1,),
generator=g,
device=g.device,
).cpu()
for g in generator
]
)
else:
rand_device = generator.device if generator is not None else torch.device("cpu")
class_labels = torch.randint(
0,
self.unet.config.num_class_embeds,
size=(batch_size,),
generator=generator,
device=rand_device,
)
class_labels = class_labels.to(device)
else:
class_labels = None
return class_labels

def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
def check_inputs(
self,
num_inference_steps: int | None,
timesteps: list[int] | None,
latents: torch.Tensor | None,
batch_size: int,
height: int,
width: int,
callback_steps: int,
) -> None:
if num_inference_steps is None and timesteps is None:
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")

Expand All @@ -156,7 +181,7 @@ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_
)

if latents is not None:
expected_shape = (batch_size, 3, img_size, img_size)
expected_shape = (batch_size, self.unet.config.in_channels, height, width)
if latents.shape != expected_shape:
raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")

Expand All @@ -175,36 +200,37 @@ def __call__(
batch_size: int = 1,
class_labels: torch.Tensor | list[int] | int | None = None,
num_inference_steps: int = 1,
timesteps: list[int] = None,
timesteps: list[int] | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
output_type: str | None = "pil",
output_type: str = "pil",
return_dict: bool = True,
callback: Callable[[int, int, torch.Tensor], None] | None = None,
callback_steps: int = 1,
):
) -> ImagePipelineOutput | tuple:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
class_labels (`torch.Tensor` or `list[int]` or `int`, *optional*):
Optional class labels for conditioning class-conditional consistency models. Not used if the model is
not class-conditional.
Optional class labels for conditioning class-conditional consistency models. If not provided for a
class-conditional model, labels are sampled randomly using `generator`. Not used if the model is not
class-conditional.
num_inference_steps (`int`, *optional*, defaults to 1):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`list[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
The output format of the generated image. Choose between `"pil"`, `"np"`, or `"pt"`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
Expand All @@ -222,27 +248,31 @@ def __call__(
returned where the first element is a list with the generated images.
"""
# 0. Prepare call parameters
img_size = self.unet.config.sample_size
sample_size = self.unet.config.sample_size
if isinstance(sample_size, int):
height = width = sample_size
else:
height, width = sample_size
device = self._execution_device

# 1. Check inputs
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, height, width, callback_steps)

# 2. Prepare image latents
# Sample image latents x_0 ~ N(0, sigma_0^2 * I)
sample = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=img_size,
width=img_size,
height=height,
width=width,
dtype=self.unet.dtype,
device=device,
generator=generator,
latents=latents,
)

# 3. Handle class_labels for class-conditional models
class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
class_labels = self.prepare_class_labels(batch_size, device, generator=generator, class_labels=class_labels)

# 4. Prepare timesteps
if timesteps is not None:
Expand Down Expand Up @@ -271,7 +301,11 @@ def __call__(
xm.mark_step()

# 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type)
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
image = self.image_processor.postprocess(sample, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()
Expand Down
100 changes: 100 additions & 0 deletions tests/pipelines/consistency_models/test_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ def get_dummy_components(self, class_cond=False):

return components

def get_tiny_components(self, sample_size=8, in_channels=3, out_channels=3, num_class_embeds=None):
torch.manual_seed(0)
unet = UNet2DModel(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
layers_per_block=1,
block_out_channels=(8,),
down_block_types=("DownBlock2D",),
up_block_types=("UpBlock2D",),
norm_num_groups=4,
num_class_embeds=num_class_embeds,
)
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
return {"unet": unet, "scheduler": scheduler}

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
Expand Down Expand Up @@ -168,6 +188,86 @@ def test_consistency_model_pipeline_onestep_class_cond(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

def test_random_class_labels_use_generator(self):
device = "cpu"
components = self.get_tiny_components(num_class_embeds=10)
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device=device).manual_seed(123)
labels = pipe.prepare_class_labels(batch_size=4, device=torch.device(device), generator=generator)

expected_generator = torch.Generator(device=device).manual_seed(123)
expected_labels = torch.randint(0, pipe.unet.config.num_class_embeds, size=(4,), generator=expected_generator)
assert torch.equal(labels.cpu(), expected_labels)

list_labels = pipe.prepare_class_labels(batch_size=2, device=torch.device(device), class_labels=[1, 2])
int_labels = pipe.prepare_class_labels(batch_size=1, device=torch.device(device), class_labels=1)
assert list_labels.dtype == torch.int
assert int_labels.dtype == torch.int

latents = torch.zeros((1, 3, 8, 8))
generator = torch.Generator(device=device).manual_seed(123)
image = pipe(
batch_size=1,
class_labels=None,
num_inference_steps=1,
timesteps=None,
generator=generator,
latents=latents,
output_type="pt",
).images

generator = torch.Generator(device=device).manual_seed(123)
image_2 = pipe(
batch_size=1,
class_labels=None,
num_inference_steps=1,
timesteps=None,
generator=generator,
latents=latents,
output_type="pt",
).images

assert torch.equal(image, image_2)

def test_latents_use_unet_in_channels(self):
device = "cpu"
components = self.get_tiny_components(in_channels=1, out_channels=1)
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

latents = torch.zeros((1, 1, 8, 8))
image = pipe(
batch_size=1,
num_inference_steps=1,
timesteps=None,
latents=latents,
output_type="pt",
).images

assert image.shape == (1, 1, 8, 8)

def test_tuple_sample_size(self):
device = "cpu"
components = self.get_tiny_components(sample_size=(8, 10))
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device=device).manual_seed(0)
image = pipe(
batch_size=1,
num_inference_steps=1,
timesteps=None,
generator=generator,
output_type="pt",
).images

assert image.shape == (1, 3, 8, 10)


@nightly
@require_torch_accelerator
Expand Down
Loading