Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 56 additions & 32 deletions tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size

from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator

from ...testing_utils import assert_tensors_close, torch_device
from ...testing_utils import (
assert_tensors_close,
require_accelerator,
require_torch_multi_accelerator,
torch_device,
)


def named_persistent_module_tensors(
Expand Down Expand Up @@ -278,8 +282,30 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""

@pytest.fixture(scope="class")
def base_model_output(self):
"""Class-scoped reference forward output, built once and reused across the class.

Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)`
plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is
identical for every test in the class. The save/load and parallelism tests compare a reloaded
model against this output; computing it a single time here — instead of rebuilding the model and
re-running the forward in each test — removes that redundant work and speeds up the suite.

The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the
`require_*` decorators), which pytest evaluates before fixture setup, so skipping on a machine
without the required accelerators never triggers this forward.

Tests that still need a live model (e.g. to save it) build their own with the same seed, so the
reloaded model's weights match this cached output.
"""
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict()).eval().to(torch_device)
with torch.no_grad():
return model(**self.get_dummy_inputs(), return_dict=False)[0]

@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
def test_from_save_pretrained(self, base_model_output, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
Expand All @@ -296,13 +322,15 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
assert_tensors_close(
base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes."
)

@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
def test_from_save_pretrained_variant(self, base_model_output, tmp_path, atol=5e-5, rtol=0):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
Expand All @@ -317,10 +345,11 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):

new_model.to(torch_device)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
assert_tensors_close(
base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes."
)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
Expand Down Expand Up @@ -360,13 +389,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
)

@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()

inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
def test_output(self, base_model_output, expected_output_shape=None):
output = base_model_output

assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
Expand Down Expand Up @@ -509,14 +533,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,

@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small

Expand All @@ -537,19 +559,17 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
)

@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints_with_variant(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
Expand All @@ -575,20 +595,22 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
base_model_output,
new_output,
atol=atol,
rtol=rtol,
msg="Output should match after variant sharded save/load",
)

@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints_with_parallel_loading(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import constants

torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small

Expand Down Expand Up @@ -624,7 +646,11 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
output_parallel = model_parallel(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
base_model_output,
output_parallel,
atol=atol,
rtol=rtol,
msg="Output should match with parallel loading",
)

finally:
Expand All @@ -635,19 +661,17 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt

@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
def test_model_parallelism(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")

torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()

model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]

model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]

Expand All @@ -665,5 +689,5 @@ def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
)
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:


class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
def test_output(self):
def test_output(self, base_model_output):
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
super().test_output(base_model_output, expected_output_shape=(batch_size,) + self.output_shape)


class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_output(self, base_model_output):
super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape))


# ======================== HunyuanVideo Token Replace Image-to-Video ========================
Expand Down Expand Up @@ -299,5 +299,5 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_output(self, base_model_output):
super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape))
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""

def test_output(self):
def test_output(self, base_model_output):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
super().test_output(base_model_output, expected_output_shape=expected_output_shape)

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
Expand Down
Loading