Skip to content

Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo#13664

Open
terarachang wants to merge 15 commits intohuggingface:mainfrom
terarachang:cosmos_predict_2.5_lora_clean
Open

Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo#13664
terarachang wants to merge 15 commits intohuggingface:mainfrom
terarachang:cosmos_predict_2.5_lora_clean

Conversation

@terarachang
Copy link
Copy Markdown

What this PR does

Adds LoRA fine-tuning support for Cosmos Predict 2.5 (nvidia/Cosmos-Predict2.5-2B) and fixes the pipeline to match the official Cosmos reference implementation.

LoRA support

  • CosmosLoraLoaderMixin in src/diffusers/loaders/lora_pipeline.py for LoRA loading/saving on CosmosTransformer3DModel
  • Training script examples/cosmos/train_cosmos_predict25_lora.py using accelerate + peft
  • Inference script examples/cosmos/eval_cosmos_predict25_lora.py
  • Added CosmosLoraLoaderMixin to docs/source/en/api/loaders/lora.md

Fixes to match the official Cosmos repo

  • Fix conditional_frame_timestep scaling by timestep_scale=0.001
  • Auto-cast AdaLN and DiT final layer to fp32 for training
  • Deterministic VAE encode (no sampling)
  • Flash Attention 2 as the default attention implementation of the text encoder
  • Support invariant seeds via numpy noise sampling

Test plan

  • All repository consistency checks pass (check_copies, check_dummies, check_support_list)

@github-actions github-actions Bot added documentation Improvements or additions to documentation lora models pipelines examples schedulers loaders size/L PR with diff > 200 LOC labels Apr 30, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 2, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i left a question

Comment thread src/diffusers/models/transformers/transformer_cosmos.py Outdated
Comment thread src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py Outdated
Comment thread src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py Outdated
Comment thread src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py Outdated
Comment thread src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@terarachang terarachang force-pushed the cosmos_predict_2.5_lora_clean branch from 0cc6351 to c8513c2 Compare May 5, 2026 19:25
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026

device = sample.device
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1].to(device), self.sigmas[self.step_index].to(device)
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I think the change here might not be intended, no?
it seeem to have reverted https://github.com/huggingface/diffusers/pull/13489/changes

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good to me once we revert the change in scheculer

@yiyixuxu yiyixuxu requested a review from sayakpaul May 5, 2026 23:06
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some questions and suggestions. LMK if anything is unclear.

[[autodoc]] loaders.lora_pipeline.CosmosLoraLoaderMixin

## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figures should be hosted somewhere else.


@classmethod
@validate_hf_hub_args
def lora_state_dict(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to also use "# Copied from ..." comment here?

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict

else:
return state_dict

def load_lora_weights(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

safe_serialization=safe_serialization,
)

def fuse_lora(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Comment on lines +6096 to +6118
network_alphas = {}
for k in list(state_dict.keys()):
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)

if return_alphas or return_lora_metadata:
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
else:
return state_dict
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this setup in cosmos?

components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
)

def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our implementation doesn't need it. But if users trained multiple LoRA adaptors they may want to use this method. Would you suggest I remove unfuse_lora?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not suggesting to remove it. I am suggesting to supplement it with "# Copied from ..." statement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples loaders lora models pipelines schedulers size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants