diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c487de89cace..766c34ed6934 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -719,7 +719,7 @@ def convert_rope_params_to_dict(self, **kwargs): partial_rotary_factor = kwargs.get("partial_rotary_factor", getattr(self, "partial_rotary_factor", None)) if partial_rotary_factor is not None: self.rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) - self.ignore_keys_at_rope_validation = self.ignore_keys_at_rope_validation | {"partial_rotary_factor"} + self.ignore_keys_at_rope_validation = set(self.ignore_keys_at_rope_validation) | {"partial_rotary_factor"} self.standardize_rope_params() return kwargs diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 3240a74bf838..779225f34ed1 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -136,6 +136,28 @@ def test_yarn_original_original_max_position_embeddings_validation(self): self.assertEqual(len(logs.output), 1) self.assertIn("implicit factor", logs.output[0]) + def test_convert_rope_params_to_dict_with_list_ignore_keys(self): + # Regression test: `ignore_keys_at_rope_validation` becomes a list when loaded from a config.json + # (JSON has no set type). `convert_rope_params_to_dict` used to do `list | set` and crash with + # TypeError when `partial_rotary_factor` was also set. + config = LlamaConfig(partial_rotary_factor=0.25) + config.ignore_keys_at_rope_validation = ["mrope_section", "mrope_interleaved"] + + config.convert_rope_params_to_dict(partial_rotary_factor=0.25) + + self.assertIsInstance(config.ignore_keys_at_rope_validation, set) + self.assertEqual( + config.ignore_keys_at_rope_validation, + {"mrope_section", "mrope_interleaved", "partial_rotary_factor"}, + ) + + # Round-trip through from_dict to mimic the JSON-deserialized path that triggered this in production. + cfg_dict = config.to_dict() + cfg_dict["ignore_keys_at_rope_validation"] = ["mrope_section", "mrope_interleaved"] + reloaded = LlamaConfig.from_dict(cfg_dict) + reloaded.convert_rope_params_to_dict(partial_rotary_factor=0.25) + self.assertIsInstance(reloaded.ignore_keys_at_rope_validation, set) + def test_rope_validation_with_per_attention_type_nested_rope(self): """Mirrors `test_rope_validation` with `config.layer_types` set, so that `rope_parameters` takes the per-attention-type nested shape."""