Skip to content

Commit 9c1237a

Browse files
claudebrendanlong
authored andcommitted
Replace hardcoded Gemma config dicts with HF config cache entries
Move all Gemma 1, 2, and 3 models to _HF_CONFIG_CACHE using their native HF config classes (GemmaConfig, Gemma2Config, Gemma3TextConfig, Gemma3Config). Add generic architecture handlers for GemmaForCausalLM, Gemma2ForCausalLM, Gemma3ForCausalLM, and Gemma3ForConditionalGeneration that read from the HF config objects, eliminating ~460 lines of hardcoded dicts and the name-based architecture detection. Fixes google/gemma-2b incorrectly getting Gemma2ForCausalLM architecture (now correctly gets GemmaForCausalLM from cache). Gemma 2 attn_types now exactly match n_layers instead of having extra unused entries. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 83602aa commit 9c1237a

3 files changed

Lines changed: 299 additions & 457 deletions

File tree

tests/unit/test_convert_hf_model_config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,7 @@ def test_gemma_2b(self):
403403
"n_key_value_heads": 1,
404404
"gated_mlp": True,
405405
"final_rms": True,
406-
# NOTE: "google/gemma-2b" contains "gemma-2" so the architecture
407-
# detection assigns "Gemma2ForCausalLM" even though this is a Gemma 1 model.
408-
# The config values are still correct (matched by model name prefix).
409-
"original_architecture": "Gemma2ForCausalLM",
406+
"original_architecture": "GemmaForCausalLM",
410407
},
411408
)
412409

@@ -457,7 +454,7 @@ def test_gemma_2_2b(self):
457454
"n_key_value_heads": 4,
458455
"window_size": 4096,
459456
"use_local_attn": True,
460-
"attn_types": ["global", "local"] * 21,
457+
"attn_types": ["global", "local"] * 13, # 26 layers
461458
"attn_scores_soft_cap": 50.0,
462459
"output_logits_soft_cap": 30.0,
463460
"gated_mlp": True,
@@ -488,7 +485,7 @@ def test_gemma_2_9b(self):
488485
"n_key_value_heads": 8,
489486
"window_size": 4096,
490487
"use_local_attn": True,
491-
"attn_types": ["global", "local"] * 21,
488+
"attn_types": ["global", "local"] * 21, # 42 layers
492489
"attn_scores_soft_cap": 50.0,
493490
"output_logits_soft_cap": 30.0,
494491
"gated_mlp": True,
@@ -520,7 +517,7 @@ def test_gemma_2_27b(self):
520517
"n_key_value_heads": 16,
521518
"window_size": 4096,
522519
"use_local_attn": True,
523-
"attn_types": ["global", "local"] * 23,
520+
"attn_types": ["global", "local"] * 23, # 46 layers
524521
"attn_scores_soft_cap": 50.0,
525522
"output_logits_soft_cap": 30.0,
526523
"gated_mlp": True,

0 commit comments

Comments
 (0)