feat: registry-based dispatch for resolve_lora_variant in Linear layer#3219
feat: registry-based dispatch for resolve_lora_variant in Linear layer#3219isha822 wants to merge 4 commits into
Conversation
|
Thanks for the initial draft. Some comments:
Yes, ideally we only need it once.
I think it pays to be explicit: Each subclass (i.e. the concrete LoRA layers) should require their own property.
Good, let's keep the PR up to date with new LoRA variant additions. In Let's also extend the tests to capture the new functionality. E.g. you could write a test that monkey-patches Also, let's add a docstring to |
|
@BenjaminBossan , thank you for the detailed feedback and architectural guidance. I've updated the PR to implement the centralized routing approach. Here is a breakdown of the changes in this iteration:
Everything is passing cleanly on my end. Let me know if this aligns with what you had in mind, or if you'd like any further refinements! |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the recent updates, the PR is taking a good shape. I made a few comments, please check. Also, once you're finished with your changes, please run make style to ensure proper code formatting.
| return {(): None} | ||
|
|
||
| def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: | ||
| import dataclasses |
There was a problem hiding this comment.
Let's make the import global, as it's a builtin.
| ) | ||
|
|
||
| # 3. Figure out which variants are currently active | ||
| active = tuple(sorted( |
There was a problem hiding this comment.
Let's call it active_variants.
| } | ||
|
|
||
| # 2. SANITY CHECK: Ensure all keys in the layer's dictionary actually exist in the config | ||
| for variant_combo in layer_variants.keys(): |
There was a problem hiding this comment.
Let's rename variant_combo to variant_keys.
| mock_variants.return_value = {("fake_unregistered_variant",): None} | ||
|
|
||
| # 3. Assert that the sanity check catches it and throws the right error | ||
| with pytest.raises(ValueError, match="not tagged with 'is_lora_variant'"): |
There was a problem hiding this comment.
Let's check the full error message here.
|
|
||
| # 3. Assert that the sanity check catches it and throws the right error | ||
| with pytest.raises(ValueError, match="not tagged with 'is_lora_variant'"): | ||
| layer.resolve_lora_variant(config=config) |
There was a problem hiding this comment.
Could you please also add a test that checks for the Invalid or unsupported variant combination error? It'll probably require monkey-patching too.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
…riant_configs, add tests
|
@BenjaminBossan , all changes addressed:
All 16 tests passing. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the update. I have two more comments.
| with lora_model.disable_adapter(): | ||
| with torch.no_grad(): | ||
| base_out = lora_model(X=input_ids) | ||
| with lora_model.disable_adapter(), torch.no_grad(): |
There was a problem hiding this comment.
Let's not touch unrelated code.
|
|
||
| Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this | ||
| method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA. | ||
| # Safely fetch the dictionary (defaults to empty if a subclass forgot to define it) |
There was a problem hiding this comment.
The indentation here and elsewhere is not correct. I think running make style should be enough to fix it. Could you please try that?
|
@BenjaminBossan the unrelated changes in the test file were introduced by |
Related to #3182
Draft implementation of the registry-based dispatch approach
discussed in the RFC.
Changes so far:
LoraConfigwithmetadata={"is_variant": True}(use_dora, arrow_config,use_bdlora, alora_invocation_tokens, velora_config)
Linear.resolve_lora_variantwitha metadata-driven loop +
valid_variantsregistry propertytest_lora_variants.pypassOpen questions for discussion:
resolve_lora_variantmove to theLoraLayerbaseclass so subclasses only define
valid_variants?Embeddingand_ConvNdlayers be handled —do they need their own
valid_variantsor can they inherit?original RFC scope — let me know if it should be excluded