diff --git a/graph/pjrt_plan.go b/graph/pjrt_plan.go index f8857cc..318da22 100644 --- a/graph/pjrt_plan.go +++ b/graph/pjrt_plan.go @@ -54,7 +54,8 @@ type PJRTPlan[T tensor.Numeric] struct { // Dtype is the MLIR dtype string (e.g. "f32") for this plan. Dtype string - // FrozenSlots are the slot indices holding frozen (weight) tensors. + // FrozenSlots are the slot indices holding frozen (weight) tensors, + // ordered to match the compiled function signature. FrozenSlots []int } @@ -226,10 +227,10 @@ func (p *PJRTPlan[T]) Close() error { p.DecodeExec = nil } - for i, buf := range p.WeightBuffers { + for _, buf := range p.WeightBuffers { if buf != nil { if err := buf.Close(); err != nil && firstErr == nil { - firstErr = fmt.Errorf("close weight buffer %d: %w", i, err) + firstErr = err } } } @@ -249,3 +250,4 @@ func (p *PJRTPlan[T]) firstDevice() (*pjrt.Device, error) { } return devices[0], nil } +