Skip to content

feat: registry-based dispatch for resolve_lora_variant in Linear layer#3219

Draft
isha822 wants to merge 4 commits into
huggingface:mainfrom
isha822:feature/lora-variant-registry
Draft

feat: registry-based dispatch for resolve_lora_variant in Linear layer#3219
isha822 wants to merge 4 commits into
huggingface:mainfrom
isha822:feature/lora-variant-registry

Conversation

@isha822
Copy link
Copy Markdown

@isha822 isha822 commented May 9, 2026

Related to #3182

Draft implementation of the registry-based dispatch approach
discussed in the RFC.

Changes so far:

  • Tagged variant-related fields in LoraConfig with
    metadata={"is_variant": True} (use_dora, arrow_config,
    use_bdlora, alora_invocation_tokens, velora_config)
  • Replaced the if-chain in Linear.resolve_lora_variant with
    a metadata-driven loop + valid_variants registry property
  • All existing tests in test_lora_variants.py pass

Open questions for discussion:

  • Should resolve_lora_variant move to the LoraLayer base
    class so subclasses only define valid_variants?
  • How should Embedding and _ConvNd layers be handled —
    do they need their own valid_variants or can they inherit?
  • velora_config included for completeness but wasn't in the
    original RFC scope — let me know if it should be excluded

@BenjaminBossan
Copy link
Copy Markdown
Member

BenjaminBossan commented May 12, 2026

Thanks for the initial draft. Some comments:

  • Let's call it is_lora_variant to be super explicit
  • Let's rename valid_variants to lora_variants

Should resolve_lora_variant move to the LoraLayer base class so subclasses only define valid_variants?

Yes, ideally we only need it once.

  • How should Embedding and _ConvNd layers be handled — do they need their own valid_variants or can they inherit?

I think it pays to be explicit: Each subclass (i.e. the concrete LoRA layers) should require their own property.

  • velora_config included for completeness but wasn't in the original RFC scope — let me know if it should be excluded

Good, let's keep the PR up to date with new LoRA variant additions.

In resolve_lora_variant, we should also add a sanity check that the chosen names and the names from the config are the same. Let's load the complexity into this function so that the rest can be as simple as possible.

Let's also extend the tests to capture the new functionality. E.g. you could write a test that monkey-patches valid_variants to add a new variant that doesn't have a corresponding field in the LoraConfig. Then we should expect there to be an error with a helpful error message being raised.

Also, let's add a docstring to valid_variants that explains how to extend it and also why it uses tuples.

@isha822
Copy link
Copy Markdown
Author

isha822 commented May 14, 2026

@BenjaminBossan , thank you for the detailed feedback and architectural guidance. I've updated the PR to implement the centralized routing approach.

Here is a breakdown of the changes in this iteration:

  • Base Class Integration: Moved the resolve_lora_variant logic into the LoraLayer base class and implemented the lora_variants property dictionaries across the Linear, Embedding, Conv1d, Conv2d, and Conv3d layers.
  • Boilerplate Removal: Cleaned up the hardcoded VeLoRA/aLoRA ValueError checks from the convolutional layers, as the base class now inherently blocks unsupported variant combinations.
  • Config Validation (Sanity Check): Added a check within the base resolution logic to ensure that any variant mapped in a layer's dictionary explicitly matches an is_lora_variant field in LoraConfig.
  • Testing & Docs: Added the explanatory docstring for the property and included a monkey-patch test (test_unregistered_variant_raises_error) to confirm the sanity check properly catches unregistered variants.

Everything is passing cleanly on my end. Let me know if this aligns with what you had in mind, or if you'd like any further refinements!

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the recent updates, the PR is taking a good shape. I made a few comments, please check. Also, once you're finished with your changes, please run make style to ensure proper code formatting.

Comment thread src/peft/tuners/lora/layer.py Outdated
return {(): None}

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
import dataclasses
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the import global, as it's a builtin.

Comment thread src/peft/tuners/lora/layer.py
Comment thread src/peft/tuners/lora/layer.py Outdated
)

# 3. Figure out which variants are currently active
active = tuple(sorted(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it active_variants.

Comment thread src/peft/tuners/lora/layer.py Outdated
}

# 2. SANITY CHECK: Ensure all keys in the layer's dictionary actually exist in the config
for variant_combo in layer_variants.keys():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename variant_combo to variant_keys.

Comment thread src/peft/tuners/lora/layer.py Outdated
Comment thread tests/test_lora_variants.py Outdated
mock_variants.return_value = {("fake_unregistered_variant",): None}

# 3. Assert that the sanity check catches it and throws the right error
with pytest.raises(ValueError, match="not tagged with 'is_lora_variant'"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check the full error message here.


# 3. Assert that the sanity check catches it and throws the right error
with pytest.raises(ValueError, match="not tagged with 'is_lora_variant'"):
layer.resolve_lora_variant(config=config)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a test that checks for the Invalid or unsupported variant combination error? It'll probably require monkey-patching too.

isha822 and others added 2 commits May 19, 2026 13:15
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@isha822
Copy link
Copy Markdown
Author

isha822 commented May 19, 2026

@BenjaminBossan , all changes addressed:

  • Moved import dataclasses to module level
  • Extracted lora_variant_configs and reused it for both tagged_fields and active_variants
  • Renamed variant_combovariant_keys and activeactive_variants
  • Updated test_unregistered_variant_raises_error to match the full error message
  • Added test_invalid_variant_combination_raises_error for the unsupported combination error

All 16 tests passing.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. I have two more comments.

with lora_model.disable_adapter():
with torch.no_grad():
base_out = lora_model(X=input_ids)
with lora_model.disable_adapter(), torch.no_grad():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not touch unrelated code.


Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
# Safely fetch the dictionary (defaults to empty if a subclass forgot to define it)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation here and elsewhere is not correct. I think running make style should be enough to fix it. Could you please try that?

@isha822
Copy link
Copy Markdown
Author

isha822 commented May 20, 2026

@BenjaminBossan the unrelated changes in the test file were introduced by make style modifying code outside the scope of this PR. I'll revert those and run make style only on the files I've modified before pushing the fix. Will update you shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants