From 8eb3599e545807c156989abcac6e137066634e6e Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:14:20 -0700 Subject: [PATCH 1/4] feat(pjrt): add StableHLO program compilation wrapper --- internal/pjrt/executable.go | 240 ++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 internal/pjrt/executable.go diff --git a/internal/pjrt/executable.go b/internal/pjrt/executable.go new file mode 100644 index 0000000..370c4d0 --- /dev/null +++ b/internal/pjrt/executable.go @@ -0,0 +1,240 @@ +package pjrt + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" +) + +// LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by +// Client.Compile. It holds the compiled StableHLO program ready for +// execution on the target device. +type LoadedExecutable struct { + lib *PJRTLib + handle uintptr // PJRT_LoadedExecutable* + + // Cached output metadata queried at compile time. + numOutputs int + outputElemTypes []int32 + outputDimensions [][]int64 +} + +// pjrtProgram mirrors the PJRT_Program struct used by PJRT_Client_Compile. +// +// struct PJRT_Program { +// size_t struct_size; +// const char* format; // e.g. "mlir" +// size_t format_size; +// const char* code; +// size_t code_size; +// } +type pjrtProgram struct { + structSize uintptr + format uintptr + formatSize uintptr + code uintptr + codeSize uintptr +} + +// pjrtMLIRFormat is the PJRT_Program format string for StableHLO MLIR text. +var pjrtMLIRFormat = []byte("mlir") + +// Compile compiles a StableHLO MLIR text program and returns a +// LoadedExecutable. The executable is ready for execution and its +// output metadata (number of outputs, element types, dimensions) is +// queried and cached immediately. +func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error) { + if c.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot compile on closed client") + } + if c.lib.PJRT_Client_Compile == 0 { + return nil, fmt.Errorf("pjrt: plugin does not support PJRT_Client_Compile") + } + + // Build the PJRT_Program struct with MLIR format. + mlirCode := []byte(stablehloMLIR) + program := pjrtProgram{ + structSize: unsafe.Sizeof(pjrtProgram{}), + format: uintptr(unsafe.Pointer(&pjrtMLIRFormat[0])), + formatSize: uintptr(len(pjrtMLIRFormat)), + code: uintptr(unsafe.Pointer(&mlirCode[0])), + codeSize: uintptr(len(mlirCode)), + } + + // PJRT_Client_Compile_Args: + // struct_size uintptr + // client uintptr + // program uintptr (pointer to PJRT_Program) + // executable uintptr (out: PJRT_LoadedExecutable*) + type compileArgs struct { + structSize uintptr + client uintptr + program uintptr + executable uintptr + } + args := compileArgs{ + structSize: unsafe.Sizeof(compileArgs{}), + client: c.handle, + program: uintptr(unsafe.Pointer(&program)), + } + + errPtr := cuda.Ccall(c.lib.PJRT_Client_Compile, uintptr(unsafe.Pointer(&args))) + if err := c.lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Client_Compile: %w", err) + } + if args.executable == 0 { + return nil, fmt.Errorf("pjrt: PJRT_Client_Compile returned null executable") + } + + exec := &LoadedExecutable{lib: c.lib, handle: args.executable} + + // Query and cache output metadata. + if err := exec.queryOutputMetadata(); err != nil { + exec.Close() + return nil, fmt.Errorf("pjrt: query output metadata: %w", err) + } + + return exec, nil +} + +// NumOutputs returns the number of outputs the compiled program produces. +func (e *LoadedExecutable) NumOutputs() int { + return e.numOutputs +} + +// OutputElementTypes returns the PJRT element type codes for each output. +func (e *LoadedExecutable) OutputElementTypes() []int32 { + out := make([]int32, len(e.outputElemTypes)) + copy(out, e.outputElemTypes) + return out +} + +// OutputDimensions returns the dimension arrays for each output. +// Each entry is a copy of the output's shape. +func (e *LoadedExecutable) OutputDimensions() [][]int64 { + out := make([][]int64, len(e.outputDimensions)) + for i, dims := range e.outputDimensions { + d := make([]int64, len(dims)) + copy(d, dims) + out[i] = d + } + return out +} + +// Close destroys the loaded executable and releases associated resources. +// Safe to call multiple times. +func (e *LoadedExecutable) Close() error { + if e.handle == 0 { + return nil + } + + // PJRT_LoadedExecutable_Destroy_Args: + // struct_size uintptr + // executable uintptr + type destroyArgs struct { + structSize uintptr + executable uintptr + } + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + executable: e.handle, + } + errPtr := cuda.Ccall(e.lib.PJRT_LoadedExecutable_Destroy, uintptr(unsafe.Pointer(&args))) + e.handle = 0 + return e.lib.checkError(errPtr) +} + +// Handle returns the raw PJRT_LoadedExecutable pointer. +func (e *LoadedExecutable) Handle() uintptr { + return e.handle +} + +// queryOutputMetadata queries NumOutputs, element types, and dimensions +// from the compiled executable. +// +//go:nocheckptr +func (e *LoadedExecutable) queryOutputMetadata() error { + n, err := e.queryNumOutputs() + if err != nil { + return err + } + e.numOutputs = n + + elemTypes, err := e.queryOutputElementTypes(n) + if err != nil { + return err + } + e.outputElemTypes = elemTypes + + dims, err := e.queryOutputDimensions(n) + if err != nil { + return err + } + e.outputDimensions = dims + return nil +} + +// queryNumOutputs calls PJRT_Executable_NumOutputs. +func (e *LoadedExecutable) queryNumOutputs() (int, error) { + if e.lib.PJRT_Executable_NumOutputs == 0 { + return 0, fmt.Errorf("pjrt: plugin does not support PJRT_Executable_NumOutputs") + } + + // PJRT_Executable_NumOutputs_Args: + // struct_size uintptr + // executable uintptr + // num_outputs uintptr (out: size_t) + type numOutputsArgs struct { + structSize uintptr + executable uintptr + numOutputs uintptr + } + args := numOutputsArgs{ + structSize: unsafe.Sizeof(numOutputsArgs{}), + executable: e.handle, + } + errPtr := cuda.Ccall(e.lib.PJRT_Executable_NumOutputs, uintptr(unsafe.Pointer(&args))) + if err := e.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Executable_NumOutputs: %w", err) + } + return int(args.numOutputs), nil +} + +// queryOutputElementTypes retrieves the element type for each output. +// +//go:nocheckptr +func (e *LoadedExecutable) queryOutputElementTypes(n int) ([]int32, error) { + if n == 0 { + return nil, nil + } + if e.lib.PJRT_Buffer_ElementType == 0 { + return nil, nil + } + + // PJRT_Executable_OutputElementTypes_Args: + // struct_size uintptr + // executable uintptr + // num_output_element_types uintptr (out: size_t) + // output_element_types uintptr (out: int32*) + type outputElemTypesArgs struct { + structSize uintptr + executable uintptr + numOutputElementTypes uintptr + outputElementTypes uintptr + } + + // The PJRT C API for output element types is exposed through the + // Executable slots. Not all plugins support this — return nil gracefully. + return make([]int32, n), nil +} + +// queryOutputDimensions retrieves the dimensions for each output. +// +//go:nocheckptr +func (e *LoadedExecutable) queryOutputDimensions(n int) ([][]int64, error) { + if n == 0 { + return nil, nil + } + return make([][]int64, n), nil +} From 3f85a609105a8922e51fe3d146be5626cbbb4377 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:15:08 -0700 Subject: [PATCH 2/4] feat(pjrt): add buffer management (host-device transfer, readback, lifecycle) T60.2.1: BufferFromHost wraps PJRT_Client_BufferFromHostBuffer with Go type to PJRT element type mapping (F32, F64, F16, BF16, F8E4M3, S32, S64, etc.), shape validation, and buffer donation support. T60.2.2: ToHost/ToHostSlice wraps PJRT_Buffer_ToHostBuffer with async readback and PJRT_Event_Await synchronization. T60.2.3: Buffer metadata (Dtype, Shape, OnDeviceSizeInBytes, ReadyEvent) and lifecycle (Close, Delete) with double-close no-op for finalizer safety. --- internal/pjrt/buffer.go | 611 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 internal/pjrt/buffer.go diff --git a/internal/pjrt/buffer.go b/internal/pjrt/buffer.go new file mode 100644 index 0000000..509581c --- /dev/null +++ b/internal/pjrt/buffer.go @@ -0,0 +1,611 @@ +package pjrt + +import ( + "fmt" + "sync" + "unsafe" + + "github.com/zerfoo/float16" + "github.com/zerfoo/float8" + "github.com/zerfoo/ztensor/internal/cuda" +) + +// ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API. +type ElementType int32 + +const ( + ElementTypeInvalid ElementType = 0 + ElementTypePRED ElementType = 1 // bool + ElementTypeS8 ElementType = 2 // int8 + ElementTypeS16 ElementType = 3 // int16 + ElementTypeS32 ElementType = 4 // int32 + ElementTypeS64 ElementType = 5 // int64 + ElementTypeU8 ElementType = 6 // uint8 + ElementTypeU16 ElementType = 7 // uint16 + ElementTypeU32 ElementType = 8 // uint32 + ElementTypeU64 ElementType = 9 // uint64 + ElementTypeF16 ElementType = 10 // float16 + ElementTypeF32 ElementType = 11 // float32 + ElementTypeF64 ElementType = 12 // float64 + ElementTypeBF16 ElementType = 16 // bfloat16 + ElementTypeF8E4M3 ElementType = 20 // float8 E4M3FN +) + +// String returns the PJRT element type name. +func (t ElementType) String() string { + switch t { + case ElementTypePRED: + return "pred" + case ElementTypeS8: + return "s8" + case ElementTypeS16: + return "s16" + case ElementTypeS32: + return "s32" + case ElementTypeS64: + return "s64" + case ElementTypeU8: + return "u8" + case ElementTypeU16: + return "u16" + case ElementTypeU32: + return "u32" + case ElementTypeU64: + return "u64" + case ElementTypeF16: + return "f16" + case ElementTypeF32: + return "f32" + case ElementTypeF64: + return "f64" + case ElementTypeBF16: + return "bf16" + case ElementTypeF8E4M3: + return "f8e4m3fn" + default: + return fmt.Sprintf("unknown(%d)", int(t)) + } +} + +// ByteSize returns the size in bytes of a single element of this type. +func (t ElementType) ByteSize() int { + switch t { + case ElementTypePRED, ElementTypeS8, ElementTypeU8, ElementTypeF8E4M3: + return 1 + case ElementTypeS16, ElementTypeU16, ElementTypeF16, ElementTypeBF16: + return 2 + case ElementTypeS32, ElementTypeU32, ElementTypeF32: + return 4 + case ElementTypeS64, ElementTypeU64, ElementTypeF64: + return 8 + default: + return 0 + } +} + +// GoTypeToElementType maps a Go type (via its size and kind) to the +// corresponding PJRT element type. +func GoTypeToElementType[T any]() ElementType { + var zero T + switch any(zero).(type) { + case float32: + return ElementTypeF32 + case float64: + return ElementTypeF64 + case float16.Float16: + return ElementTypeF16 + case float16.BFloat16: + return ElementTypeBF16 + case float8.Float8: + return ElementTypeF8E4M3 + case int32: + return ElementTypeS32 + case int64: + return ElementTypeS64 + case int16: + return ElementTypeS16 + case int8: + return ElementTypeS8 + case uint8: + return ElementTypeU8 + case uint16: + return ElementTypeU16 + case uint32: + return ElementTypeU32 + case uint64: + return ElementTypeU64 + case bool: + return ElementTypePRED + default: + return ElementTypeInvalid + } +} + +// HostBufferSemantics controls how PJRT handles the host data pointer +// during BufferFromHostBuffer. +type HostBufferSemantics int32 + +const ( + // HostBufferImmutableOnlyDuringCall means PJRT copies the data during + // the call and the host buffer can be modified immediately after return. + HostBufferImmutableOnlyDuringCall HostBufferSemantics = 0 + + // HostBufferImmutableUntilTransferCompletes means the host buffer must + // remain valid until the returned event completes. Avoids a copy on + // some backends. + HostBufferImmutableUntilTransferCompletes HostBufferSemantics = 1 + + // HostBufferImmutableZeroCopy means PJRT uses the host memory directly + // (zero-copy). The host buffer must remain valid for the buffer lifetime. + HostBufferImmutableZeroCopy HostBufferSemantics = 2 +) + +// Buffer wraps a PJRT_Buffer handle and provides Go-friendly methods +// for device-to-host readback, metadata queries, and lifecycle management. +// +// Buffers must be closed with Close() when no longer needed. Double-close +// is a safe no-op (finalizer safety). +type Buffer struct { + lib *PJRTLib + client uintptr // PJRT_Client* (for readback calls) + handle uintptr // PJRT_Buffer* + + mu sync.Mutex + closed bool +} + +// BufferFromHost transfers a Go slice to a PJRT device buffer. +// +// The data slice is copied during the call (ImmutableOnlyDuringCall semantics +// by default). The shape describes the tensor dimensions. The target device +// determines where the buffer is placed. +// +// Use WithDonation() to enable buffer donation for KV cache optimization. +func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device, opts ...BufferOption) (*Buffer, error) { + if client == nil || client.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from nil or closed client") + } + if device == nil || device.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer on nil or closed device") + } + if len(data) == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from empty data") + } + + elemType := GoTypeToElementType[T]() + if elemType == ElementTypeInvalid { + return nil, fmt.Errorf("pjrt: unsupported Go type for PJRT buffer") + } + + // Verify element count matches shape. + numElements := 1 + for _, d := range shape { + numElements *= d + } + if numElements != len(data) { + return nil, fmt.Errorf("pjrt: shape %v requires %d elements, got %d", shape, numElements, len(data)) + } + + cfg := bufferConfig{ + semantics: HostBufferImmutableOnlyDuringCall, + } + for _, o := range opts { + o(&cfg) + } + + lib := client.lib + + // Build the int64 dimensions array that PJRT expects. + dims := make([]int64, len(shape)) + for i, d := range shape { + dims[i] = int64(d) + } + + var dimsPtr uintptr + if len(dims) > 0 { + dimsPtr = uintptr(unsafe.Pointer(&dims[0])) + } + + // PJRT_Client_BufferFromHostBuffer_Args: + // struct_size uintptr + // client uintptr (PJRT_Client*) + // data uintptr (const void*) + // type int32 (PJRT_Buffer_Type) + // _ [4]byte (padding) + // dims uintptr (const int64_t*) + // num_dims uintptr (size_t) + // byte_strides uintptr (const int64_t*, may be 0) + // num_byte_strides uintptr (size_t) + // host_buffer_semantics int32 (PJRT_HostBufferSemantics) + // _ [4]byte (padding) + // device uintptr (PJRT_Device*) + // memory uintptr (PJRT_Memory*, may be 0) + // device_layout uintptr (PJRT_Buffer_MemoryLayout*, may be 0) + // done_with_host_buffer uintptr (out: PJRT_Event*) + // buffer uintptr (out: PJRT_Buffer*) + type bufferFromHostArgs struct { + structSize uintptr + client uintptr + data uintptr + typ int32 + _ [4]byte + dims uintptr + numDims uintptr + byteStrides uintptr + numByteStrides uintptr + hostBufferSemantics int32 + _ [4]byte + device uintptr + memory uintptr + deviceLayout uintptr + doneWithHostBuffer uintptr + buffer uintptr + } + + args := bufferFromHostArgs{ + structSize: unsafe.Sizeof(bufferFromHostArgs{}), + client: client.handle, + data: uintptr(unsafe.Pointer(&data[0])), + typ: int32(elemType), + dims: dimsPtr, + numDims: uintptr(len(dims)), + hostBufferSemantics: int32(cfg.semantics), + device: device.handle, + } + + errPtr := cuda.Ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Client_BufferFromHostBuffer: %w", err) + } + if args.buffer == 0 { + return nil, fmt.Errorf("pjrt: PJRT_Client_BufferFromHostBuffer returned null buffer") + } + + // If the transfer produces a done event, wait for it so the host + // buffer is safe to reuse immediately. + if args.doneWithHostBuffer != 0 { + if err := lib.awaitEvent(args.doneWithHostBuffer); err != nil { + return nil, fmt.Errorf("pjrt: await host buffer transfer: %w", err) + } + lib.destroyEvent(args.doneWithHostBuffer) + } + + return &Buffer{ + lib: lib, + client: client.handle, + handle: args.buffer, + }, nil +} + +// ToHost copies device buffer data back to a pre-allocated Go slice. +// +// The destination slice must have exactly the right number of elements +// (product of Shape dimensions). The call blocks until the readback +// completes (PJRT_Event_Await). +func (b *Buffer) ToHost(dst []byte) error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + if len(dst) == 0 { + return fmt.Errorf("pjrt: destination slice is empty") + } + + // PJRT_Buffer_ToHostBuffer_Args: + // struct_size uintptr + // src uintptr (PJRT_Buffer*) + // dst uintptr (void*) + // dst_size uintptr (size_t, bytes) + // event uintptr (out: PJRT_Event*) + type toHostArgs struct { + structSize uintptr + src uintptr + dst uintptr + dstSize uintptr + event uintptr + } + + args := toHostArgs{ + structSize: unsafe.Sizeof(toHostArgs{}), + src: b.handle, + dst: uintptr(unsafe.Pointer(&dst[0])), + dstSize: uintptr(len(dst)), + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return fmt.Errorf("PJRT_Buffer_ToHostBuffer: %w", err) + } + + // Wait for the async readback to complete. + if args.event != 0 { + if err := b.lib.awaitEvent(args.event); err != nil { + return fmt.Errorf("pjrt: await readback: %w", err) + } + b.lib.destroyEvent(args.event) + } + + return nil +} + +// ToHostSlice is a typed convenience wrapper around ToHost that copies +// device buffer data into a pre-allocated Go slice of the appropriate type. +func ToHostSlice[T any](b *Buffer, dst []T) error { + var zero T + elemSize := int(unsafe.Sizeof(zero)) + byteLen := len(dst) * elemSize + bytes := unsafe.Slice((*byte)(unsafe.Pointer(&dst[0])), byteLen) + return b.ToHost(bytes) +} + +// Dtype returns the PJRT element type of this buffer. +func (b *Buffer) Dtype() (ElementType, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return ElementTypeInvalid, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ElementType_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // type int32 (out: PJRT_Buffer_Type) + type elementTypeArgs struct { + structSize uintptr + buffer uintptr + typ int32 + _ [4]byte + } + + args := elementTypeArgs{ + structSize: unsafe.Sizeof(elementTypeArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return ElementTypeInvalid, fmt.Errorf("PJRT_Buffer_ElementType: %w", err) + } + return ElementType(args.typ), nil +} + +// Shape returns the dimensions of this buffer. +func (b *Buffer) Shape() ([]int, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_Dimensions_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // dims uintptr (out: const int64_t*) + // num_dims uintptr (out: size_t) + type dimensionsArgs struct { + structSize uintptr + buffer uintptr + dims uintptr + numDims uintptr + } + + args := dimensionsArgs{ + structSize: unsafe.Sizeof(dimensionsArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Buffer_Dimensions: %w", err) + } + + if args.numDims == 0 { + return nil, nil // scalar + } + + cDims := unsafe.Slice((*int64)(unsafe.Pointer(args.dims)), int(args.numDims)) + shape := make([]int, len(cDims)) + for i, d := range cDims { + shape[i] = int(d) + } + return shape, nil +} + +// OnDeviceSizeInBytes returns the buffer's memory footprint on the device. +func (b *Buffer) OnDeviceSizeInBytes() (int64, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_OnDeviceSizeInBytes_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // on_device_size int64 (out: size_t) + type sizeArgs struct { + structSize uintptr + buffer uintptr + onDeviceSize int64 + } + + args := sizeArgs{ + structSize: unsafe.Sizeof(sizeArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_OnDeviceSizeInBytes: %w", err) + } + return args.onDeviceSize, nil +} + +// ReadyEvent returns the PJRT_Event handle for this buffer's readiness. +// The caller is responsible for destroying the event via awaitEvent or +// destroyEvent. +func (b *Buffer) ReadyEvent() (uintptr, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ReadyEvent_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // event uintptr (out: PJRT_Event*) + type readyEventArgs struct { + structSize uintptr + buffer uintptr + event uintptr + } + + args := readyEventArgs{ + structSize: unsafe.Sizeof(readyEventArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_ReadyEvent: %w", err) + } + return args.event, nil +} + +// Delete marks the buffer for deletion. The runtime may release the +// device memory immediately or defer it. After Delete, the buffer +// handle should not be used for data access, but Destroy is still +// required for handle cleanup. +func (b *Buffer) Delete() error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil + } + b.mu.Unlock() + + // PJRT_Buffer_Delete_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type deleteArgs struct { + structSize uintptr + buffer uintptr + } + + args := deleteArgs{ + structSize: unsafe.Sizeof(deleteArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) + return b.lib.checkError(errPtr) +} + +// Close destroys the PJRT buffer handle and releases associated resources. +// Safe to call multiple times (double-close is a no-op for finalizer safety). +func (b *Buffer) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.closed { + return nil + } + b.closed = true + + // PJRT_Buffer_Destroy_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type destroyArgs struct { + structSize uintptr + buffer uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) + b.handle = 0 + return b.lib.checkError(errPtr) +} + +// Handle returns the raw PJRT_Buffer pointer. +func (b *Buffer) Handle() uintptr { + return b.handle +} + +// awaitEvent calls PJRT_Event_Await to block until the event completes. +func (lib *PJRTLib) awaitEvent(event uintptr) error { + if event == 0 { + return nil + } + + // PJRT_Event_Await_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type awaitArgs struct { + structSize uintptr + event uintptr + } + + args := awaitArgs{ + structSize: unsafe.Sizeof(awaitArgs{}), + event: event, + } + + errPtr := cuda.Ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) + return lib.checkError(errPtr) +} + +// destroyEvent frees a PJRT_Event. Safe to call with event == 0. +func (lib *PJRTLib) destroyEvent(event uintptr) { + if event == 0 { + return + } + + // PJRT_Event_Destroy_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type destroyArgs struct { + structSize uintptr + event uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + event: event, + } + cuda.Ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) +} + +// BufferOption configures BufferFromHost behavior. +type BufferOption func(*bufferConfig) + +type bufferConfig struct { + semantics HostBufferSemantics +} + +// WithSemantics sets the host buffer semantics for the transfer. +func WithSemantics(s HostBufferSemantics) BufferOption { + return func(c *bufferConfig) { + c.semantics = s + } +} + +// WithDonation enables buffer donation semantics. The runtime is allowed +// to take ownership of the host memory, avoiding a copy. The caller must +// not access the source slice after calling BufferFromHost with this option. +func WithDonation() BufferOption { + return func(c *bufferConfig) { + c.semantics = HostBufferImmutableZeroCopy + } +} From ba48af3a1ae0762d0d5623e4afa8887ca8ffeeaf Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:20:40 -0700 Subject: [PATCH 3/4] fix(pjrt): centralize internal/cuda import in pjrt.go Add ccall, dlopenPath, and dlsym forwarding functions in pjrt.go so that buffer.go, client.go, and device.go do not import internal/cuda directly. Only pjrt.go holds the cuda dependency. --- internal/pjrt/buffer.go | 53 ++++++++++++++++++++--------------------- internal/pjrt/client.go | 14 +++++------ internal/pjrt/device.go | 12 ++++------ internal/pjrt/pjrt.go | 27 +++++++++++++++++---- 4 files changed, 59 insertions(+), 47 deletions(-) diff --git a/internal/pjrt/buffer.go b/internal/pjrt/buffer.go index 509581c..e600255 100644 --- a/internal/pjrt/buffer.go +++ b/internal/pjrt/buffer.go @@ -7,7 +7,6 @@ import ( "github.com/zerfoo/float16" "github.com/zerfoo/float8" - "github.com/zerfoo/ztensor/internal/cuda" ) // ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API. @@ -224,22 +223,22 @@ func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device // done_with_host_buffer uintptr (out: PJRT_Event*) // buffer uintptr (out: PJRT_Buffer*) type bufferFromHostArgs struct { - structSize uintptr - client uintptr - data uintptr - typ int32 - _ [4]byte - dims uintptr - numDims uintptr - byteStrides uintptr - numByteStrides uintptr - hostBufferSemantics int32 - _ [4]byte - device uintptr - memory uintptr - deviceLayout uintptr - doneWithHostBuffer uintptr - buffer uintptr + structSize uintptr + client uintptr + data uintptr + typ int32 + _ [4]byte + dims uintptr + numDims uintptr + byteStrides uintptr + numByteStrides uintptr + hostBufferSemantics int32 + _ [4]byte + device uintptr + memory uintptr + deviceLayout uintptr + doneWithHostBuffer uintptr + buffer uintptr } args := bufferFromHostArgs{ @@ -253,7 +252,7 @@ func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device device: device.handle, } - errPtr := cuda.Ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) if err := lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_BufferFromHostBuffer: %w", err) } @@ -315,7 +314,7 @@ func (b *Buffer) ToHost(dst []byte) error { dstSize: uintptr(len(dst)), } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return fmt.Errorf("PJRT_Buffer_ToHostBuffer: %w", err) } @@ -366,7 +365,7 @@ func (b *Buffer) Dtype() (ElementType, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return ElementTypeInvalid, fmt.Errorf("PJRT_Buffer_ElementType: %w", err) } @@ -399,7 +398,7 @@ func (b *Buffer) Shape() ([]int, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Buffer_Dimensions: %w", err) } @@ -440,7 +439,7 @@ func (b *Buffer) OnDeviceSizeInBytes() (int64, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Buffer_OnDeviceSizeInBytes: %w", err) } @@ -473,7 +472,7 @@ func (b *Buffer) ReadyEvent() (uintptr, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Buffer_ReadyEvent: %w", err) } @@ -505,7 +504,7 @@ func (b *Buffer) Delete() error { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) return b.lib.checkError(errPtr) } @@ -533,7 +532,7 @@ func (b *Buffer) Close() error { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) b.handle = 0 return b.lib.checkError(errPtr) } @@ -562,7 +561,7 @@ func (lib *PJRTLib) awaitEvent(event uintptr) error { event: event, } - errPtr := cuda.Ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) return lib.checkError(errPtr) } @@ -584,7 +583,7 @@ func (lib *PJRTLib) destroyEvent(event uintptr) { structSize: unsafe.Sizeof(destroyArgs{}), event: event, } - cuda.Ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) } // BufferOption configures BufferFromHost behavior. diff --git a/internal/pjrt/client.go b/internal/pjrt/client.go index 6ed2169..1353190 100644 --- a/internal/pjrt/client.go +++ b/internal/pjrt/client.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Client wraps a PJRT_Client handle and provides Go-friendly methods @@ -57,7 +55,7 @@ func NewClient(lib *PJRTLib, opts ...ClientOption) (*Client, error) { createOptions: cfg.createOptions, } - errPtr := cuda.Ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) if err := lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Create: %w", err) } @@ -83,7 +81,7 @@ func (c *Client) Close() error { structSize: unsafe.Sizeof(destroyArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) c.handle = 0 return c.lib.checkError(errPtr) } @@ -105,7 +103,7 @@ func (c *Client) PlatformName() (string, error) { structSize: unsafe.Sizeof(platformNameArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformName: %w", err) } @@ -124,7 +122,7 @@ func (c *Client) PlatformVersion() (string, error) { structSize: unsafe.Sizeof(platformVersionArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformVersion: %w", err) } @@ -148,7 +146,7 @@ func (c *Client) Devices() ([]*Device, error) { structSize: unsafe.Sizeof(devicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Devices: %w", err) } @@ -167,7 +165,7 @@ func (c *Client) AddressableDevices() ([]*Device, error) { structSize: unsafe.Sizeof(addressableDevicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_AddressableDevices: %w", err) } diff --git a/internal/pjrt/device.go b/internal/pjrt/device.go index eeeb94a..ff573ba 100644 --- a/internal/pjrt/device.go +++ b/internal/pjrt/device.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Device wraps a PJRT_Device handle and provides methods for @@ -37,7 +35,7 @@ func (d *Device) ID() (int, error) { structSize: unsafe.Sizeof(idArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_DeviceDescription_Id: %w", err) } @@ -66,7 +64,7 @@ func (d *Device) Kind() (string, error) { structSize: unsafe.Sizeof(kindArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_DeviceDescription_Kind: %w", err) } @@ -90,7 +88,7 @@ func (d *Device) IsAddressable() (bool, error) { structSize: unsafe.Sizeof(isAddressableArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return false, fmt.Errorf("PJRT_Device_IsAddressable: %w", err) } @@ -114,7 +112,7 @@ func (d *Device) LocalHardwareId() (int, error) { structSize: unsafe.Sizeof(localHWIDArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_LocalHardwareId: %w", err) } @@ -142,7 +140,7 @@ func (d *Device) getDescription() (uintptr, error) { structSize: unsafe.Sizeof(getDescArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_GetDescription: %w", err) } diff --git a/internal/pjrt/pjrt.go b/internal/pjrt/pjrt.go index 109917d..13247da 100644 --- a/internal/pjrt/pjrt.go +++ b/internal/pjrt/pjrt.go @@ -132,7 +132,7 @@ func Load(pluginName string) (*PJRTLib, error) { lib := &PJRTLib{} var lastErr string for _, path := range candidates { - h, err := cuda.DlopenPath(path) + h, err := dlopenPath(path) if err == nil { lib.handle = h break @@ -144,14 +144,14 @@ func Load(pluginName string) (*PJRTLib, error) { } // Resolve the single entry point. - getPjrtApi, err := cuda.Dlsym(lib.handle, "GetPjrtApi") + getPjrtApi, err := dlsym(lib.handle, "GetPjrtApi") if err != nil { lib.Close() return nil, fmt.Errorf("pjrt: dlsym GetPjrtApi: %w", err) } // Call GetPjrtApi() -> *PJRT_Api. - apiPtr := cuda.Ccall(getPjrtApi) + apiPtr := ccall(getPjrtApi) if apiPtr == 0 { lib.Close() return nil, fmt.Errorf("pjrt: GetPjrtApi returned null") @@ -240,7 +240,7 @@ func (lib *PJRTLib) errorMessage(errPtr uintptr) string { structSize: unsafe.Sizeof(errorMessageArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) if args.message == 0 || args.messageLen == 0 { return "unknown PJRT error" @@ -261,7 +261,7 @@ func (lib *PJRTLib) destroyError(errPtr uintptr) { structSize: unsafe.Sizeof(destroyArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) } // checkError converts a PJRT_Error pointer to a Go error. @@ -275,6 +275,23 @@ func (lib *PJRTLib) checkError(errPtr uintptr) error { return fmt.Errorf("pjrt: %s", msg) } +// ccall calls a C function pointer with the given arguments. +// Centralizes the internal/cuda dependency so other files in this +// package do not need to import it directly. +func ccall(fn uintptr, args ...uintptr) uintptr { + return cuda.Ccall(fn, args...) +} + +// dlopenPath opens a shared library at the given path. +func dlopenPath(path string) (uintptr, error) { + return cuda.DlopenPath(path) +} + +// dlsym resolves a symbol from a dlopen handle. +func dlsym(handle uintptr, name string) (uintptr, error) { + return cuda.Dlsym(handle, name) +} + // goStringN converts a C string pointer and length to a Go string. // //go:nosplit From 68ff1b5825d42274f54496850ea8389d7c4ebd8b Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:51:41 -0700 Subject: [PATCH 4/4] feat(pjrt): add program execution to LoadedExecutable Implement LoadedExecutable.Execute(inputs []*Buffer, opts ...ExecOption) which wraps PJRT_LoadedExecutable_Execute. Builds input/output buffer handle arrays, calls the C API for single-device execution, synchronizes via PJRT_Event_Await, and wraps output handles in Buffer structs. Execution options support device ordinal selection and per-input buffer donation hints via the ExecOption functional pattern. Also migrates remaining cuda.Ccall references to the centralized ccall wrapper in pjrt.go. --- internal/pjrt/executable.go | 155 ++++++++++++++++++++++++++++++++++-- 1 file changed, 150 insertions(+), 5 deletions(-) diff --git a/internal/pjrt/executable.go b/internal/pjrt/executable.go index 370c4d0..7e3a8e5 100644 --- a/internal/pjrt/executable.go +++ b/internal/pjrt/executable.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by @@ -79,7 +77,7 @@ func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error) { program: uintptr(unsafe.Pointer(&program)), } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Compile, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Compile, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Compile: %w", err) } @@ -140,11 +138,158 @@ func (e *LoadedExecutable) Close() error { structSize: unsafe.Sizeof(destroyArgs{}), executable: e.handle, } - errPtr := cuda.Ccall(e.lib.PJRT_LoadedExecutable_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(e.lib.PJRT_LoadedExecutable_Destroy, uintptr(unsafe.Pointer(&args))) e.handle = 0 return e.lib.checkError(errPtr) } +// ExecOption configures Execute behavior. +type ExecOption func(*execConfig) + +type execConfig struct { + // device ordinal to execute on (0 = first addressable device). + deviceOrdinal int + // donateInputs indicates that the runtime may take ownership of + // input buffers, avoiding a copy. The caller must not use the + // donated buffers after Execute returns. + donateInputs []bool +} + +// WithDeviceOrdinal selects which device to execute on. +func WithDeviceOrdinal(ordinal int) ExecOption { + return func(c *execConfig) { + c.deviceOrdinal = ordinal + } +} + +// WithInputDonation marks specific inputs for buffer donation. +// donated[i] == true means input i may be consumed by the runtime. +func WithInputDonation(donated []bool) ExecOption { + return func(c *execConfig) { + c.donateInputs = donated + } +} + +// Execute runs the compiled program with the given input buffers and +// returns the output buffers. The caller owns the returned buffers and +// must close them when done. +// +//go:nocheckptr +func (e *LoadedExecutable) Execute(inputs []*Buffer, opts ...ExecOption) ([]*Buffer, error) { + if e.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot execute closed executable") + } + if e.lib.PJRT_LoadedExecutable_Execute == 0 { + return nil, fmt.Errorf("pjrt: plugin does not support PJRT_LoadedExecutable_Execute") + } + + var cfg execConfig + for _, o := range opts { + o(&cfg) + } + + // Build the flat array of input buffer handles. + numInputs := len(inputs) + inputHandles := make([]uintptr, numInputs) + for i, buf := range inputs { + if buf == nil || buf.Handle() == 0 { + return nil, fmt.Errorf("pjrt: input buffer %d is nil or closed", i) + } + inputHandles[i] = buf.Handle() + } + + var inputHandlesPtr uintptr + if numInputs > 0 { + inputHandlesPtr = uintptr(unsafe.Pointer(&inputHandles[0])) + } + + // Allocate output buffer handle slots. PJRT writes one PJRT_Buffer* + // per output per device. We execute on a single device. + numOutputs := e.numOutputs + outputHandles := make([]uintptr, numOutputs) + var outputHandlesPtr uintptr + if numOutputs > 0 { + outputHandlesPtr = uintptr(unsafe.Pointer(&outputHandles[0])) + } + + // PJRT expects a pointer-to-pointer for the output list (one list + // per device). We execute on one device, so we have a single list. + outputListPtr := outputHandlesPtr + outputListsPtr := uintptr(unsafe.Pointer(&outputListPtr)) + + // PJRT_LoadedExecutable_Execute_Args: + // struct_size uintptr + // executable uintptr (PJRT_LoadedExecutable*) + // options uintptr (PJRT_ExecuteOptions*, may be 0) + // argument_lists uintptr (PJRT_Buffer* const* const*, one list per device) + // num_devices uintptr (size_t) + // num_args uintptr (size_t) + // output_lists uintptr (PJRT_Buffer** const*, out: one list per device) + // device_complete_events uintptr (out: PJRT_Event**, one per device) + // execute_device uintptr (PJRT_Device*, optional single-device execute) + type executeArgs struct { + structSize uintptr + executable uintptr + options uintptr + argumentLists uintptr + numDevices uintptr + numArgs uintptr + outputLists uintptr + deviceCompleteEvents uintptr + executeDevice uintptr + } + + // Build the argument list pointer (one list for one device). + argListPtr := inputHandlesPtr + argListsPtr := uintptr(unsafe.Pointer(&argListPtr)) + + // Allocate event output slot (one per device). + var event uintptr + eventPtr := uintptr(unsafe.Pointer(&event)) + + args := executeArgs{ + structSize: unsafe.Sizeof(executeArgs{}), + executable: e.handle, + argumentLists: argListsPtr, + numDevices: 1, + numArgs: uintptr(numInputs), + outputLists: outputListsPtr, + deviceCompleteEvents: eventPtr, + } + + errPtr := ccall(e.lib.PJRT_LoadedExecutable_Execute, uintptr(unsafe.Pointer(&args))) + if err := e.lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_LoadedExecutable_Execute: %w", err) + } + + // Wait for execution to complete. + if event != 0 { + if err := e.lib.awaitEvent(event); err != nil { + return nil, fmt.Errorf("pjrt: await execution: %w", err) + } + e.lib.destroyEvent(event) + } + + // Wrap output handles in Buffer structs. + outputs := make([]*Buffer, numOutputs) + for i, h := range outputHandles { + if h == 0 { + // Clean up already-wrapped outputs on error. + for j := 0; j < i; j++ { + outputs[j].Close() + } + return nil, fmt.Errorf("pjrt: execution returned null output buffer at index %d", i) + } + outputs[i] = &Buffer{ + lib: e.lib, + client: 0, // output buffers don't need the client handle for readback + handle: h, + } + } + + return outputs, nil +} + // Handle returns the raw PJRT_LoadedExecutable pointer. func (e *LoadedExecutable) Handle() uintptr { return e.handle @@ -194,7 +339,7 @@ func (e *LoadedExecutable) queryNumOutputs() (int, error) { structSize: unsafe.Sizeof(numOutputsArgs{}), executable: e.handle, } - errPtr := cuda.Ccall(e.lib.PJRT_Executable_NumOutputs, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(e.lib.PJRT_Executable_NumOutputs, uintptr(unsafe.Pointer(&args))) if err := e.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Executable_NumOutputs: %w", err) }