Skip to content

Add KappaTuneSelector: condition-number-based automatic LoRA target selection#3106

Open
oswaldoludwig wants to merge 56 commits intohuggingface:mainfrom
oswaldoludwig:main
Open

Add KappaTuneSelector: condition-number-based automatic LoRA target selection#3106
oswaldoludwig wants to merge 56 commits intohuggingface:mainfrom
oswaldoludwig:main

Conversation

@oswaldoludwig
Copy link
Copy Markdown

What this PR does

This PR adds lightweight target selection tooling, a new utility that lets users automatically select the best LoRA target_modules before creating a LoraConfig based on the lowest condition number per tensor, the metric used in KappaTune (https://arxiv.org/abs/2506.16289).

New API

from peft import LoraConfig, get_peft_model
from peft.utils.target_selection import KappaTuneSelector, find_kappa_target_modules

model = AutoModelForCausalLM.from_pretrained(...)   # your base model

# Option 1: class (full control)
selector = KappaTuneSelector(model)
optimal_targets = selector.get_best_targets(top_p=0.2)   # or num_modules=50, threshold=...

# Option 2: one-liner
optimal_targets = find_kappa_target_modules(model, top_p=0.2)

config = LoraConfig(
    target_modules=optimal_targets,
    r=16,
    lora_alpha=32,
    task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, config)

Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I think the interface is fine, KappaTuneSelector as a class to compute once and be able to experiment with different selections is good, having a short-cut one-liner is good as well.

There are only a few things missing:

  • Let's add unit tests, e.g. in tests/test_target_selection.py
  • Let's add an example, if possible re-creating the results from the paper
  • Let's add documentation (entry in _toctree.yml under Utilities and a package reference file in package_reference/target_selection)

Did you test this with other PEFT methods like MiSS or SHiRA to see if it has a similar effect? I wonder if this method generalizes to other methods as well!

Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
oswaldoludwig and others added 8 commits March 17, 2026 14:52
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
This script implements Kappa-Selection using the PEFT KappaTuneSelector for identifying higher-entropy, less-anisotropic modules. It includes data preparation, experiment execution, and evaluation of perplexity on IMDB and WikiText datasets.
Updated KappaTuneSelector to support bnb 4-bit models and improved comments (ready for paper experiments with QLoRA / 4-bit models).
The selector picks the layers with the lowest condition number (most isotropic = best for adaptation), exactly as shown in the KappaTune paper (https://arxiv.org/abs/2506.16289).
Add tests for KappaTuneSelector and target selection
Added detailed docstrings for target selection functions, documenting top_p, num_modules and threshold explicitly.
Avoiding a compatibility issue, reported dozens of times on HF discussions, between the DeepSeek-V2-Lite custom modeling code and Transformers v4.40+. The model’s forward pass still calls the old past_key_values.get_usable_length(...) method, but DynamicCache (the default cache class now) no longer has it.
Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the example and addressing the feedback.

While reviewing, I remembered that SPECTRUM exists (https://huggingface.co/papers/2406.06623) which is quite similar to this approach. Just out of curiosity: have you come across this paper and did you compare the quality of the selection by any chance?

Comment thread docs/source/package_reference/target_selection.md Outdated
Comment thread tests/test_target_selection.py Outdated
Comment thread examples/experiments_SA_kappatune_peft.py Outdated
Comment thread examples/KappaTune/experiments_SA_kappatune_peft.py Outdated
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
LR=2e-4
STP=35
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless LoRA doesn't improve after step 10 I'd say it is quite unfair to let LoRA + kappa train for 3.5x the time

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KappaTune's objective is not to accelerate convergence or reduce VRAM (Spectrum goal); it's to mitigate catastrophic forgetting when fine-tuning on downstream data. By selecting only the most isotropic modules, the model retains far more of its original knowledge even after many epochs (verify the perplexity on the control test set before and after fine-tuning). A truly "fair comparison" would require mixing the full pre-training corpus back in (as done in some continual-pretraining setups). That would be orders of magnitude pricier than 3.5× the current epochs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I can follow that logic. We want to show how much original data is forgotten. The reference data for measuring forgetting in this example is perplexity over parts of Wikipedia because we assume that the model has compressed a lot of Wikipedia data.

Is it the case that the effect is only visible at later steps? Then we should train both models for the same number of steps. If that is too costly, we can use a smaller model.

Right now both LoRA and LoRA + Kappa are so different that I am not sure any conclusion can be made (rank, steps, type of layer targeted). I think it is a good idea to have both trained as closely as possible.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@githbunnemo Thank you for the detailed feedback.

The different ranks (r=16 for baseline vs. r=190 for KappaTune_LoRA) were chosen deliberately to keep the total number of trainable parameters roughly the same (~191M vs. ~197M). This follows the paper’s comparison style: same parameter budget, but smarter target selection. Therefore, I intentionally matched the training perplexity on the target task rather than the number of steps because KappaTune's goal is continual learning and mitigating catastrophic forgetting, not accelerating convergence. That’s why the number of epochs differs between the two runs.

Including large amounts of pre-training data for a forgetting measurement would be far more expensive and is often not even available, which is exactly why we use Wiki PPL as a cheap proxy for retained general knowledge (I'm assuming Wikipedia is the most broad domain dataset we can get).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I think it would be good to document this in the example with the same level of detail that you used here for the explanation.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added guidelines for continual learning experiments + model-size note to the example README (with explanations).

Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
Comment thread src/peft/utils/target_selection.py Outdated
S = torch.linalg.svdvals(w.view(w.size(0), -1))
kappa = (S[0] / (S[-1] + 1e-8)).item()
condition_numbers[module_name] = kappa
except (torch.linalg.LinAlgError, RuntimeError):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which RuntimeError are we expecting here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.linalg.svdvals (especially on CUDA/certain hardware) can raise RuntimeError in edge cases, like numerical instability on extremely ill-conditioned matrices, CUDA-specific SVD failures (e.g. cuSOLVER error), and very small/zero matrices after the view() operation. LinAlgError alone may not be sufficient in practice for quantized models.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Of course this captures a whole lot more than that because RuntimeError is very generic. If you havent encountered this a lot I would prefer not to catch RuntimeError here and let the user decide what to do with that information.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still open?

Copy link
Copy Markdown
Author

@oswaldoludwig oswaldoludwig Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the try/except around torch.linalg.svdvals (target_selection.py) as requested. If SVD ever fails, the error will now be raised directly, i.e. no silent fallback.

Comment thread examples/experiments_SA_kappatune_peft.py
@oswaldoludwig
Copy link
Copy Markdown
Author

Thanks for adding the example and addressing the feedback.

While reviewing, I remembered that SPECTRUM exists (https://huggingface.co/papers/2406.06623) which is quite similar to this approach. Just out of curiosity: have you come across this paper and did you compare the quality of the selection by any chance?

I understand that Spectrum leverages Random Matrix Theory and the Marchenko-Pastur distribution to identify and train high-SNR layers for maximum efficiency, whereas KappaTune uses numerical linear algebra to prioritize knowledge preservation during continual learning. Their selection logic is fundamentally opposed: Spectrum updates the most informative layers to match full fine-tuning performance, while KappaTune targets tensors with low condition numbers to protect the model's most specialized knowledge. By surgically updating these flexible (high differential entropy) tensors, KappaTune provides a level of protection against catastrophic forgetting that Spectrum's layer-wise, VRAM-focused approach does not explicitly address. The idea is offering superior protection against catastrophic forgetting to empower smaller players to adapt foundation models from industry giants without harming their core capabilities.

oswaldoludwig and others added 12 commits April 7, 2026 20:08
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Switched to deterministic layer initialization so top_p=0.5 (and num_modules=1) always returns the well-conditioned fc1 and never fc2.
Clarify "quality cutoff" for threshold parameter (flexibility / differential entropy)
Adding bnb int8 quantization support in KappaTuneSelector to address the reviewer request.
Replace the bare exceptions in KappaTuneSelector.
Added detailed description and usage example for KappaTuneSelector.
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Clarified documentation for KappaTuneSelector regarding module sorting and catastrophic forgetting.
@oswaldoludwig oswaldoludwig requested a review from githubnemo April 9, 2026 19:01
Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes!

Regarding procedure: It would help me a lot if you do not mark review comments as resolved yourself :)

I've tested the example with a smaller model (llama3.2-1B) for quick experimentation and got no significant difference between KappaTune and baseline LoRA. I matched the number of trainable params by matching rank and setting budget_k=64 which is the number of trainable modules in the LoRA setting. This is not exact (16,056,320 kappa vs. 13,893,632 LoRA) but at least roughly comparable.

======================================================================
METHOD          | IMDB PPL (Task train) |  IMDB PPL (Task test) | Wiki PPL (General/control)
----------------------------------------------------------------------
KappaTune       | 18.1814            | 20.4582            | 28.0479           
LoRA_Global     | 18.5284            | 20.5689            | 28.0835           
Baseline        | 30.8990            | 32.3116            | 28.2422           
======================================================================

When training both longer (r=32, steps=30) KappaTune performs worse in Wiki PPL:

======================================================================
METHOD          | IMDB PPL (Task train) |  IMDB PPL (Task test) | Wiki PPL (General/control)
----------------------------------------------------------------------
KappaTune       | 9.6770             | 26.8508            | 40.0022           
LoRA_Global     | 12.0280            | 23.0591            | 36.9354           
Baseline        | 32.1751            | 32.3018            | 28.2422           
======================================================================

This might be a bug?

Comment thread examples/KappaTune/experiments_kappatune_peft.py
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
LR=2e-4
STP=35
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I can follow that logic. We want to show how much original data is forgotten. The reference data for measuring forgetting in this example is perplexity over parts of Wikipedia because we assume that the model has compressed a lot of Wikipedia data.

Is it the case that the effect is only visible at later steps? Then we should train both models for the same number of steps. If that is too costly, we can use a smaller model.

Right now both LoRA and LoRA + Kappa are so different that I am not sure any conclusion can be made (rank, steps, type of layer targeted). I think it is a good idea to have both trained as closely as possible.

Comment thread tests/test_target_selection.py Outdated
Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still open.

You can use testing_utils.require_bitsandbytes as a test mark for the bitsandbytes tests.

Comment thread src/peft/utils/target_selection.py Outdated
S = torch.linalg.svdvals(w.view(w.size(0), -1))
kappa = (S[0] / (S[-1] + 1e-8)).item()
condition_numbers[module_name] = kappa
except (torch.linalg.LinAlgError, RuntimeError):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Of course this captures a whole lot more than that because RuntimeError is very generic. If you havent encountered this a lot I would prefer not to catch RuntimeError here and let the user decide what to do with that information.

Comment thread src/peft/utils/target_selection.py Outdated
Comment thread examples/KappaTune/experiments_SA_kappatune_peft.py Outdated
Comment thread examples/experiments_SA_kappatune_peft.py Outdated
@oswaldoludwig
Copy link
Copy Markdown
Author

oswaldoludwig commented Apr 13, 2026

Thanks for the changes!

Regarding procedure: It would help me a lot if you do not mark review comments as resolved yourself :)

I've tested the example with a smaller model (llama3.2-1B) for quick experimentation and got no significant difference between KappaTune and baseline LoRA. I matched the number of trainable params by matching rank and setting budget_k=64 which is the number of trainable modules in the LoRA setting. This is not exact (16,056,320 kappa vs. 13,893,632 LoRA) but at least roughly comparable.

======================================================================
METHOD          | IMDB PPL (Task train) |  IMDB PPL (Task test) | Wiki PPL (General/control)
----------------------------------------------------------------------
KappaTune       | 18.1814            | 20.4582            | 28.0479           
LoRA_Global     | 18.5284            | 20.5689            | 28.0835           
Baseline        | 30.8990            | 32.3116            | 28.2422           
======================================================================

When training both longer (r=32, steps=30) KappaTune performs worse in Wiki PPL:

======================================================================
METHOD          | IMDB PPL (Task train) |  IMDB PPL (Task test) | Wiki PPL (General/control)
----------------------------------------------------------------------
KappaTune       | 9.6770             | 26.8508            | 40.0022           
LoRA_Global     | 12.0280            | 23.0591            | 36.9354           
Baseline        | 32.1751            | 32.3018            | 28.2422           
======================================================================

This might be a bug?

@githbunnemo Thank you for running these extra experiments and sharing the numbers.

The difference becomes smaller (or even disappears) on very small models like Llama-3.2-1B. This is expected; it's not a bug at all:

  • A 1B model has very few independent tensors, so the selective nature of KappaTune has limited impact.
  • In your second experiment, KappaTune fits the task data noticeably better (train PPL 9.68 vs 12.03). Therefore, the higher Wiki PPL is the classic overfitting/catastrophic forgetting trade-off when one method adapts more aggressively. A proper experiment on continual learning should fit the training data equally well with both models.

KappaTune was designed and validated primarily on larger models (7B+), where the benefit in mitigating forgetting without replay data is much clearer. As noted in the paper, MoE architectures offer even better conditions for selective fine-tuning, and KappaTune excels particularly well with this kind of model because there are many more independent expert modules to choose from.

oswaldoludwig and others added 8 commits April 21, 2026 13:26
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
The modified version adds full support for fused MoE weights stored as 3D nn.parameter, while the original script only handles nn.linear modules. This new version adds a second pass in _compute_kappas() that iterates over named_parameters() and looks for 3D parameters whose names end with suffixes like .gate_up_proj, .down_proj, etc. For each matching parameter, it computes κ per expert (slicing the 3D tensor along the first dimension) and averages the results. The modified version is backward-compatible with dense models.
…patune_peft.py

New experiment with a MoE model that doesn't need trust_remote_code=true and the new version of the target_selection script, which is ready for modern MoE architectures (Llama-4, Mixtral, Qwen-MoE) that don't store expert weights as individual nn.Linear layers, but as a single 3D tensor for greater efficiency.
Including a section on expected results with KappaTune vs LoRA logs.
Adding the KappaTune section.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Adding information about expert weights stored as a fused 3D nn.Parameter tensors.
Added a note regarding the intended use of KappaTune.
Adding experimental results for a dense model (Llama 8B) comparing KappaTune vs. LoRA under sustained training effort.
@oswaldoludwig
Copy link
Copy Markdown
Author

oswaldoludwig commented Apr 26, 2026

This might be a bug?

@githbunnemo Thank you for running these extra experiments and sharing the numbers.
The difference becomes smaller (or even disappears) on very small models like Llama-3.2-1B. This is expected; it's not a bug at all:

* A 1B model has very few independent tensors, so the selective nature of KappaTune has limited impact.

* In your second experiment, KappaTune fits the task data noticeably better (train PPL 9.68 vs 12.03). Therefore, the higher Wiki PPL is the classic overfitting/catastrophic forgetting trade-off when one method adapts more aggressively. A proper experiment on continual learning should fit the training data equally well with both models.

KappaTune was designed and validated primarily on larger models (7B+), where the benefit in mitigating forgetting without replay data is much clearer. As noted in the paper, MoE architectures offer even better conditions for selective fine-tuning, and KappaTune excels particularly well with this kind of model because there are many more independent expert modules to choose from.

I agree that the behavior of KappaTune resembles overfitting. One interpretation would be that KappaTune selects layers that fit the task best and is therefore more prone to overfitting - that is a valid interpretation. I'm still a bit skeptical, though. I haven't run your updated benchmark (sorry!) but used the previous one with llama3.2-1B and ran a few tests. I kept ranks equal between methods and changed budget_k to match parameters.

Baseline: LoRA 13,893,632 params vs. KappaTune 14,090,240 params, both 30 steps

METHOD          | IMDB PPL (Task train) |  IMDB PPL (Task test) | Wiki PPL (General/control)
----------------------------------------------------------------------
KappaTune       | 10.4988            | 24.0050            | 37.6163           
LoRA_Global     | 11.6370            | 22.0064            | 36.9830           

Hypothesis: we're seeing overfitting. Let's use r=1 for LoRA (434,176 params) and KappaTune (440,320 params, budget_k=55):

KappaTune       | 11.4146            | 22.7615            | 38.0532           
LoRA_Global     | 12.1635            | 22.4535            | 35.4393           

Again, KappaTune has way worse forgetting but the eval ppl is roughly the same. Not sure if we can argue for overfitting here. For fun I switched signs of kappa when sorting the parameter name list.

Same parameters as before (note: changing the sorting will obviously change the number of trainable params, KappaTune will target parameters here, KappaTune params in this case: 243,712):

KappaTune       | 16.1164            | 20.9910            | 29.3501           

I found it surprising that it performs so well, this is by far the best result yet with these hyper parameters.

Experiment: running both with r=32 and steps=30, budget_k=55 -> LoRA 13,893,632 params, KappaTune 7,798,784 params:

KappaTune       | 14.2367            | 22.6476            | 29.2819           
LoRA_Global     | 10.9440            | 24.0303            | 36.6101           

KappaTune performs better but that may be because the parameter counts are not matched.

Experiment: same parameters but budget_k=72, makes for 13.9M (LoRA) vs. 13.4M (kappa):

KappaTune       | 11.8606            | 26.4971            | 31.6223           
LoRA_Global     | 11.7936            | 24.4310            | 36.8024           

IMDB test ppl is slightly worse but forgetting is way better than LoRA. This is roughly what I would have expected from KappaTune.

I might be interpreting things wrongly but could it be that the sorting is wrong?

Hi @githubnemo, I conducted over 20 independent experiments with a dense model (Llama 8B) with increasing levels of adaptation effort using the script from the provided example. This may better explain the context in which KappaTune is important, i.e., for large, real-world datasets that demand strong adaptation effort. See: 1b77aa1

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the recent changes and for including the more extensive experiments. There isn't much that's still missing. I still had a few small comments, please check.

As for moving this to helpers.py, I think it should be quite straightfoward, maybe your IDE can even automate most of it. If there are any issues with that, don't hesitate to ask.

Comment thread docs/source/developer_guides/lora.md Outdated
Comment thread docs/source/developer_guides/lora.md Outdated
Comment thread docs/source/developer_guides/lora.md Outdated
peft_model = get_peft_model(model, config)
```

See a complete example [here](../../../examples/KappaTune/experiments_kappatune_peft.py).
Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's link the absolute path on GH. The link would be void right now but will work once the PR is merged.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a link to the absolute example path in lora.md

Comment thread examples/KappaTune/README.md Outdated
Comment thread src/peft/utils/target_selection.py Outdated
top_p: Optional[float] = None,
num_modules: Optional[int] = None,
threshold: Optional[float] = None,
) -> List[str]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since, for many models, there are no MoE layers, let's return None in that case, which is the default in PEFT.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please check Line 459 of the new helpers.py.

Comment thread tests/test_target_selection.py Outdated
selector = KappaTuneSelector(model)
targets = selector.get_best_targets(top_p=0.5)

assert len(targets) == 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests need updating now that the return value of find_kappa_target_modules. Also, let's add a test that includes MoE layers.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

oswaldoludwig and others added 14 commits April 28, 2026 10:10
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Adding class TestKappaTuneSelector.
Including test with MoE layers.
Moving KappaTuneSelector to helpers.py.
Sync with the latest version of find_kappa_target_modules.
Importing KappaTune from the new location.
Updating the new location of KappaTune.
Updating KappaTune location.
Updating KappaTune location.
Moved to tests/test_helpers.py.
Added a link to the absolute example path.
@oswaldoludwig
Copy link
Copy Markdown
Author

Thanks for the recent changes and for including the more extensive experiments. There isn't much that's still missing. I still had a few small comments, please check.

As for moving this to helpers.py, I think it should be quite straightfoward, maybe your IDE can even automate most of it. If there are any issues with that, don't hesitate to ask.

KappaTune was moved to "helpers.py" and all files with "from peft.utils.target_selection import find_kappa_target_modules" were updated (docs/source/package_reference/target_selection.md, src/peft/init.py, docs/source/developer_guides/lora.md, examples/KappaTune/experiments_kappatune_peft.py).

I hope everything is in order now for the PR to be merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants