ENH Optimize HRA forward for better memory efficiency#3185
Open
BenjaminBossan wants to merge 1 commit into
Open
ENH Optimize HRA forward for better memory efficiency#3185BenjaminBossan wants to merge 1 commit into
BenjaminBossan wants to merge 1 commit into
Conversation
Description HRA previously materialized the full dense Householder transform and multiplied it into the base weight on every forward pass. This can be expensive for big weight matrices. Moreover, the base weight was cast to the adapter dtype, which can create an extra copy at higher dtype. The new code is more in line with how LoRA calculates the forward step. Context I wanted to run HRA on the new image generation benchmark which is in the works. No matter what I did, it always OOM-ed, so I worked with Codex to find possible optimizations. The new Linear.forward looks quite reasonable to me, for Conv2d.forward, I can't say I fully grasp it but it seems to work. Results I ran some regression tests to determine that he results after the change are numerically close to the previous results. There can be deviations, especially at float16/bfloat16, but I think this is acceptable for the memory savings we see. Maybe there is a better way that is more faithful, LMK if you have an idea. Below are the results of running the new HRA on the MetaMathQA and the image generation benchmark compared to the current implementation. For opt-125, memory usage actually increases, but for more realistic sizes like Llama 3.2 3B or Flux 2 klein 4B, we see an improvement. | Benchmark | Model | HRA rank | Seq len | Optimization | Memory | Runtime | |------------------+-----------------+----------+---------+--------------+---------+--------------------| | MetaMath | Llama 3.2 3B | 32 | 768 | no | OOM | - | | MetaMath | Llama 3.2 3B | 32 | 768 | yes | OOM | - | | MetaMath | Llama 3.2 3B | 32 | 128 | no | OOM | - | | MetaMath | Llama 3.2 3B | 32 | 128 | yes | 23.5 GB | 71 sec / 250 steps | | MetaMath | Llama 3.2 3B | 32 | 64 | no | OOM | - | | MetaMath | Llama 3.2 3B | 32 | 64 | yes | 16 GB | 63 sec / 250 steps | | MetaMath | OPT-125M | 32 | 768 | no | 4.7 GB | 27 sec / 250 steps | | MetaMath | OPT-125M | 32 | 768 | yes | 9.9 GB | 24 sec / 250 steps | | Image generation | Flux 2 klein 4B | 32 | - | no | OOM | - | | Image generation | Flux 2 klein 4B | 32 | - | yes | 20 GB | - |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Member
Author
|
@DaShenZi721 Could you please take look? |
Member
Author
|
gentle ping @DaShenZi721 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
HRA previously materialized the full dense Householder transform and multiplied it into the base weight on every forward pass. This can be expensive for big weight matrices. Moreover, the base weight was cast to the adapter dtype, which can create an extra copy at higher dtype. The new code is more in line with how LoRA calculates the forward step.
Context
I wanted to run HRA on the new image generation benchmark which is in the works. No matter what I did, it always OOM-ed, so I worked with Codex to find possible optimizations. The new Linear.forward looks quite reasonable to me, for Conv2d.forward, I can't say I fully grasp it but it seems to work.
Results
I ran some regression tests to determine that he results after the change are numerically close to the previous results. There can be deviations, especially at float16/bfloat16, but I think this is acceptable for the memory savings we see. Maybe there is a better way that is more faithful, LMK if you have an idea.
Below are the results of running the new HRA on the MetaMathQA and the image generation benchmark compared to the current implementation. For opt-125, memory usage actually increases, but for more realistic sizes like Llama 3.2 3B or Flux 2 klein 4B, we see an improvement.
Note that for the final benchmarking results, we'll use a GPU with 48 GB, so that should hopefully not OOM.