Skip to content

ENH Optimize HRA forward for better memory efficiency#3185

Open
BenjaminBossan wants to merge 1 commit into
huggingface:mainfrom
BenjaminBossan:enh-optimize-hra-forward
Open

ENH Optimize HRA forward for better memory efficiency#3185
BenjaminBossan wants to merge 1 commit into
huggingface:mainfrom
BenjaminBossan:enh-optimize-hra-forward

Conversation

@BenjaminBossan
Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan commented Apr 22, 2026

Description

HRA previously materialized the full dense Householder transform and multiplied it into the base weight on every forward pass. This can be expensive for big weight matrices. Moreover, the base weight was cast to the adapter dtype, which can create an extra copy at higher dtype. The new code is more in line with how LoRA calculates the forward step.

Context

I wanted to run HRA on the new image generation benchmark which is in the works. No matter what I did, it always OOM-ed, so I worked with Codex to find possible optimizations. The new Linear.forward looks quite reasonable to me, for Conv2d.forward, I can't say I fully grasp it but it seems to work.

Results

I ran some regression tests to determine that he results after the change are numerically close to the previous results. There can be deviations, especially at float16/bfloat16, but I think this is acceptable for the memory savings we see. Maybe there is a better way that is more faithful, LMK if you have an idea.

Below are the results of running the new HRA on the MetaMathQA and the image generation benchmark compared to the current implementation. For opt-125, memory usage actually increases, but for more realistic sizes like Llama 3.2 3B or Flux 2 klein 4B, we see an improvement.

Benchmark Model HRA rank Seq len Optimization Memory Runtime
MetaMath Llama 3.2 3B 32 768 no OOM -
MetaMath Llama 3.2 3B 32 768 yes OOM -
MetaMath Llama 3.2 3B 32 128 no OOM -
MetaMath Llama 3.2 3B 32 128 yes 23.5 GB 71 sec / 250 steps
MetaMath Llama 3.2 3B 32 64 no OOM -
MetaMath Llama 3.2 3B 32 64 yes 16 GB 63 sec / 250 steps
MetaMath OPT-125M 32 768 no 4.7 GB 27 sec / 250 steps
MetaMath OPT-125M 32 768 yes 9.9 GB 24 sec / 250 steps
Image generation Flux 2 klein 4B 32 - no OOM -
Image generation Flux 2 klein 4B 32 - yes 20 GB -

Note that for the final benchmarking results, we'll use a GPU with 48 GB, so that should hopefully not OOM.

Description

HRA previously materialized the full dense Householder transform and
multiplied it into the base weight on every forward pass. This can be
expensive for big weight matrices. Moreover, the base weight was cast to
the adapter dtype, which can create an extra copy at higher dtype. The
new code is more in line with how LoRA calculates the forward step.

Context

I wanted to run HRA on the new image generation benchmark which is in
the works. No matter what I did, it always OOM-ed, so I worked with
Codex to find possible optimizations. The new Linear.forward looks quite
reasonable to me, for Conv2d.forward, I can't say I fully grasp it but
it seems to work.

Results

I ran some regression tests to determine that he results after the
change are numerically close to the previous results. There can be
deviations, especially at float16/bfloat16, but I think this is
acceptable for the memory savings we see. Maybe there is a better way
that is more faithful, LMK if you have an idea.

Below are the results of running the new HRA on the MetaMathQA and the
image generation benchmark compared to the current implementation. For
opt-125, memory usage actually increases, but for more realistic sizes
like Llama 3.2 3B or Flux 2 klein 4B, we see an improvement.

| Benchmark        | Model           | HRA rank | Seq len | Optimization | Memory  | Runtime            |
|------------------+-----------------+----------+---------+--------------+---------+--------------------|
| MetaMath         | Llama 3.2 3B    |       32 |     768 | no           | OOM     | -                  |
| MetaMath         | Llama 3.2 3B    |       32 |     768 | yes          | OOM     | -                  |
| MetaMath         | Llama 3.2 3B    |       32 |     128 | no           | OOM     | -                  |
| MetaMath         | Llama 3.2 3B    |       32 |     128 | yes          | 23.5 GB | 71 sec / 250 steps |
| MetaMath         | Llama 3.2 3B    |       32 |      64 | no           | OOM     | -                  |
| MetaMath         | Llama 3.2 3B    |       32 |      64 | yes          | 16 GB   | 63 sec / 250 steps |
| MetaMath         | OPT-125M        |       32 |     768 | no           | 4.7 GB  | 27 sec / 250 steps |
| MetaMath         | OPT-125M        |       32 |     768 | yes          | 9.9 GB  | 24 sec / 250 steps |
| Image generation | Flux 2 klein 4B |       32 |       - | no           | OOM     | -                  |
| Image generation | Flux 2 klein 4B |       32 |       - | yes          | 20 GB   | -                  |
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@DaShenZi721 Could you please take look?

@BenjaminBossan
Copy link
Copy Markdown
Member Author

gentle ping @DaShenZi721

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants