Skip to content

Cosmos3 context parallel#14054

Draft
atharvajoshi10 wants to merge 2 commits into
huggingface:mainfrom
atharvajoshi10:cosmos3-context-parallel
Draft

Cosmos3 context parallel#14054
atharvajoshi10 wants to merge 2 commits into
huggingface:mainfrom
atharvajoshi10:cosmos3-context-parallel

Conversation

@atharvajoshi10

Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Under sharded placement (device_map="balanced"), vae.encode() runs on the
VAE's own device while the mean/inv_std buffers were pinned to x.device,
causing a cross-device RuntimeError. Compute raw_mu first, then pin the
normalization buffers to its device so all tensors share one device.
@atharvajoshi10 atharvajoshi10 marked this pull request as draft June 23, 2026 17:31
@github-actions github-actions Bot added documentation Improvements or additions to documentation models pipelines examples size/L PR with diff > 200 LOC labels Jun 23, 2026
@atharvajoshi10 atharvajoshi10 force-pushed the cosmos3-context-parallel branch 3 times, most recently from 962f513 to 67fb9ec Compare June 23, 2026 18:03
Cosmos 3 cannot use diffusers' declarative `_cp_plan` CP path: it is grouped-query
attention (the shared Ulysses kernel assumes K/V share the query head count), its
understanding (causal) and generation (full) streams are separate packed sequences
(gen attends to cat(und, gen)), and per-pathway lengths are ragged. The model carries
no parallelism logic -- it exposes only small, CP-agnostic seams; all sharding lives
outside it, in a reusable example module.

Model (transformer_cosmos3.py): adds two default-None `forward` seams -- `_cp_shard_fn`
(shards und/gen + rotary before the decoder layers) and `_cp_gather_fn` (gathers/unpads
after the final norm) -- and extracts `Cosmos3AttnProcessor._run_attention` as an
override point. The non-parallel path is unchanged.

Helpers (examples/cosmos3/cosmos_parallel.py): one importable module, two orthogonal
and composable axes:
  * Context parallelism (Ulysses) -- `enable_cosmos3_context_parallel`. Shards the
    sequence; brackets the two attention pathways with all-to-all (DTensor redistribute),
    repeats GQA KV heads, pads ragged lengths and masks padded generation keys.
  * Tensor parallelism (Megatron) -- `enable_cosmos3_tensor_parallel`. Column/row-shards
    the attention + MLP weights so a checkpoint that does not fit one GPU (Super, ~120 GB)
    loads across several; weights load to CPU then shard layer by layer.
Both expand KV heads to the query-head count and call SDPA with enable_gqa=False so it
dispatches to the flash kernel; enable_gqa=True forces the math path, which materializes
the full [S, S] score matrix and OOMs on long videos. A dense `Cosmos3FlashAttnProcessor`
(`enable_cosmos3_flash_attention`) provides the same for TP without CP.

CLI (examples/cosmos3/inference_cosmos3.py): imports these helpers, so any modality
(text-to-image/video, image-to-video, sound, action) runs single- or multi-GPU via
`--tp-degree` / `--cp-degree` (their product must equal --nproc_per_node). Single-GPU
behavior is unchanged.

Docs + example README updated. Verified: CP attention core is bit-exact vs non-CP in
fp32 (max|d|=0), and a full 36-layer forward matches CP-on vs CP-off to ~1e-6 in fp32
(bf16 differs only by floating-point rounding).
@atharvajoshi10 atharvajoshi10 force-pushed the cosmos3-context-parallel branch from 67fb9ec to 6edc5fd Compare June 24, 2026 00:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples models pipelines size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant