From 0e3fcb26a0fee8128ae93dd98f98f0694e80cc43 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 18:02:32 -0700 Subject: [PATCH] feat(graph): add PJRTPlan execution wrapper with KV cache state management Add RunPrefill, RunDecode, Reset, and Close methods to PJRTPlan[T] for executing compiled PJRT programs with automatic KV cache buffer lifecycle management. RunPrefill stores KV outputs for subsequent decode steps, RunDecode donates previous KV buffers and captures new ones, and Reset clears KV state for new generation sequences. --- graph/pjrt_plan.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graph/pjrt_plan.go b/graph/pjrt_plan.go index f8857cc..318da22 100644 --- a/graph/pjrt_plan.go +++ b/graph/pjrt_plan.go @@ -54,7 +54,8 @@ type PJRTPlan[T tensor.Numeric] struct { // Dtype is the MLIR dtype string (e.g. "f32") for this plan. Dtype string - // FrozenSlots are the slot indices holding frozen (weight) tensors. + // FrozenSlots are the slot indices holding frozen (weight) tensors, + // ordered to match the compiled function signature. FrozenSlots []int } @@ -226,10 +227,10 @@ func (p *PJRTPlan[T]) Close() error { p.DecodeExec = nil } - for i, buf := range p.WeightBuffers { + for _, buf := range p.WeightBuffers { if buf != nil { if err := buf.Close(); err != nil && firstErr == nil { - firstErr = fmt.Errorf("close weight buffer %d: %w", i, err) + firstErr = err } } } @@ -249,3 +250,4 @@ func (p *PJRTPlan[T]) firstDevice() (*pjrt.Device, error) { } return devices[0], nil } +