From f4e6c97014118906a2edf34035f39b22860c7335 Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 30 Apr 2026 21:25:07 -0700 Subject: [PATCH 1/3] fix(unets): preserve scalar float timestep dtype in UNet1DModel --- src/diffusers/models/unets/unet_1d.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 83ffe1f6f8cb..f210f989ecbc 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -227,7 +227,14 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) From bbf4973e0ffefad37acdc7df4463d9cc9ddd402c Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 30 Apr 2026 21:25:09 -0700 Subject: [PATCH 2/3] fix(unets): preserve scalar float timestep dtype in UNet2DModel --- src/diffusers/models/unets/unet_2d.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 4e54f757d120..63399f839206 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -277,7 +277,14 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) From a6bf97a30c43660c645abf62c64e1d2919aa1ec8 Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 30 Apr 2026 21:25:10 -0700 Subject: [PATCH 3/3] fix(unets): preserve scalar float timestep dtype in FlaxUNet2DConditionModel --- src/diffusers/models/unets/unet_2d_condition_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index a361026fc0ea..4ddadc50a4b4 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -370,7 +370,8 @@ def __call__( """ # 1. time if not isinstance(timesteps, jnp.ndarray): - timesteps = jnp.array([timesteps], dtype=jnp.int32) + dtype = jnp.float32 if isinstance(timesteps, float) else jnp.int32 + timesteps = jnp.array([timesteps], dtype=dtype) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0)