diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 080f852e2490..44773100995e 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -14,10 +14,15 @@ import torch.nn.functional as F -if getattr(torch, "distributed", None) is not None: +if torch.distributed.is_available(): from torch.distributed.fsdp import CPUOffload, ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +else: + CPUOffload = None + ShardingStrategy = None + FSDP = None + transformer_auto_wrap_policy = None from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline