Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions docs/source/en/optimization/tpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# TorchTPU

[TorchTPU](https://github.com/pytorch/tpu) registers a `"tpu"` device type with PyTorch, enabling you to run
diffusers pipelines on Google Cloud TPUs (v4, v5p, v5e, …) with minimal code changes.

Three execution modes are available:

| Mode | How to activate | Speed | Notes |
|---|---|---|---|
| **Lazy** (default) | just `import torch_tpu` | baseline | XLA traces the graph lazily |
| **Eager** | `set_eager_mode(EagerMode.DEFER_NEVER)` | medium | dispatch ops eagerly |
| **Compile** | `pipe.enable_tpu_compile()` | fastest (~4–6×) | static compilation with `TpuBackend` |

## Installation

Follow the [TorchTPU installation guide](https://github.com/pytorch/tpu). After installation,
`import torch_tpu` registers the `"tpu"` device automatically.

## Text encoders always stay on CPU

XLA's static graph compiler does not support certain dynamic ops used in text encoders (notably
`index_select` on large embedding tables). Text encoders must therefore remain on CPU. Their
output embeddings are moved to the TPU after encoding.

Diffusers handles this transparently:
- `_execution_device` detects any component on TPU and returns that device.
- `encode_prompt` runs the text encoder on its own device (`cpu`) and moves the resulting
embeddings to the execution device (TPU).
- `randn_tensor` generates initial noise on CPU and moves it to TPU, avoiding a TPU RNG
unaligned DUS (dynamic-update-slice) bug.

## Basic usage (lazy mode)

```python
import torch
import torch_tpu # noqa: F401 — registers torch.tpu

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
)

# Move only the denoising components to TPU; text encoders stay on CPU.
pipe.transformer.to("tpu")
pipe.vae.to("tpu")

# _execution_device is now "tpu" automatically.
image = pipe(
prompt="a golden retriever surfing a wave, photorealistic",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
).images[0]

image.save("output.png")
```

## Compiled mode (recommended for production)

`torch.compile` with `TpuBackend` traces the transformer statically and gives the largest
speedup. The first call (warmup) is slow because it triggers compilation; subsequent calls
reuse the compiled graph.

```python
import torch
import torch_tpu # noqa: F401

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
)
pipe.transformer.to("tpu")
pipe.vae.to("tpu")

# Compile TPU components with TpuBackend.
# Also applies AttnProcessor to replace SDP-based attention (required for XLA).
pipe.enable_tpu_compile()

# Warmup — triggers static graph compilation.
pipe.tpu_warmup(
prompt="warmup",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
)

# Timed inference reuses the compiled graph.
image = pipe(
prompt="a golden retriever surfing a wave, photorealistic",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
).images[0]

image.save("output.png")
```

## SDXL

SDXL uses a UNet instead of a transformer. The same approach applies.

```python
import torch
import torch_tpu # noqa: F401

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
use_safetensors=True,
)

pipe.unet.to("tpu")
pipe.vae.to("tpu")

pipe.enable_tpu_compile()
pipe.tpu_warmup(
prompt="warmup",
height=1024,
width=1024,
num_inference_steps=20,
guidance_scale=7.5,
)

image = pipe(
prompt="a golden retriever surfing a wave, photorealistic",
height=1024,
width=1024,
num_inference_steps=20,
guidance_scale=7.5,
).images[0]

image.save("output.png")
```

> [!NOTE]
> In SDXL **lazy/eager mode** (without `enable_tpu_compile`), `time_proj` inside the UNet
> runs on CPU automatically to avoid an XLA unaligned DUS crash. `enable_tpu_compile` uses
> `TpuBackend` which handles the layout internally, so no wrapper is needed in compiled mode.

## Eager mode

Eager mode dispatches ops immediately instead of accumulating a lazy graph. Enter it
**before loading or moving models** to TPU:

```python
import torch
import torch_tpu # noqa: F401
from torch_tpu._internal.execution_mode import EagerMode, set_eager_mode

eager_ctx = set_eager_mode(EagerMode.DEFER_NEVER)
eager_ctx.__enter__()

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.transformer.to("tpu")
pipe.vae.to("tpu")

image = pipe(prompt="a cat", height=1024, width=1024, num_inference_steps=4, guidance_scale=0.0).images[0]
image.save("output.png")

eager_ctx.__exit__(None, None, None)
```

## Performance benchmarks (v5p, BF16)

| Model | Mode | Steps | Resolution | Time/iter |
|---|---|---|---|---|
| FLUX.2-klein-9B | Lazy | 4 | 1024×1024 | 7.82 s |
| FLUX.2-klein-9B | Compile | 4 | 1024×1024 | 1.94 s |
| ERNIE-Image-Turbo | Lazy | 8 | 1024×1024 | 5.97 s |
| ERNIE-Image-Turbo | Compile | 8 | 1024×1024 | 2.24 s |
| Wan2.2-TI2V (video) | Eager | 50 | 480×832 | 82.2 s |
| Wan2.2-TI2V (video) | Compile | 50 | 480×832 | 14.2 s |

## API reference

### `enable_tpu_compile`

[[autodoc]] diffusers.DiffusionPipeline.enable_tpu_compile

### `tpu_warmup`

[[autodoc]] diffusers.DiffusionPipeline.tpu_warmup
15 changes: 12 additions & 3 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
deprecate,
logging,
)
from ...utils.torch_utils import is_compiled_module
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import (
Expand Down Expand Up @@ -855,18 +856,26 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float |
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
is_tpu = sample.device.type == "tpu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_tpu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = torch.int32 if (is_mps or is_npu or is_tpu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
# On TPU in eager/lazy mode, torch.cat([sin, cos], dim=-1) inside time_proj
# lands at an unaligned offset in the XLA DUS fusion emitter → crash.
# torch.compile with TpuBackend handles this internally, so skip the CPU
# workaround when we're inside a compiled graph.
if sample.device.type == "tpu" and not torch.compiler.is_compiling():
t_emb = self.time_proj(timesteps.cpu()).to(sample.device)
else:
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _enhance_prompt_with_pe(
tokenize=False,
add_generation_prompt=False, # "Output:" is already in the user block
)
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device)
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(self.pe.device)
output_ids = self.pe.generate(
**inputs,
max_new_tokens=self.pe_tokenizer.model_max_length,
Expand Down Expand Up @@ -155,7 +155,7 @@ def encode_prompt(
else:
ids = [0]

input_ids = torch.tensor([ids], device=device)
input_ids = torch.tensor([ids], device=self.text_encoder.device)
with torch.no_grad():
outputs = self.text_encoder(
input_ids=input_ids,
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -296,7 +297,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -327,7 +328,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -321,7 +322,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -343,7 +344,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -376,7 +377,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def _get_qwen3_prompt_embeds(
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])

input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
model_device = text_encoder.device
input_ids = torch.cat(all_input_ids, dim=0).to(model_device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(model_device)

# Forward pass through the model
output = text_encoder(
Expand Down
Loading
Loading