feat(lora): Add FIM-guided adaptive LoRA rank allocation (FimConfig + initialize_lora_fim_ranks)#3204
Conversation
Introduces FimConfig and initialize_lora_fim_ranks() — a calibration-based method that redistributes LoRA ranks across layers using the diagonal of the empirical Fisher Information Matrix (eFIM), subject to a global rank budget. Layers with high gradient variance (high eFIM score) receive higher rank; layers with low sensitivity receive lower rank. This allows the same total parameter count as fixed-rank LoRA, while concentrating capacity where the loss curvature is highest. Algorithm: F_ii ≈ (1/T) Σ (∂ℓ_t/∂θ_i)² (eFIM diagonal, mean squared gradient) rank_i ∝ mean(F_i) / Σ mean(F_j) × budget, clamped to [r_min, r_max] budget = n_layers × r (mean rank preserved) Files changed: src/peft/tuners/lora/fim.py — FimConfig + initialize_lora_fim_ranks src/peft/tuners/lora/config.py — fim_config field + 'fim' init mode src/peft/tuners/lora/layer.py — allow 'fim' in reset_lora_parameters src/peft/tuners/lora/__init__.py — export FimConfig, initialize_lora_fim_ranks src/peft/tuners/__init__.py — propagate exports src/peft/__init__.py — top-level export tests/test_lora_fim.py — 23 unit tests Relates to: huggingface#3203 Reference: LeCun et al., Optimal Brain Damage, NeurIPS 1990. Signed-off-by: Ramakrishnan Sathyavageeswaran <ramkrishs@outlook.com>
|
Thanks for providing this PR @ramkrishs. Is there any paper that shows that this initialization works well with LoRA? Did you run any of your own experiments? Usually, we don't add new methods to PEFT only on the theoretical assumption that they could work. |
|
Thank you for the feedback Benjamin. I'm currently running a structured comparison on GLUE (DeBERTaV3-base) and commonsense reasoning (LLaMA-3-8B) against LoRA, AdaLoRA, and EVA across rank budgets r ∈ {2, 4, 8, 16}. The experiment harness is set up and the first results should be ready within 2–3 weeks. This work is also being written up as a short paper — the closest prior work (AdaLoRA, ICLR 2023) shows that non-uniform rank allocation consistently outperforms fixed-rank LoRA, particularly at low budgets, and our hypothesis is that eFIM-based allocation is more directly tied to the fine-tuning objective than SVD-based signals. I'll update this PR with the results table and a link to the arXiv preprint once the experiments are complete. Happy to keep this as a draft in the meantime — no action needed from your side until then. |
Thanks, then let's pick this PR up again at that point. You could also check this init on the PEFT MetaMath benchmark, it'll probably just require a couple of lines of extra code. |
Summary
Adds
FimConfigandinitialize_lora_fim_ranks()— a data-driven method that redistributes LoRA ranks across layers using the diagonal of the empirical Fisher Information Matrix (eFIM), concentrating rank budget on layers that are most sensitive to the loss.Proposal issue: #3203
Motivation
LoRA uses a fixed rank
rfor all adapter matrices. Different layers have different sensitivity to fine-tuning data — early attention layers often require less capacity than later layers; q/v projections often differ from k projections. A fixed-rank allocation wastes capacity on insensitive layers.EVA (already in PEFT) addresses this via SVD of input activations. This PR uses a complementary signal: the eFIM diagonal (mean squared gradient), which directly measures per-parameter loss sensitivity rather than activation variance. The two are orthogonal — EVA optimizes initialization directions; FIM optimizes rank allocation by sensitivity.
Algorithm
The eFIM diagonal for parameter θᵢ:
Rank allocation:
Total rank budget is preserved: mean rank across layers equals the original
r.API
Follows the EVA pattern exactly:
FimConfigdataclass +initialize_lora_fim_ranks()public function +init_lora_weights='fim'trigger inLoraConfig.Files changed
src/peft/tuners/lora/fim.pyFimConfig,initialize_lora_fim_ranks, internal helperssrc/peft/tuners/lora/config.pyfim_configfield,'fim'toinit_lora_weightsLiteral, validationsrc/peft/tuners/lora/layer.py'fim'inreset_lora_parameters(treated as standard init; rank redistribution happens post-construction)src/peft/tuners/lora/__init__.pyFimConfig,initialize_lora_fim_rankssrc/peft/tuners/__init__.pysrc/peft/__init__.pytests/test_lora_fim.pyTests
Covers:
FimConfigconstruction/validation,_compute_layer_importance,_allocate_ranks(budget preservation, clamping, monotonicity),_resize_lora_layer(increase/decrease/noop, scaling adjustment),initialize_lora_fim_ranksend-to-end (with dataloader and pre-computed scores),LoraConfigvalidation warnings, and top-level import.No GPU required. All tests run on CPU.
Reference