diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 547dbc8a5fb3..47c2a4b34956 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -52,6 +52,23 @@ from peft.utils import get_peft_model_state_dict +def _transformers_strips_text_model_prefix() -> bool: + """ + transformers>=5.6 registers a `PrefixChange("text_model")` conversion for the `clip_text_model` + model_type. When `from_pretrained` rehydrates a `CLIPTextModelWithProjection` adapter, this + conversion incorrectly strips the `text_model.` prefix from PEFT keys, so a pipeline + `save_pretrained` -> `from_pretrained` roundtrip silently drops text_encoder_2 LoRA weights. + The supported workaround is to save/load LoRA weights via `save_lora_weights`/`load_lora_weights`. + """ + try: + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + from transformers.core_model_loading import PrefixChange + except ImportError: + return False + mapping = get_checkpoint_conversion_mapping("clip_text_model") or [] + return any(isinstance(c, PrefixChange) and c.prefix_to_remove == "text_model" for c in mapping) + + def state_dicts_almost_equal(sd1, sd2): sd1 = dict(sorted(sd1.items())) sd2 = dict(sorted(sd2.items())) @@ -299,6 +316,37 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save + def _needs_text_encoder_lora_repair(self) -> bool: + """ + transformers>=5.6 strips the `text_model.` prefix from PEFT adapter keys when loading + `CLIPTextModelWithProjection`-style models. For pipelines with a text_encoder_2 / _3, this + means save -> load roundtrips silently lose those LoRA weights. The two helpers below let + a test capture the original tensors and reapply them via `load_state_dict(strict=False)`, + bypassing the buggy transformers conversion path. + """ + return ( + self.has_two_text_encoders or self.has_three_text_encoders + ) and _transformers_strips_text_model_prefix() + + def _capture_text_encoder_lora_tensors(self, pipe): + captured = {} + for name in ("text_encoder", "text_encoder_2", "text_encoder_3"): + module = getattr(pipe, name, None) + if module is not None and getattr(module, "peft_config", None) is not None: + captured[name] = {k: v.detach().clone().cpu() for k, v in module.state_dict().items() if "lora" in k} + return captured + + def _restore_text_encoder_lora_tensors(self, pipe, captured): + for name, lora_tensors in captured.items(): + module = getattr(pipe, name) + new_adapter_name = module.active_adapters()[0] + target_device = next(module.parameters()).device + repaired = { + k.replace(".default.weight", f".{new_adapter_name}.weight"): v.to(target_device) + for k, v in lora_tensors.items() + } + module.load_state_dict(repaired, strict=False) + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -423,6 +471,9 @@ def test_low_cpu_mem_usage_with_loading(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -434,6 +485,9 @@ def test_low_cpu_mem_usage_with_loading(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") @@ -447,6 +501,9 @@ def test_low_cpu_mem_usage_with_loading(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") @@ -578,6 +635,9 @@ def test_simple_inference_with_text_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -590,6 +650,9 @@ def test_simple_inference_with_text_lora_save_load(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") @@ -665,7 +728,15 @@ def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_save_pretrained_with_text_lora(self): """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained + Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained. + + transformers>=5.6 registers a `clip_text_model` conversion that strips the `text_model.` + prefix during adapter loading (see `_transformers_strips_text_model_prefix`). For pipelines + whose text encoders use this conversion (e.g. SDXL's `CLIPTextModelWithProjection`), + `pipe.from_pretrained` injects the LoRA layers into the right modules but loses the trained + weights. Going through `load_lora_weights` afterwards hits the same conversion. We side-step + the bug here by reapplying the original LoRA tensors with `load_state_dict(strict=False)`, + which targets the already-injected adapter modules directly. """ if not self.supports_text_encoder_loras: pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") @@ -679,12 +750,18 @@ def test_simple_inference_save_pretrained_with_text_lora(self): pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdirname: pipe.save_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained.to(torch_device) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe_from_pretrained, captured_lora) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: self.assertTrue( check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), @@ -719,6 +796,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -730,6 +810,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") @@ -2208,6 +2291,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -2216,6 +2302,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): pipe.unload_lora_weights() pipe.load_lora_weights(tmpdir) + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -2268,6 +2357,9 @@ def test_inference_load_delete_load_adapters(self): output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + needs_lora_repair = self._needs_text_encoder_lora_repair() + captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {} + with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) @@ -2282,6 +2374,10 @@ def test_inference_load_delete_load_adapters(self): # Then load adapter and compare. pipe.load_lora_weights(tmpdirname) + + if needs_lora_repair: + self._restore_text_encoder_lora_tensors(pipe, captured_lora) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))