Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo#13664
Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo#13664terarachang wants to merge 15 commits intohuggingface:mainfrom
Conversation
…ncoder attention implementation, and timestep scaling
… to device before torch.stack
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks, i left a question
0cc6351 to
c8513c2
Compare
|
|
||
| device = sample.device | ||
| sigma_t, sigma_s0 = self.sigmas[self.step_index + 1].to(device), self.sigmas[self.step_index].to(device) | ||
| sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] |
There was a problem hiding this comment.
ohh I think the change here might not be intended, no?
it seeem to have reverted https://github.com/huggingface/diffusers/pull/13489/changes
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks, looks good to me once we revert the change in scheculer
sayakpaul
left a comment
There was a problem hiding this comment.
I left some questions and suggestions. LMK if anything is unclear.
| [[autodoc]] loaders.lora_pipeline.CosmosLoraLoaderMixin | ||
|
|
||
| ## KandinskyLoraLoaderMixin | ||
| [[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin |
There was a problem hiding this comment.
Figures should be hosted somewhere else.
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| def lora_state_dict( |
There was a problem hiding this comment.
We should be able to also use "# Copied from ..." comment here?
| else: | ||
| return state_dict | ||
|
|
||
| def load_lora_weights( |
| safe_serialization=safe_serialization, | ||
| ) | ||
|
|
||
| def fuse_lora( |
| network_alphas = {} | ||
| for k in list(state_dict.keys()): | ||
| if "alpha" in k: | ||
| alpha_value = state_dict.get(k) | ||
| if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( | ||
| alpha_value, float | ||
| ): | ||
| network_alphas[k] = state_dict.pop(k) | ||
| else: | ||
| raise ValueError( | ||
| f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." | ||
| ) | ||
|
|
||
| if return_alphas or return_lora_metadata: | ||
| return cls._prepare_outputs( | ||
| state_dict, | ||
| metadata=metadata, | ||
| alphas=network_alphas, | ||
| return_alphas=return_alphas, | ||
| return_metadata=return_lora_metadata, | ||
| ) | ||
| else: | ||
| return state_dict |
There was a problem hiding this comment.
Do we need this setup in cosmos?
| components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs | ||
| ) | ||
|
|
||
| def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): |
There was a problem hiding this comment.
Our implementation doesn't need it. But if users trained multiple LoRA adaptors they may want to use this method. Would you suggest I remove unfuse_lora?
There was a problem hiding this comment.
I am not suggesting to remove it. I am suggesting to supplement it with "# Copied from ..." statement.
What this PR does
Adds LoRA fine-tuning support for Cosmos Predict 2.5 (nvidia/Cosmos-Predict2.5-2B) and fixes the pipeline to match the official Cosmos reference implementation.
LoRA support
CosmosLoraLoaderMixininsrc/diffusers/loaders/lora_pipeline.pyfor LoRA loading/saving onCosmosTransformer3DModelexamples/cosmos/train_cosmos_predict25_lora.pyusingaccelerate+peftexamples/cosmos/eval_cosmos_predict25_lora.pyCosmosLoraLoaderMixintodocs/source/en/api/loaders/lora.mdFixes to match the official Cosmos repo
conditional_frame_timestepscaling bytimestep_scale=0.001Test plan
check_copies,check_dummies,check_support_list)