Save checkpoint with TP#3096
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
045b9b1 to
e8aa041
Compare
e8aa041 to
96e7d5c
Compare
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the PR. I have a few comments, but overall it already looks good.
| from typing import Any, Literal, Optional | ||
|
|
||
| import torch | ||
| from data import get_train_valid_test_datasets, get_wiki_small |
There was a problem hiding this comment.
Let's undo changes to unrelated files like here and in utils.py.
There was a problem hiding this comment.
Done. It is the linter which does that.
| from .config import LoraConfig | ||
|
|
||
|
|
||
| @dataclass |
There was a problem hiding this comment.
How about we move this to utils/integration.py and give it a more generic name, as it's not strictly related to LoRA? Then we can also use it in save_and_load.py as the return type for _get_tp_info.
| model.to(device) | ||
|
|
||
| lora_config = LoraConfig(r=4, target_modules=TARGET_MODULES, init_lora_weights=True) | ||
| model = inject_adapter_in_model(lora_config, model) |
There was a problem hiding this comment.
model.load_adapter(lora_config) should work equally and is closer to what a normal user would do. Same applies to the inject_adapter_in_model use below.
There was a problem hiding this comment.
I did not get this comment.
There was a problem hiding this comment.
What I mean is that instead of calling inject_adapter_in_model(lora_config, model), we should call model.load_adapter(lora_config), because this is closer to what a user would normally do.
There was a problem hiding this comment.
I have worked on it, after some time, I have done some changes here as well: huggingface/transformers#45155
So this PR and the one in transformers are now connected.
|
@michaelbenayoun What's the state of the PR? |
|
Still working on it but it's not as trivial as you would think to add support for |
|
Got it, LMK if I can help. |
|
It is all good on my end. One thing: adding support for |
|
Great, thanks for the progress. So if my understanding is correct, we should wait for the Transformers PR to land and be released first, otherwise this PR won't work correctly. |
|
No, actually everything will work fine except the test that uses |
There was a problem hiding this comment.
Thanks for updating the PR and also cleaning up the TP-related code. Generally, this LGTM.
I ran the tests locally with Transformers from main (I also hard-coded is_transformers_ge_v5_6_0 = True so that tests would not be skipped). There, I got an error:
TypeError: add_tensor_parallel_hooks_to_module() takes 5 positional arguments but 6 were given
This is because of the clean up in huggingface/transformers#44768. So after removing the 2nd tp_plan argument in the add_tensor_parallel_hooks_to_module call, all tests passed. Could you please update the code?
Note: Failing CI is unrelated.
|
Yes I had actually created a branch for this specific change but will update it here and push. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for pushing the fix, LGTM. Failing CI is unrelated.
Enable state dict gather before saving checkpoint when doing TP.
Should wait for:
ParallelInterfacetransformers#44640cc @3outeille