diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 83ffe1f6f8cb..f210f989ecbc 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -227,7 +227,14 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 4e54f757d120..63399f839206 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -277,7 +277,14 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index a361026fc0ea..4ddadc50a4b4 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -370,7 +370,8 @@ def __call__( """ # 1. time if not isinstance(timesteps, jnp.ndarray): - timesteps = jnp.array([timesteps], dtype=jnp.int32) + dtype = jnp.float32 if isinstance(timesteps, float) else jnp.int32 + timesteps = jnp.array([timesteps], dtype=dtype) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0)