diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..1e9db24d4337 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1903,6 +1903,7 @@ def device(self) -> torch.device: `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ + # Not cached: with group offloading, the effective device changes per-forward as groups onload/offload. return get_parameter_device(self) @property @@ -1910,7 +1911,17 @@ def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ - return get_parameter_dtype(self) + cached = self.__dict__.get("_cached_dtype") + if cached is not None: + return cached + cached = get_parameter_dtype(self) + self.__dict__["_cached_dtype"] = cached + return cached + + def _apply(self, fn, *args, **kwargs): + # Invalidate cached dtype since `.to()`, `.cpu()`, `.cuda()`, `.half()`, etc. all flow through `_apply`. + self.__dict__.pop("_cached_dtype", None) + return super()._apply(fn, *args, **kwargs) def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """