Add Cayley-Hamilton triangular inverse example#357
Conversation
New intermediate example computing inv(I - A) for strict-lower-triangular A via the Cayley-Hamilton doubling algorithm (X = X + X@Y, Y = Y@Y for ceil(log2(n)) iterations). Targets the chunked-GDN tri-inverse used in Qwen3-Next. Validates against torch.linalg.inv(I - A) on Ascend 910B2 (n=128, FP32). Structure follows the canonical gemm.py pattern: M-parallel pl.parallel inside pl.at(chunked_loop_optimizer) with K-blocked pl.matmul + pl.matmul_acc to keep each InCore function inside the 192 KB AIV Vec budget. Y is snapshotted to Y_temp before the Y = Y@Y step so the matmul reads from a stable tensor while writing Y_state. Test setup: - Seeded random init (local torch.Generator(seed=0)) for reproducibility. - Init scale 1/(4*sqrt(n)) targets ||A||_op ~= 0.5, bounding the intermediate-tile magnitudes through the 7 doubling iterations. - Tolerance rtol = atol = 2e-2: 14 chained matmuls on the Ascend 910B2 cube (FP16 multiply + FP32 accumulate) compound to ~1e-2 relative error, and matmul reduction ordering is not bit-deterministic across runs. 2e-2 keeps the test green across 5+ consecutive invocations.
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new example script, examples/intermediate/tri_inverse.py, which implements the Cayley-Hamilton doubling algorithm to compute the inverse of a strict-lower-triangular matrix. The implementation utilizes tiled matrix multiplications and additions within a specialized program structure. Feedback highlights that the current logic assumes the matrix dimension is a multiple of the tile sizes, which should be explicitly validated to prevent out-of-bounds access or incorrect results. Additionally, the reviewer suggests using double buffering for the state tensors to avoid inefficient global memory copies during each doubling iteration.
| n_steps = max(1, (n - 1).bit_length()) # 7 for n=128 | ||
| k_blocks = n // k_tile # 2 for n=128, k_tile=64 |
There was a problem hiding this comment.
The current implementation assumes that n is a multiple of both m_tile and k_tile. If n is not a multiple, the pl.slice operations will either access out-of-bounds memory or the k_blocks loop will miss the remainder of the K dimension, leading to incorrect results or runtime errors. It is recommended to add an explicit check or handle the remainder.
| n_steps = max(1, (n - 1).bit_length()) # 7 for n=128 | |
| k_blocks = n // k_tile # 2 for n=128, k_tile=64 | |
| n_steps = max(1, (n - 1).bit_length()) # 7 for n=128 | |
| if n % m_tile != 0 or n % k_tile != 0: | |
| raise ValueError(f"n ({n}) must be a multiple of m_tile ({m_tile}) and k_tile ({k_tile})") | |
| k_blocks = n // k_tile # 2 for n=128, k_tile=64 |
| with pl.at(level=pl.Level.CORE_GROUP, | ||
| optimization=pl.chunked_loop_optimizer, | ||
| name_hint="cayley_y_snapshot"): | ||
| for mb in pl.parallel(0, n, m_tile, chunk=m_chunk): | ||
| y_row = pl.slice(Y_state, [m_tile, n], [mb, 0]) | ||
| Y_temp = pl.assemble(Y_temp, y_row, [mb, 0]) |
There was a problem hiding this comment.
The cayley_y_snapshot scope performs a full GM-to-GM copy of Y_state to Y_temp in every doubling iteration. This is inefficient as it consumes significant global memory bandwidth. A more efficient approach is to use double buffering: create two Y tensors and alternate between them as source and destination in each iteration. This avoids the explicit copy scope entirely.
Add Cayley-Hamilton triangular inverse example
Summary
Adds
examples/intermediate/tri_inverse.py— a new intermediate example thatcomputes
inv(I - A)for strict-lower-triangularAvia the Cayley-Hamiltondoubling algorithm. Targets the chunked-GDN triangular-inverse used in
Qwen3-Next gated delta-rule attention.
Algorithm
Ais strict-lower-triangular, therefore nilpotent (A^n = 0), so thedoubling iteration terminates exactly in
ceil(log2(n))steps:For
n = 128this is 7 doubling steps → 14 matmuls + 7 adds, all on[128, 128]FP32 tiles.Sign convention solves
inv(I - A), matching pypto2's Qweninverse_ptoinmodels/qwen3_next/gated_delta_rule_impl.py.Kernel structure
Each iteration is laid out as four
pl.at(CORE_GROUP, chunked_loop_optimizer)scopes, all M-parallel (
pl.parallel(0, n, m_tile, chunk=m_chunk)) withK-blocked
pl.matmul+pl.matmul_accinside each scope:cayley_x_matmulX_state,Y_stateX_acccayley_x_addX_state,X_accX_statecayley_y_snapshotY_stateY_tempcayley_y_squareY_tempY_stateState (
X_state,Y_state,Y_temp,X_acc) lives inpl.create_tensorGMbuffers and flows across scopes.
Why X-update is split into matmul + add. Keeping
x_row(read),acc(matmul output), andx_new(add output) all alive in onepl.atputsthree
[m_tile, n]tiles in the live set per scope. With the cube/Vecpipeline doublers this just reaches the 192 KB AIV Vec budget; the Vec
allocator now spills the last M-tile and corrupts its output. Splitting
matmul into one scope (writes a fresh GM tensor
X_acc) and add into asecond (reads
x_row + X_acc, writesX_state) halves the live set perscope and keeps the kernel valid with the same
m_tile=32, k_tile=64config.Y snapshot.
Y_temp = Y_statebeforeY = Y @ Yso the matmul readsfrom a stable tensor while another parallel iter is writing its row slab to
Y_state.Validation
Validates against
torch.linalg.inv(I - A)on a real Ascend 910B2 NPU withrtol = atol = 2e-2. RandomAseeded for reproducibility with a localtorch.Generator(seed=0); init scale1 / (4 * sqrt(n))targets||A||_op ≈ 0.5so the intermediate-tile magnitudes through 7 doublingiterations stay bounded.
Tolerance rationale: 14 chained matmuls on the 910B2 cube (FP16 multiply +
FP32 accumulate) compound to ~1e-2 relative error, and matmul reduction
ordering is not bit-deterministic across runs.
2e-2keeps the test greenacross 5+ consecutive invocations.
Test plan
rtol=atol=2e-2python examples/intermediate/tri_inverse.py -p a2a3 -d <device>gemm.pystyle(M-parallel
pl.parallelinsidepl.at(chunked_loop_optimizer),K-blocked matmul/matmul_acc, GM state across
pl.atscopes)How to run