diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index c921e82f5e0d..c6113f8df023 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -132,6 +132,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin +## CosmosLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.CosmosLoraLoaderMixin + ## KandinskyLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin diff --git a/examples/cosmos/README.md b/examples/cosmos/README.md new file mode 100644 index 000000000000..e89b986e3fcc --- /dev/null +++ b/examples/cosmos/README.md @@ -0,0 +1,97 @@ +# LoRA fine-tuning for Cosmos Predict 2.5 + +This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset. + +## Requirements + +Install the library from source and the example-specific dependencies: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e ".[dev]" +cd examples/cosmos +pip install -r requirements.txt +``` + +## Data preparation + +The training script expects a dataset directory with the following layout: + +``` +/ +├── videos/ # .mp4 files +└── metas/ # one .txt prompt file per video (same stem) + ├── 0.txt + ├── 1.txt + └── ... +``` + +### GR1 dataset (quick start) + +The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files. + +```bash +bash download_and_preprocess_datasets.sh +``` + +This produces: +- `gr1_dataset/train/` — training videos + prompts +- `gr1_dataset/test/` — evaluation images + prompts + +## Training + +Launch LoRA training with `accelerate`: + +```bash +export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B" +export DATA_DIR="gr1_dataset/train" +export OUT_DIR="lora-output" + +accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --revision diffusers/base/post-trained \ + --train_data_dir=$DATA_DIR \ + --output_dir=$OUT_DIR \ + --train_batch_size=1 \ + --num_train_epochs=500 \ + --checkpointing_epochs=100 \ + --seed=0 \ + --height 432 --width 768 \ + --allow_tf32 \ + --gradient_checkpointing \ + --lora_rank 32 --lora_alpha 32 \ + --report_to=wandb +``` + +Or use the provided shell script: + +```bash +bash train_lora.sh +``` + +## Evaluation + +Run inference with the trained LoRA adapter: + +```bash +export DATA_DIR="gr1_dataset/test" +export LORA_DIR="lora-output" +export OUT_DIR="eval-output" + +python eval_cosmos_predict25_lora.py \ + --data_dir $DATA_DIR \ + --output_dir $OUT_DIR \ + --lora_dir $LORA_DIR \ + --revision diffusers/base/post-trained \ + --height 432 --width 768 \ + --num_output_frames 93 \ + --num_steps 36 \ + --seed 0 +``` + +Or use the provided shell script: + +```bash +bash eval_lora.sh +``` diff --git a/examples/cosmos/assets/figures/plot_IF.png b/examples/cosmos/assets/figures/plot_IF.png new file mode 100644 index 000000000000..5b4fa8f9d5a0 Binary files /dev/null and b/examples/cosmos/assets/figures/plot_IF.png differ diff --git a/examples/cosmos/assets/figures/plot_physics.png b/examples/cosmos/assets/figures/plot_physics.png new file mode 100644 index 000000000000..e9a4c540ea47 Binary files /dev/null and b/examples/cosmos/assets/figures/plot_physics.png differ diff --git a/examples/cosmos/assets/figures/plot_sampson.png b/examples/cosmos/assets/figures/plot_sampson.png new file mode 100644 index 000000000000..859a3f0dd963 Binary files /dev/null and b/examples/cosmos/assets/figures/plot_sampson.png differ diff --git a/examples/cosmos/assets/generated_videos/backbone_ex0.mp4 b/examples/cosmos/assets/generated_videos/backbone_ex0.mp4 new file mode 100644 index 000000000000..d38d465f6955 Binary files /dev/null and b/examples/cosmos/assets/generated_videos/backbone_ex0.mp4 differ diff --git a/examples/cosmos/assets/generated_videos/backbone_ex1.mp4 b/examples/cosmos/assets/generated_videos/backbone_ex1.mp4 new file mode 100644 index 000000000000..656dc7402963 Binary files /dev/null and b/examples/cosmos/assets/generated_videos/backbone_ex1.mp4 differ diff --git a/examples/cosmos/assets/generated_videos/dora_r32_ex0.mp4 b/examples/cosmos/assets/generated_videos/dora_r32_ex0.mp4 new file mode 100644 index 000000000000..64cf77d7dbec Binary files /dev/null and b/examples/cosmos/assets/generated_videos/dora_r32_ex0.mp4 differ diff --git a/examples/cosmos/assets/generated_videos/dora_r32_ex1.mp4 b/examples/cosmos/assets/generated_videos/dora_r32_ex1.mp4 new file mode 100644 index 000000000000..ad6c82aa20cc Binary files /dev/null and b/examples/cosmos/assets/generated_videos/dora_r32_ex1.mp4 differ diff --git a/examples/cosmos/assets/generated_videos/lora_r32_ex0.mp4 b/examples/cosmos/assets/generated_videos/lora_r32_ex0.mp4 new file mode 100644 index 000000000000..e84451dcce12 Binary files /dev/null and b/examples/cosmos/assets/generated_videos/lora_r32_ex0.mp4 differ diff --git a/examples/cosmos/assets/generated_videos/lora_r32_ex1.mp4 b/examples/cosmos/assets/generated_videos/lora_r32_ex1.mp4 new file mode 100644 index 000000000000..7d10e0c5c283 Binary files /dev/null and b/examples/cosmos/assets/generated_videos/lora_r32_ex1.mp4 differ diff --git a/examples/cosmos/create_prompts_for_gr1_dataset.py b/examples/cosmos/create_prompts_for_gr1_dataset.py new file mode 100644 index 000000000000..771cf4eda5b7 --- /dev/null +++ b/examples/cosmos/create_prompts_for_gr1_dataset.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from tqdm import tqdm + + +"""example command +python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1 +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Create text prompts for GR1 dataset") + parser.add_argument( + "--dataset_path", type=str, default="datasets/benchmark_train/gr1", help="Root path to the dataset" + ) + parser.add_argument( + "--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt" + ) + parser.add_argument( + "--meta_csv", type=str, default=None, help="Metadata csv file (defaults to /metadata.csv)" + ) + return parser.parse_args() + + +def main(args) -> None: + meta_csv = args.meta_csv or os.path.join(args.dataset_path, "metadata.csv") + meta_lines = open(meta_csv).readlines()[1:] + meta_txt_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(meta_txt_dir, exist_ok=True) + + for meta_line in tqdm(meta_lines): + video_filename, prompt = meta_line.split(",", 1) + prompt = prompt.strip("\n") + if prompt.startswith('"') and prompt.endswith('"'): + # Remove the quotes + prompt = prompt[1:-1] + prompt = args.prompt_prefix + prompt + meta_txt_filename = os.path.join(meta_txt_dir, os.path.basename(video_filename).replace(".mp4", ".txt")) + with open(meta_txt_filename, "w") as fp: + fp.write(prompt) + + print(f"encoding prompt: {prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/cosmos/download_and_preprocess_datasets.sh b/examples/cosmos/download_and_preprocess_datasets.sh new file mode 100644 index 000000000000..e43259f7a8af --- /dev/null +++ b/examples/cosmos/download_and_preprocess_datasets.sh @@ -0,0 +1,25 @@ +dataset_dir='gr1_dataset' +train_dir=$dataset_dir/train +test_dir=$dataset_dir/test + +# Download and Preprocess Training Dataset +hf download nvidia/GR1-100 --repo-type dataset --local-dir datasets/benchmark_train/hf_gr1/ && \ +mkdir -p datasets/benchmark_train/gr1/videos && \ +mv datasets/benchmark_train/hf_gr1/gr1/*mp4 datasets/benchmark_train/gr1/videos && \ +mv datasets/benchmark_train/hf_gr1/metadata.csv datasets/benchmark_train/gr1/ + +python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1 + +# Download Eval Dataset +hf download nvidia/EVAL-175 --repo-type dataset --local-dir dream_gen_benchmark + + +# Rename dataset directory +mkdir $dataset_dir +mv datasets/benchmark_train/gr1 $train_dir +mv dream_gen_benchmark/gr1_object $test_dir +echo Download training data to $train_dir +echo Download test data to $test_dir + +# Clean up staging directories +rm -rf datasets/ dream_gen_benchmark/ diff --git a/examples/cosmos/eval_cosmos_predict25_lora.py b/examples/cosmos/eval_cosmos_predict25_lora.py new file mode 100644 index 000000000000..24072b40a78e --- /dev/null +++ b/examples/cosmos/eval_cosmos_predict25_lora.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import torch +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from diffusers import Cosmos2_5_PredictBasePipeline +from diffusers.utils import export_to_video, load_image + + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} + + +class ImageDataset(Dataset): + """Dataset that loads images and their corresponding text prompts. + + Expects a directory with: + .jpg / .jpeg / .png — the conditioning image + .txt — the prompt text + """ + + def __init__(self, data_dir: str): + self.data_dir = data_dir + self.samples = [] + + for filename in sorted(os.listdir(data_dir)): + stem, ext = os.path.splitext(filename) + if ext.lower() not in IMAGE_EXTENSIONS: + continue + img_path = os.path.join(data_dir, filename) + txt_path = os.path.join(data_dir, stem + ".txt") + if not os.path.exists(txt_path): + print(f"WARNING: no prompt file found for {img_path}, skipping.") + continue + self.samples.append((img_path, txt_path, stem)) + + if len(self.samples) == 0: + raise ValueError(f"No valid image/prompt pairs found in {data_dir}") + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + img_path, txt_path, stem = self.samples[idx] + image = load_image(img_path) + with open(txt_path) as f: + prompt = f.read().strip() + return { + "image": image, + "prompt": prompt, + "stem": stem, + } + + +def collate_fn(batch): + """Keep images as a list (PIL images can't be stacked into a tensor).""" + return { + "images": [item["image"] for item in batch], + "prompts": [item["prompt"] for item in batch], + "stems": [item["stem"] for item in batch], + } + + +def parse_args(): + parser = argparse.ArgumentParser(description="Eval Cosmos Predict 2.5 with optional LoRA weights.") + + parser.add_argument("--data_dir", type=str, required=True, help="Directory with image/prompt pairs.") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to save generated outputs.") + parser.add_argument( + "--model_id", type=str, default="nvidia/Cosmos-Predict2.5-2B", help="HuggingFace model repository." + ) + parser.add_argument( + "--revision", + type=str, + default="diffusers/base/post-trained", + choices=["diffusers/base/post-trained", "diffusers/base/pre-trained"], + ) + parser.add_argument("--lora_dir", type=str, default=None, help="Path to LoRA weights directory.") + parser.add_argument("--num_output_frames", type=int, default=93, help="1 for image output, 93 for video output.") + parser.add_argument("--num_steps", type=int, default=36, help="Number of inference steps.") + parser.add_argument("--height", type=int, default=704, help="Output height in pixels (must be divisible by 16).") + parser.add_argument("--width", type=int, default=1280, help="Output width in pixels (must be divisible by 16).") + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--device", type=str, default="cuda", help="Device to use.") + parser.add_argument("--batch_size", type=int, default=1, help="Number of samples per batch.") + parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker processes.") + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="Negative prompt. Defaults to the pipeline's built-in negative prompt.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + + dataset = ImageDataset(args.data_dir) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=collate_fn, + ) + + print(f"Found {len(dataset)} examples.") + + class MockSafetyChecker: + def to(self, *args, **kwargs): + return self + + def check_text_safety(self, *args, **kwargs): + return True + + def check_video_safety(self, video): + return video + + pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + args.model_id, + revision=args.revision, + device_map=args.device, + torch_dtype=torch.bfloat16, + safety_checker=MockSafetyChecker(), + ) + + if args.lora_dir is not None: + pipe.load_lora_weights(args.lora_dir) + pipe.fuse_lora(lora_scale=1.0) + print(f"Loaded LoRA weights from {args.lora_dir}") + + progress = tqdm(total=len(dataset), desc="Generating") + for batch in dataloader: + images = batch["images"] + prompts = batch["prompts"] + stems = batch["stems"] + + for image, prompt, stem in zip(images, prompts, stems): + frames = pipe( + image=image, + prompt=prompt, + negative_prompt=args.negative_prompt, + num_frames=args.num_output_frames, + num_inference_steps=args.num_steps, + height=args.height, + width=args.width, + ).frames[0] # NOTE: batch_size == 1 + + out_path = os.path.join(args.output_dir, f"{stem}.mp4") + export_to_video(frames, out_path, fps=16) + + tqdm.write(f" Saved to: {out_path}") + progress.update(1) + + +if __name__ == "__main__": + main() diff --git a/examples/cosmos/eval_lora.sh b/examples/cosmos/eval_lora.sh new file mode 100644 index 000000000000..07e79a421238 --- /dev/null +++ b/examples/cosmos/eval_lora.sh @@ -0,0 +1,15 @@ +export DATA_DIR="gr1_dataset/test" +export LORA_DIR=YOUR_ADAPTER_DIR +export OUT_DIR=YOUR_EVAL_OUTPUT_DIR +revision="post-trained" + +export TOKENIZERS_PARALLELISM=false +python eval_cosmos_predict25_lora.py \ + --data_dir $DATA_DIR \ + --output_dir $OUT_DIR \ + --lora_dir $LORA_DIR \ + --revision diffusers/base/$revision \ + --height 432 --width 768 \ + --num_output_frames 93 \ + --num_steps 36 \ + --seed 0 diff --git a/examples/cosmos/llm_judge_prompts/video_IF.yaml b/examples/cosmos/llm_judge_prompts/video_IF.yaml new file mode 100644 index 000000000000..6c76004d5e64 --- /dev/null +++ b/examples/cosmos/llm_judge_prompts/video_IF.yaml @@ -0,0 +1,28 @@ +system_prompt: "You are a helpful assistant." +user_prompt: | + You are a helpful video analyzer. Evaluate whether the video follows the given instruction. + + Instruction: {instruction} + + Evaluation Criteria: + 1. **Task Completion:** Does the video show the task described in the instruction being completed? + 2. **Action Accuracy:** Are the actions performed in the video consistent with what the instruction specifies? + 3. **Object Interaction:** Does the robot or agent interact with the correct objects as described in the instruction? + 4. **Goal Achievement:** Is the final state of the video consistent with the expected outcome of the instruction? + 5. **Correct Hand Usage:** Does the video show the correct hand performing the action? + + Instructions for Scoring: + - **1:** No adherence to the instruction. The video shows actions completely unrelated to the instruction. + - **2:** Poor adherence. Some elements match the instruction, but major deviations are present. + - **3:** Moderate adherence. The video follows the instruction for the most part but contains noticeable deviations. + - **4:** Good adherence. Most elements in the video match the instruction, with only minor issues. + - **5:** Perfect adherence. The video fully follows the instruction with no deviations. + + Response Template: + Analyze the video carefully and answer the question according to the following template: + [Score between 1 and 5.] + + Example Response: + 2 + + Does this video follow the instruction? diff --git a/examples/cosmos/llm_judge_prompts/video_physics.yaml b/examples/cosmos/llm_judge_prompts/video_physics.yaml new file mode 100644 index 000000000000..4a87a0f102d3 --- /dev/null +++ b/examples/cosmos/llm_judge_prompts/video_physics.yaml @@ -0,0 +1,25 @@ +system_prompt: "You are a helpful assistant." +user_prompt: | + You are a helpful video analyzer. Evaluate whether the video follows physical commonsense. + + Evaluation Criteria: + 1. **Object Behavior:** Do objects behave according to their expected physical properties (e.g., rigid objects do not deform unnaturally, fluids flow naturally)? + 2. **Motion and Forces:** Are motions and forces depicted in the video consistent with real-world physics (e.g., gravity, inertia, conservation of momentum)? + 3. **Interactions:** Do objects interact with each other and their environment in a plausible manner (e.g., no unnatural penetration, appropriate reactions on impact)? + 4. **Consistency Over Time:** Does the video maintain consistency across frames without abrupt, unexplainable changes in object behavior or motion? + + Instructions for Scoring: + - **1:** No adherence to physical commonsense. The video contains numerous violations of fundamental physical laws. + - **2:** Poor adherence. Some elements follow physics, but major violations are present. + - **3:** Moderate adherence. The video follows physics for the most part but contains noticeable inconsistencies. + - **4:** Good adherence. Most elements in the video follow physical laws, with only minor issues. + - **5:** Perfect adherence. The video demonstrates a strong understanding of physical commonsense with no violations. + + Response Template: + Analyze the video carefully and answer the question according to the following template: + [Score between 1 and 5.] + + Example Response: + 2 + + Does this video adhere to the physical laws? diff --git a/examples/cosmos/requirements.txt b/examples/cosmos/requirements.txt new file mode 100644 index 000000000000..7fb57273e4c6 --- /dev/null +++ b/examples/cosmos/requirements.txt @@ -0,0 +1,15 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch +torchvision +accelerate>=0.31.0 +huggingface_hub +imageio +imageio-ffmpeg +transformers>=4.41.2 +peft>=0.11.1 +datasets +numpy +tqdm +sentencepiece +tensorboard +wandb diff --git a/examples/cosmos/train_cosmos_predict25_lora.py b/examples/cosmos/train_cosmos_predict25_lora.py new file mode 100644 index 000000000000..a4a6d9d637b6 --- /dev/null +++ b/examples/cosmos/train_cosmos_predict25_lora.py @@ -0,0 +1,751 @@ +import argparse +import json +import logging +import math +import os +import random +from pathlib import Path +from typing import Any, Optional + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from torch.utils.data import DataLoader, Dataset +from tqdm.auto import tqdm + +import diffusers +from diffusers import Cosmos2_5_PredictBasePipeline +from diffusers.optimization import get_linear_schedule_with_warmup +from diffusers.training_utils import cast_training_params +from diffusers.utils import ( + convert_state_dict_to_diffusers, + export_to_video, + load_video, +) +from diffusers.video_processor import VideoProcessor + + +logger = get_logger(__name__, log_level="INFO") + + +class MockSafetyChecker: + def to(self, *args, **kwargs): + return self + + def check_text_safety(self, *args, **kwargs): + return True + + def check_video_safety(self, video): + return video + + +def arch_invariant_rand(shape, dtype, device, seed=None): + rng = np.random.RandomState(seed) + random_array = rng.standard_normal(shape).astype(np.float32) + return torch.from_numpy(random_array).to(dtype=dtype, device=device) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default="nvidia/Cosmos-Predict2.5-2B", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default="diffusers/base/post-trained", + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default="datasets/cosmos_nemo_assets", + help=("A folder containing the training data."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--conditional_frame_timestep", + type=float, + default=0.0001, + help="0.0001 for post-trained model. Set to < 0 to disable.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_epochs", + type=int, + default=20, + help="Save a checkpoint of the training state every X epochs.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=32, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help=("The alpha parameter for Lora scaling."), + ) + parser.add_argument( + "--use_dora", + action="store_true", + help="Whether or not to use DoRA (Weight-Decomposed Low-Rank Adaptation).", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=36, + help="Number of denoising steps during final eval inference.", + ) + parser.add_argument("--height", type=int, default=704, help="Height of the training videos in pixels.") + parser.add_argument("--width", type=int, default=1280, help="Width of the training videos in pixels.") + parser.add_argument("--num_frames", type=int, default=93, help="Number of frames per training video.") + parser.add_argument( + "--cfg_dropout_prob", + type=float, + default=0.2, + help="Probability of dropping text or video conditioning per sample for CFG training.", + ) + parser.add_argument( + "--conditional_frames_probs", + type=json.loads, + default={1: 0.5, 2: 0.5}, + help=( + "JSON dict mapping number of conditional frames to sampling probability. " + "Default {1: 0.5, 2: 0.5} trains Image2World and Video2World equally." + ), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2 ** (-14.5), + help="Learning rate for the AdamW optimizer used in build_optimizer_and_scheduler.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.001, + help="Weight decay for the AdamW optimizer used in build_optimizer_and_scheduler.", + ) + parser.add_argument( + "--scheduler_warm_up_steps", + type=int, + default=1000, + help="Number of warmup steps for the linear LR scheduler.", + ) + parser.add_argument( + "--num_training_steps", + type=int, + default=100000, + help="Total number of training steps for the LR scheduler.", + ) + parser.add_argument( + "--scheduler_f_max", + type=float, + default=0.5, + help="Maximum LR multiplier (peak after warmup) for the linear scheduler.", + ) + parser.add_argument( + "--scheduler_f_min", + type=float, + default=0.2, + help="Minimum LR multiplier (floor of linear decay) for the linear scheduler.", + ) + parser.add_argument( + "--do_final_eval", + action="store_true", + help="Whether to run inference on a training sample after training completes.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.use_dora: + args.output_dir = args.output_dir + "-dora" + + return args + + +class VideoDataset(Dataset): + def __init__( + self, + dataset_dir: str, + num_frames: int, + video_size: tuple[int, int], + prompt_type: str | None = None, # "long", "short", "medium", or None for auto + caption_format: str = "auto", # "text", "json", or "auto" + video_paths: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.dataset_dir = dataset_dir + self.num_frames = num_frames + self.prompt_type = prompt_type + self.caption_format = caption_format + + # Determine caption format and directory + self._setup_caption_format() + + video_dir = os.path.join(self.dataset_dir, "videos") + + if video_paths is None: + self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] + self.video_paths = sorted(self.video_paths) + else: + self.video_paths = video_paths + logger.info(f"{len(self.video_paths)} videos in total", main_process_only=True) + + self.video_size = video_size + self.video_processor = VideoProcessor(vae_scale_factor=8, resample="bilinear") + self.num_failed_loads = 0 + + def __str__(self) -> str: + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def __len__(self) -> int: + return len(self.video_paths) + + def _load_video(self, video_path: str) -> list: + frames = load_video(video_path) + total_frames = len(frames) + if total_frames < self.num_frames: + raise ValueError( + f"Video {video_path} has only {total_frames} frames, at least {self.num_frames} frames are required." + ) + + # randomly sample a consecutive window of frames + max_start_idx = total_frames - self.num_frames + start_frame = np.random.randint(0, max_start_idx + 1) + return frames[start_frame : start_frame + self.num_frames] + + def _setup_caption_format(self) -> None: + """Determine the caption format and set up the caption directory.""" + metas_dir = os.path.join(self.dataset_dir, "metas") + captions_dir = os.path.join(self.dataset_dir, "captions") + + if self.caption_format == "auto": + # Auto-detect based on directory existence + if os.path.exists(captions_dir) and any(f.endswith(".json") for f in os.listdir(captions_dir)): + self.caption_format = "json" + self.caption_dir = captions_dir + elif os.path.exists(metas_dir) and any(f.endswith(".txt") for f in os.listdir(metas_dir)): + self.caption_format = "text" + self.caption_dir = metas_dir + else: + raise ValueError( + f"Could not auto-detect caption format. Neither 'metas/*.txt' nor 'captions/*.json' found in {self.dataset_dir}" + ) + elif self.caption_format == "json": + if not os.path.exists(captions_dir): + raise ValueError(f"JSON format specified but 'captions' directory not found in {self.dataset_dir}") + self.caption_dir = captions_dir + elif self.caption_format == "text": + if not os.path.exists(metas_dir): + raise ValueError(f"Text format specified but 'metas' directory not found in {self.dataset_dir}") + self.caption_dir = metas_dir + else: + raise ValueError(f"Invalid caption_format: {self.caption_format}. Must be 'text', 'json', or 'auto'") + + def _load_text(self, text_source: Path) -> str: + """Load text caption from file.""" + try: + return text_source.read_text().strip() + except Exception as e: + print(f"Failed to read caption file {text_source}: {e}") + return "" + + def _load_json_caption(self, json_path: Path) -> str: + """Load caption from JSON file with prompt type selection.""" + try: + with open(json_path, "r") as f: + data = json.load(f) + + # Get the first model's captions (e.g., "qwen3_vl_30b_a3b") + model_key = next(iter(data.keys())) + captions = data[model_key] + + if self.prompt_type: + # Use specified prompt type + if self.prompt_type in captions: + return captions[self.prompt_type] + else: + print( + f"Prompt type '{self.prompt_type}' not found in {json_path}. " + f"Available: {list(captions.keys())}. Using first available." + ) + + # Use first available prompt type + first_prompt = next(iter(captions.values())) + return first_prompt + + except Exception as e: + print(f"Failed to read JSON caption file {json_path}: {e}") + return "" + + def _get_frames(self, video_path: str) -> torch.Tensor: + frames = self._load_video(video_path) # list of PIL images + video = self.video_processor.preprocess_video(frames, height=self.video_size[0], width=self.video_size[1]) + # video: [1, C, T, H, W] in [-1, 1] + return video.squeeze(0) # [C, T, H, W] + + def __getitem__(self, index: int) -> dict | Any: + try: + data = {} + video = self._get_frames(self.video_paths[index]) # [C, T, H, W] + + # Load caption based on format + video_path = self.video_paths[index] + video_basename = os.path.splitext(os.path.basename(video_path))[0] + + if self.caption_format == "json": + caption_path = os.path.join(self.caption_dir, f"{video_basename}.json") + caption = self._load_json_caption(Path(caption_path)) + else: # text format + caption_path = os.path.join(self.caption_dir, f"{video_basename}.txt") + caption = self._load_text(Path(caption_path)) + + data["video"] = video + data["caption"] = caption + + return data + except Exception as e: + self.num_failed_loads += 1 + print(f"Failed to load video {self.video_paths[index]} (total failures: {self.num_failed_loads}): {e}\n") + # Randomly sample another video + return self[np.random.randint(len(self.video_paths))] + + +def build_dataloader(args): + dataset = VideoDataset( + video_paths=None, + num_frames=args.num_frames, + video_size=[args.height, args.width], + dataset_dir=args.train_data_dir, + ) + + dataloader = DataLoader( + dataset=dataset, + shuffle=True, + batch_size=args.train_batch_size, + drop_last=False, + num_workers=args.dataloader_num_workers, + pin_memory=True, + ) + return dataloader + + +def get_flow_xt_and_target_v(clean_latent, t, cond_mask): + # https://github.com/nvidia-cosmos/cosmos-predict2.5/blob/main/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py#L779 + noise = torch.randn_like(clean_latent) + target_velocity = noise - clean_latent + xt_B_C_T_H_W = noise * t + clean_latent * (1 - t) + + # https://github.com/nvidia-cosmos/cosmos-predict2.5/blob/main/cosmos_predict2/_src/predict2/models/video2world_model_rectified_flow.py#L104 + xt_B_C_T_H_W = clean_latent * cond_mask + xt_B_C_T_H_W * (1 - cond_mask) + return xt_B_C_T_H_W, target_velocity + + +def sample_train_sigma_t(batch_size, distribution, device, dtype=torch.float32, shift=5): + if distribution == "uniform": + t = torch.rand((batch_size,)).to(device=device, dtype=dtype) + elif distribution == "logitnormal": + t = torch.sigmoid(torch.randn((batch_size,))).to(device=device, dtype=dtype) + else: + raise NotImplementedError(f"Time distribution {distribution} is not implemented.") + sigma_t = shift * t / (1 + (shift - 1) * t) # 0.0 <= sigma_t <= 1.0 + return sigma_t.view(batch_size, 1, 1, 1, 1) + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + print("-" * 100) + print(args) + print("-" * 100) + + # Initialize models + pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=torch.bfloat16, + safety_checker=MockSafetyChecker(), + ) + + dit = pipe.transformer + vae = pipe.vae + text_encoder = pipe.text_encoder + + dit.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + target_modules_list = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"] + dit_lora_config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=target_modules_list, + use_dora=args.use_dora, + ) + logger.info( + f"Add LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, targets={target_modules_list}, use_dora={args.use_dora}" + ) + + device = accelerator.device + dit.to(device) + vae.to(device) + text_encoder.to(device) + dit_dtype = dit.dtype + + # Add adapter and make sure the trainable params are in float32. + dit.add_adapter(dit_lora_config) + + if accelerator.mixed_precision in ["fp16", "bf16"]: + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(dit, dtype=torch.float32) + + lora_params = [p for p in dit.parameters() if p.requires_grad] + num_trainable_params = sum(p.numel() for p in lora_params) + + if args.gradient_checkpointing: + dit.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate, weight_decay=args.weight_decay) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=args.scheduler_warm_up_steps, + num_training_steps=args.num_training_steps, + f_min=args.scheduler_f_min, + f_max=args.scheduler_f_max, + ) + + train_dataloader = build_dataloader(args) + + # Prepare everything with our `accelerator`. + dit, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + dit, optimizer, train_dataloader, lr_scheduler + ) + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + assert len(models) == 1, f"Expected only one model to save, got {len(models)}" + dit_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0])) + weights.pop() + Cosmos2_5_PredictBasePipeline.save_lora_weights( + save_directory=output_dir, + transformer_lora_layers=dit_lora_state_dict, + safe_serialization=True, + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + + if accelerator.is_main_process: + accelerator.init_trackers("diffusers-lora", config=vars(args)) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataloader.dataset)}") + logger.info(f" Video shape = {(args.height, args.width, args.num_frames)}") + logger.info(f" Total Trainable Parameters: {num_trainable_params / 10**9:.2f}B") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Gradient Checkpointing = {args.gradient_checkpointing}, allow_tf32 = {args.allow_tf32}") + logger.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + progress_bar = tqdm( + range(0, max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + padding_mask = torch.zeros(1, 1, args.height, args.width, dtype=dit_dtype, device=device) + latent_shape = ( + pipe.vae.config.z_dim, + (args.num_frames - 1) // pipe.vae_scale_factor_temporal + 1, + args.height // pipe.vae_scale_factor_spatial, + args.width // pipe.vae_scale_factor_spatial, + ) + latents_mean = pipe.latents_mean.float().to(device) + latents_std = pipe.latents_std.float().to(device) # 1/σ + # Start training + torch.set_grad_enabled(True) # re-enable grad disabled by Cosmos2_5_PredictBasePipeline + for epoch in range(first_epoch, args.num_train_epochs): + dit.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(dit): + # Encode ground-truth video to latents + # https://github.com/nvidia-cosmos/cosmos-predict2.5/blob/main/cosmos_predict2/_src/predict2/tokenizers/wan2pt1.py#L532 + raw_state = batch["video"].to(device=device, dtype=vae.dtype) + mu = vae.encode(raw_state).latent_dist.mean # deterministic + clean_latent = ((mu - latents_mean) * latents_std).contiguous().float() + assert not clean_latent.requires_grad + torch.cuda.empty_cache() + + # Encode text to text embeddings + prompt_embeds = pipe._get_prompt_embeds( + prompt=batch["caption"], + device=device, + ) + assert not prompt_embeds.requires_grad + + # CFG dropout: independently zero out text conditioning per sample + bsz = clean_latent.shape[0] + is_drop = torch.rand(bsz, device=device) < args.cfg_dropout_prob + prompt_embeds[is_drop] = 0.0 + + # Create indicator and mask to make the first few frames of x_t be the ground truth frames + frames_options = list(args.conditional_frames_probs.keys()) + weights = list(args.conditional_frames_probs.values()) + num_conditional_frames = random.choices(frames_options, weights=weights, k=bsz) + cond_indicator, cond_mask = pipe.create_condition_mask( + (bsz, *latent_shape), + device=device, + dtype=torch.float32, + num_cond_latent_frames=num_conditional_frames, + ) + + # Sample a random timestep + sigma_t = sample_train_sigma_t(bsz, distribution="logitnormal", device=device) + # 1. Sample noise 2. Get the target velocity 3. Get xt by interpolation between noise and clean + xt_B_C_T_H_W, target_velocity = get_flow_xt_and_target_v(clean_latent, sigma_t, cond_mask) + + # Denoise + if args.conditional_frame_timestep >= 0: + in_timestep = cond_indicator * args.conditional_frame_timestep + (1 - cond_indicator) * sigma_t + + pred_velocity = dit( + hidden_states=xt_B_C_T_H_W, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # Loss is only calculated on the non-conditioned frames + pred_velocity = target_velocity * cond_mask + pred_velocity * (1 - cond_mask) + loss = F.mse_loss(pred_velocity.float(), target_velocity.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = lora_params + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + if (epoch + 1) % args.checkpointing_epochs == 0 and (epoch + 1) < args.num_train_epochs: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{epoch}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # After Training + accelerator.wait_for_everyone() + if accelerator.is_main_process: + # Save the lora layers + unwrapped_dit = accelerator.unwrap_model(dit) + dit_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_dit)) + Cosmos2_5_PredictBasePipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=dit_lora_state_dict, + safe_serialization=True, + ) + + if args.do_final_eval: + noises = arch_invariant_rand((1, *latent_shape), dtype=torch.float32, device=device, seed=args.seed) + inputs = train_dataloader.dataset[0] + + pipe.transformer.eval() + with torch.inference_mode(): + frames = pipe( + image=None, + video=inputs["video"].unsqueeze(0).to(device), + prompt=inputs["caption"], + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + latents=noises, # ensure architecture invariant generation + height=args.height, + width=args.width, + ).frames[0] + + export_to_video(frames, os.path.join(args.output_dir, "eval_output.mp4"), fps=16) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/cosmos/train_lora.sh b/examples/cosmos/train_lora.sh new file mode 100644 index 000000000000..813bd4938d08 --- /dev/null +++ b/examples/cosmos/train_lora.sh @@ -0,0 +1,18 @@ +export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B" +export DATA_DIR="gr1_dataset/train" +export OUT_DIR=YOUR_OUTPUT_DIR +lora_rank=32 +revision="diffusers/base/post-trained" + +export TOKENIZERS_PARALLELISM=false +accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME --revision $revision \ + --train_data_dir=$DATA_DIR \ + --train_batch_size=1 \ + --num_train_epochs=500 --checkpointing_epochs=100 \ + --seed=0 \ + --output_dir=$OUT_DIR \ + --report_to=wandb \ + --height 432 --width 768 \ + --allow_tf32 --gradient_checkpointing \ + --lora_rank $lora_rank --lora_alpha $lora_rank diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index f6a070682168..488f77422dcd 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -86,6 +86,7 @@ def text_encoder_attn_modules(text_encoder): "ZImageLoraLoaderMixin", "Flux2LoraLoaderMixin", "ErnieImageLoraLoaderMixin", + "CosmosLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -118,6 +119,7 @@ def text_encoder_attn_modules(text_encoder): AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, + CosmosLoraLoaderMixin, ErnieImageLoraLoaderMixin, Flux2LoraLoaderMixin, FluxLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ac9383728802..ca4296699faa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -6040,6 +6040,238 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class CosmosLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CosmosTransformer3DModel`], Specific to [`Cosmos2_5_PredictBasePipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + return_alphas: bool = False, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + network_alphas = {} + for k in list(state_dict.keys()): + if "alpha" in k: + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + network_alphas[k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) + + if return_alphas or return_lora_metadata: + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=network_alphas, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + else: + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) + + if not any("lora" in key for key in state_dict.keys()): + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.save_lora_weights + @classmethod + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the transformer. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + """ + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers={cls.transformer_name: transformer_lora_layers}, + lora_metadata={cls.transformer_name: transformer_lora_adapter_metadata}, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs + ) + + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._prepare_outputs + @staticmethod + def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): + outputs = [state_dict] + if return_alphas: + outputs.append(alphas) + if return_metadata: + outputs.append(metadata) + return tuple(outputs) if (return_alphas or return_metadata) else state_dict + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 46746a19a678..0bb1e40e8b66 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -17,7 +17,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import is_torchvision_available from ..attention import FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -74,8 +74,8 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None: self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim) self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True) - def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor: - timesteps_proj = self.time_proj(timestep).type_as(hidden_states) + def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep.float()) temb = self.t_embedder(timesteps_proj) embedded_timestep = self.norm(timesteps_proj) return temb, embedded_timestep @@ -102,6 +102,7 @@ def forward( embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim] shift, scale = embedded_timestep.chunk(2, dim=-1) + hidden_states = self.norm(hidden_states) if embedded_timestep.ndim == 2: @@ -131,14 +132,16 @@ def forward( embedded_timestep: torch.Tensor, temb: torch.Tensor | None = None, ) -> torch.Tensor: - embedded_timestep = self.activation(embedded_timestep) + original_dtype = hidden_states.dtype + embedded_timestep = self.activation(embedded_timestep.float()) embedded_timestep = self.linear_1(embedded_timestep) embedded_timestep = self.linear_2(embedded_timestep) - if temb is not None: - embedded_timestep = embedded_timestep + temb - + embedded_timestep = embedded_timestep + temb.float() shift, scale, gate = embedded_timestep.chunk(3, dim=-1) + shift = shift.to(original_dtype) + scale = scale.to(original_dtype) + gate = gate.to(original_dtype) hidden_states = self.norm(hidden_states) if embedded_timestep.ndim == 2: @@ -181,8 +184,11 @@ def __call__( if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) - key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + original_dtype = query.dtype + query = apply_rotary_emb(query.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + key = apply_rotary_emb(key.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + query = query.to(original_dtype) + key = key.to(original_dtype) # 4. Prepare for GQA if torch.onnx.is_in_onnx_export(): @@ -248,8 +254,11 @@ def __call__( if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) - key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + original_dtype = query.dtype + query = apply_rotary_emb(query.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + key = apply_rotary_emb(key.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + query = query.to(original_dtype) + key = key.to(original_dtype) if torch.onnx.is_in_onnx_export(): query_idx = torch.tensor(query.size(3), device=query.device) @@ -551,7 +560,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return (emb / norm).type_as(hidden_states) -class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos). @@ -599,7 +608,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"] _no_split_modules = ["CosmosTransformerBlock"] - _keep_in_fp32_modules = ["learnable_pos_embed"] + _keep_in_fp32_modules = ["learnable_pos_embed", "time_embed", "norm1", "norm2", "norm3", "norm_out", "proj_out"] @register_to_config def __init__( @@ -797,7 +806,7 @@ def forward( ) # 8. Output norm & projection & unpatchify - hidden_states = self.norm_out(hidden_states, embedded_timestep, temb) + hidden_states = self.norm_out(hidden_states.float(), embedded_timestep, temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1)) hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width)) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 044bb0db1908..a4b03bf469e4 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -120,7 +120,12 @@ def rule_func(steps: int) -> float: def get_linear_schedule_with_warmup( - optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1 + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + f_min: float = 0.0, + f_max: float = 1.0, ) -> LambdaLR: """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after @@ -135,6 +140,10 @@ def get_linear_schedule_with_warmup( The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. + f_min (`float`, *optional*, defaults to 0.0): + Minimum lr multiplier (floor of the linear decay). The lr will not fall below `f_min * initial_lr`. + f_max (`float`, *optional*, defaults to 1.0): + Maximum lr multiplier (peak reached after warmup). The lr peaks at `f_max * initial_lr`. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. @@ -142,10 +151,9 @@ def get_linear_schedule_with_warmup( def lr_lambda(current_step: int): if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) - ) + return f_max * float(current_step) / float(max(1, num_warmup_steps)) + progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + return f_min + (f_max - f_min) * max(0.0, progress) return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py index 581711205814..680a219101d2 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -20,6 +20,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import CosmosLoraLoaderMixin from ...models import AutoencoderKLWan, CosmosTransformer3DModel from ...schedulers import UniPCMultistepScheduler from ...utils import ( @@ -181,7 +182,7 @@ def retrieve_latents( """ -class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline, CosmosLoraLoaderMixin): r""" Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. @@ -233,23 +234,22 @@ def __init__( self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear") - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() - if getattr(self.vae.config, "latents_mean", None) is not None - else None - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() - if getattr(self.vae.config, "latents_std", None) is not None - else None - ) + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() self.latents_mean = latents_mean - self.latents_std = latents_std - - if self.latents_mean is None or self.latents_std is None: - raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + self.latents_std = 1.0 / latents_std + + def create_condition_mask(self, latent_shape, device, dtype, num_cond_latent_frames): + bsz, C, T, H, W = latent_shape + cond_indicator = torch.zeros(bsz, 1, T, 1, 1, dtype=dtype, device=device) + if isinstance(num_cond_latent_frames, int): + num_cond_latent_frames = [num_cond_latent_frames] * bsz + for idx in range(bsz): + cond_indicator[idx, :, : num_cond_latent_frames[idx], :, :] = 1.0 + cond_mask = cond_indicator.expand(-1, -1, -1, H, W) + return cond_indicator, cond_mask def _get_prompt_embeds( self, @@ -455,34 +455,33 @@ def prepare_latents( needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) if needs_preprocessing: video = self.video_processor.preprocess_video(video, height, width) - video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): cond_latents = [ - retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + retrieve_latents( + self.vae.encode(video[i].unsqueeze(0)), generator=generator[i], sample_mode="argmax" + ) for i in range(batch_size) ] else: - cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + cond_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator, sample_mode="argmax") + for vid in video + ] cond_latents = torch.cat(cond_latents, dim=0).to(dtype) latents_mean = self.latents_mean.to(device=device, dtype=dtype) latents_std = self.latents_std.to(device=device, dtype=dtype) - cond_latents = (cond_latents - latents_mean) / latents_std + cond_latents = (cond_latents - latents_mean) * latents_std if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) - padding_shape = (B, 1, T, H, W) - ones_padding = latents.new_ones(padding_shape) - zeros_padding = latents.new_zeros(padding_shape) - num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 - cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) - cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 - cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + cond_indicator, cond_mask = self.create_condition_mask(shape, device, dtype, num_cond_latent_frames) return ( latents, @@ -565,7 +564,7 @@ def __call__( callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, - conditional_frame_timestep: float = 0.1, + conditional_frame_timestep: float = 0.0001, num_latent_conditional_frames: int = 2, ): r""" @@ -700,20 +699,17 @@ def __call__( vae_dtype = self.vae.dtype transformer_dtype = self.transformer.dtype + is_video = video is not None + is_image = image is not None - num_frames_in = None - if image is not None: - if batch_size != 1: - raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") - + if is_image: image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) video = video.unsqueeze(0) + video = self.video_processor.preprocess_video(video, height, width) num_frames_in = 1 - elif video is None: - video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) - num_frames_in = 0 - else: + + elif is_video: if batch_size != 1: raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") @@ -722,34 +718,31 @@ def __call__( f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}" ) - frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1 - - total_input_frames = len(video) + # List of num_frames images -> tensor of shape [B, C, T, H, W] + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + # For Video2World: extract last frames_to_extract frames from input, then pad + frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1 + total_input_frames = video.shape[2] if total_input_frames < frames_to_extract: raise ValueError( f"Input video has only {total_input_frames} frames but Video2World requires at least " f"{frames_to_extract} frames for conditioning." ) + video = video[:, :, -frames_to_extract:, :, :] + if video.shape[2] < num_frames: + n_pad_frames = num_frames - video.shape[2] + last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) num_frames_in = frames_to_extract - assert video is not None - video = self.video_processor.preprocess_video(video, height, width) - - # For Video2World: extract last frames_to_extract frames from input, then pad - if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]: - video = video[:, :, -num_frames_in:, :, :] - - num_frames_out = num_frames - - if video.shape[2] < num_frames_out: - n_pad_frames = num_frames_out - video.shape[2] - last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W] - pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] - video = torch.cat((video, pad_frames), dim=2) - - assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + else: + video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=torch.uint8) + num_frames_in = 0 video = video.to(device=device, dtype=vae_dtype) @@ -768,9 +761,6 @@ def __call__( generator=generator, latents=latents, ) - cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep - cond_mask = cond_mask.to(transformer_dtype) - padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) # Denoising loop @@ -778,8 +768,9 @@ def __call__( timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - + cond_mask = cond_mask.to(transformer_dtype) gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -788,15 +779,14 @@ def __call__( self._current_timestep = t.cpu().item() # NOTE: assumes sigma(t) \in [0, 1] - sigma_t = ( - torch.tensor(self.scheduler.sigmas[i].item()) - .unsqueeze(0) - .to(device=device, dtype=transformer_dtype) - ) - + sigma_t = self.scheduler.sigmas[i].expand(batch_size).to(device=device, dtype=torch.float32) + if conditional_frame_timestep >= 0: + in_timestep = cond_indicator * conditional_frame_timestep + (1 - cond_indicator) * sigma_t + else: + in_timestep = sigma_t in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents in_latents = in_latents.to(transformer_dtype) - in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( hidden_states=in_latents, condition_mask=cond_mask, @@ -805,7 +795,7 @@ def __call__( padding_mask=padding_mask, return_dict=False, )[0] - # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + # NOTE: replace velocity with gt_velocity for conditioning inputs only noise_pred = gt_velocity + noise_pred * (1 - cond_mask) if self.do_classifier_free_guidance: @@ -817,7 +807,7 @@ def __call__( padding_mask=padding_mask, return_dict=False, )[0] - # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + # NOTE: replace velocity with gt_velocity for conditioning inputs only noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) @@ -845,20 +835,20 @@ def __call__( if not output_type == "latent": latents_mean = self.latents_mean.to(latents.device, latents.dtype) latents_std = self.latents_std.to(latents.device, latents.dtype) - latents = latents * latents_std + latents_mean + latents = latents / latents_std + latents_mean video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] video = self._match_num_frames(video, num_frames) - assert self.safety_checker is not None - self.safety_checker.to(device) - video = self.video_processor.postprocess_video(video, output_type="np") - video = (video * 255).astype(np.uint8) - video_batch = [] - for vid in video: - vid = self.safety_checker.check_video_safety(vid) - video_batch.append(vid) - video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 - video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + if isinstance(self.safety_checker, CosmosSafetyChecker): + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 5c2cbcc13ff1..71a5444491ed 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -882,8 +882,7 @@ def multistep_uni_p_bh_update( x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t - device = sample.device - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1].to(device), self.sigmas[self.step_index].to(device) + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -891,20 +890,21 @@ def multistep_uni_p_bh_update( lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 + device = sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - i mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) - rks.append(torch.ones((), device=device)) - rks = torch.stack(rks) + rks.append(1.0) + rks = torch.tensor(rks, device=device) R = [] b = [] @@ -929,13 +929,13 @@ def multistep_uni_p_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device) + b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: - rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5 + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: @@ -1017,8 +1017,7 @@ def multistep_uni_c_bh_update( x_t = this_sample model_t = this_model_output - device = this_sample.device - sigma_t, sigma_s0 = self.sigmas[self.step_index].to(device), self.sigmas[self.step_index - 1].to(device) + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -1026,20 +1025,21 @@ def multistep_uni_c_bh_update( lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 + device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) - rks.append(torch.ones((), device=device)) - rks = torch.stack(rks) + rks.append(1.0) + rks = torch.tensor(rks, device=device) R = [] b = [] @@ -1064,7 +1064,7 @@ def multistep_uni_c_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device) + b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) @@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update( # for order 1, we use a simplified version if order == 1: - rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5 + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)