diff --git a/internal/pjrt/buffer.go b/internal/pjrt/buffer.go new file mode 100644 index 0000000..e600255 --- /dev/null +++ b/internal/pjrt/buffer.go @@ -0,0 +1,610 @@ +package pjrt + +import ( + "fmt" + "sync" + "unsafe" + + "github.com/zerfoo/float16" + "github.com/zerfoo/float8" +) + +// ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API. +type ElementType int32 + +const ( + ElementTypeInvalid ElementType = 0 + ElementTypePRED ElementType = 1 // bool + ElementTypeS8 ElementType = 2 // int8 + ElementTypeS16 ElementType = 3 // int16 + ElementTypeS32 ElementType = 4 // int32 + ElementTypeS64 ElementType = 5 // int64 + ElementTypeU8 ElementType = 6 // uint8 + ElementTypeU16 ElementType = 7 // uint16 + ElementTypeU32 ElementType = 8 // uint32 + ElementTypeU64 ElementType = 9 // uint64 + ElementTypeF16 ElementType = 10 // float16 + ElementTypeF32 ElementType = 11 // float32 + ElementTypeF64 ElementType = 12 // float64 + ElementTypeBF16 ElementType = 16 // bfloat16 + ElementTypeF8E4M3 ElementType = 20 // float8 E4M3FN +) + +// String returns the PJRT element type name. +func (t ElementType) String() string { + switch t { + case ElementTypePRED: + return "pred" + case ElementTypeS8: + return "s8" + case ElementTypeS16: + return "s16" + case ElementTypeS32: + return "s32" + case ElementTypeS64: + return "s64" + case ElementTypeU8: + return "u8" + case ElementTypeU16: + return "u16" + case ElementTypeU32: + return "u32" + case ElementTypeU64: + return "u64" + case ElementTypeF16: + return "f16" + case ElementTypeF32: + return "f32" + case ElementTypeF64: + return "f64" + case ElementTypeBF16: + return "bf16" + case ElementTypeF8E4M3: + return "f8e4m3fn" + default: + return fmt.Sprintf("unknown(%d)", int(t)) + } +} + +// ByteSize returns the size in bytes of a single element of this type. +func (t ElementType) ByteSize() int { + switch t { + case ElementTypePRED, ElementTypeS8, ElementTypeU8, ElementTypeF8E4M3: + return 1 + case ElementTypeS16, ElementTypeU16, ElementTypeF16, ElementTypeBF16: + return 2 + case ElementTypeS32, ElementTypeU32, ElementTypeF32: + return 4 + case ElementTypeS64, ElementTypeU64, ElementTypeF64: + return 8 + default: + return 0 + } +} + +// GoTypeToElementType maps a Go type (via its size and kind) to the +// corresponding PJRT element type. +func GoTypeToElementType[T any]() ElementType { + var zero T + switch any(zero).(type) { + case float32: + return ElementTypeF32 + case float64: + return ElementTypeF64 + case float16.Float16: + return ElementTypeF16 + case float16.BFloat16: + return ElementTypeBF16 + case float8.Float8: + return ElementTypeF8E4M3 + case int32: + return ElementTypeS32 + case int64: + return ElementTypeS64 + case int16: + return ElementTypeS16 + case int8: + return ElementTypeS8 + case uint8: + return ElementTypeU8 + case uint16: + return ElementTypeU16 + case uint32: + return ElementTypeU32 + case uint64: + return ElementTypeU64 + case bool: + return ElementTypePRED + default: + return ElementTypeInvalid + } +} + +// HostBufferSemantics controls how PJRT handles the host data pointer +// during BufferFromHostBuffer. +type HostBufferSemantics int32 + +const ( + // HostBufferImmutableOnlyDuringCall means PJRT copies the data during + // the call and the host buffer can be modified immediately after return. + HostBufferImmutableOnlyDuringCall HostBufferSemantics = 0 + + // HostBufferImmutableUntilTransferCompletes means the host buffer must + // remain valid until the returned event completes. Avoids a copy on + // some backends. + HostBufferImmutableUntilTransferCompletes HostBufferSemantics = 1 + + // HostBufferImmutableZeroCopy means PJRT uses the host memory directly + // (zero-copy). The host buffer must remain valid for the buffer lifetime. + HostBufferImmutableZeroCopy HostBufferSemantics = 2 +) + +// Buffer wraps a PJRT_Buffer handle and provides Go-friendly methods +// for device-to-host readback, metadata queries, and lifecycle management. +// +// Buffers must be closed with Close() when no longer needed. Double-close +// is a safe no-op (finalizer safety). +type Buffer struct { + lib *PJRTLib + client uintptr // PJRT_Client* (for readback calls) + handle uintptr // PJRT_Buffer* + + mu sync.Mutex + closed bool +} + +// BufferFromHost transfers a Go slice to a PJRT device buffer. +// +// The data slice is copied during the call (ImmutableOnlyDuringCall semantics +// by default). The shape describes the tensor dimensions. The target device +// determines where the buffer is placed. +// +// Use WithDonation() to enable buffer donation for KV cache optimization. +func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device, opts ...BufferOption) (*Buffer, error) { + if client == nil || client.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from nil or closed client") + } + if device == nil || device.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer on nil or closed device") + } + if len(data) == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from empty data") + } + + elemType := GoTypeToElementType[T]() + if elemType == ElementTypeInvalid { + return nil, fmt.Errorf("pjrt: unsupported Go type for PJRT buffer") + } + + // Verify element count matches shape. + numElements := 1 + for _, d := range shape { + numElements *= d + } + if numElements != len(data) { + return nil, fmt.Errorf("pjrt: shape %v requires %d elements, got %d", shape, numElements, len(data)) + } + + cfg := bufferConfig{ + semantics: HostBufferImmutableOnlyDuringCall, + } + for _, o := range opts { + o(&cfg) + } + + lib := client.lib + + // Build the int64 dimensions array that PJRT expects. + dims := make([]int64, len(shape)) + for i, d := range shape { + dims[i] = int64(d) + } + + var dimsPtr uintptr + if len(dims) > 0 { + dimsPtr = uintptr(unsafe.Pointer(&dims[0])) + } + + // PJRT_Client_BufferFromHostBuffer_Args: + // struct_size uintptr + // client uintptr (PJRT_Client*) + // data uintptr (const void*) + // type int32 (PJRT_Buffer_Type) + // _ [4]byte (padding) + // dims uintptr (const int64_t*) + // num_dims uintptr (size_t) + // byte_strides uintptr (const int64_t*, may be 0) + // num_byte_strides uintptr (size_t) + // host_buffer_semantics int32 (PJRT_HostBufferSemantics) + // _ [4]byte (padding) + // device uintptr (PJRT_Device*) + // memory uintptr (PJRT_Memory*, may be 0) + // device_layout uintptr (PJRT_Buffer_MemoryLayout*, may be 0) + // done_with_host_buffer uintptr (out: PJRT_Event*) + // buffer uintptr (out: PJRT_Buffer*) + type bufferFromHostArgs struct { + structSize uintptr + client uintptr + data uintptr + typ int32 + _ [4]byte + dims uintptr + numDims uintptr + byteStrides uintptr + numByteStrides uintptr + hostBufferSemantics int32 + _ [4]byte + device uintptr + memory uintptr + deviceLayout uintptr + doneWithHostBuffer uintptr + buffer uintptr + } + + args := bufferFromHostArgs{ + structSize: unsafe.Sizeof(bufferFromHostArgs{}), + client: client.handle, + data: uintptr(unsafe.Pointer(&data[0])), + typ: int32(elemType), + dims: dimsPtr, + numDims: uintptr(len(dims)), + hostBufferSemantics: int32(cfg.semantics), + device: device.handle, + } + + errPtr := ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Client_BufferFromHostBuffer: %w", err) + } + if args.buffer == 0 { + return nil, fmt.Errorf("pjrt: PJRT_Client_BufferFromHostBuffer returned null buffer") + } + + // If the transfer produces a done event, wait for it so the host + // buffer is safe to reuse immediately. + if args.doneWithHostBuffer != 0 { + if err := lib.awaitEvent(args.doneWithHostBuffer); err != nil { + return nil, fmt.Errorf("pjrt: await host buffer transfer: %w", err) + } + lib.destroyEvent(args.doneWithHostBuffer) + } + + return &Buffer{ + lib: lib, + client: client.handle, + handle: args.buffer, + }, nil +} + +// ToHost copies device buffer data back to a pre-allocated Go slice. +// +// The destination slice must have exactly the right number of elements +// (product of Shape dimensions). The call blocks until the readback +// completes (PJRT_Event_Await). +func (b *Buffer) ToHost(dst []byte) error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + if len(dst) == 0 { + return fmt.Errorf("pjrt: destination slice is empty") + } + + // PJRT_Buffer_ToHostBuffer_Args: + // struct_size uintptr + // src uintptr (PJRT_Buffer*) + // dst uintptr (void*) + // dst_size uintptr (size_t, bytes) + // event uintptr (out: PJRT_Event*) + type toHostArgs struct { + structSize uintptr + src uintptr + dst uintptr + dstSize uintptr + event uintptr + } + + args := toHostArgs{ + structSize: unsafe.Sizeof(toHostArgs{}), + src: b.handle, + dst: uintptr(unsafe.Pointer(&dst[0])), + dstSize: uintptr(len(dst)), + } + + errPtr := ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return fmt.Errorf("PJRT_Buffer_ToHostBuffer: %w", err) + } + + // Wait for the async readback to complete. + if args.event != 0 { + if err := b.lib.awaitEvent(args.event); err != nil { + return fmt.Errorf("pjrt: await readback: %w", err) + } + b.lib.destroyEvent(args.event) + } + + return nil +} + +// ToHostSlice is a typed convenience wrapper around ToHost that copies +// device buffer data into a pre-allocated Go slice of the appropriate type. +func ToHostSlice[T any](b *Buffer, dst []T) error { + var zero T + elemSize := int(unsafe.Sizeof(zero)) + byteLen := len(dst) * elemSize + bytes := unsafe.Slice((*byte)(unsafe.Pointer(&dst[0])), byteLen) + return b.ToHost(bytes) +} + +// Dtype returns the PJRT element type of this buffer. +func (b *Buffer) Dtype() (ElementType, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return ElementTypeInvalid, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ElementType_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // type int32 (out: PJRT_Buffer_Type) + type elementTypeArgs struct { + structSize uintptr + buffer uintptr + typ int32 + _ [4]byte + } + + args := elementTypeArgs{ + structSize: unsafe.Sizeof(elementTypeArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return ElementTypeInvalid, fmt.Errorf("PJRT_Buffer_ElementType: %w", err) + } + return ElementType(args.typ), nil +} + +// Shape returns the dimensions of this buffer. +func (b *Buffer) Shape() ([]int, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_Dimensions_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // dims uintptr (out: const int64_t*) + // num_dims uintptr (out: size_t) + type dimensionsArgs struct { + structSize uintptr + buffer uintptr + dims uintptr + numDims uintptr + } + + args := dimensionsArgs{ + structSize: unsafe.Sizeof(dimensionsArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Buffer_Dimensions: %w", err) + } + + if args.numDims == 0 { + return nil, nil // scalar + } + + cDims := unsafe.Slice((*int64)(unsafe.Pointer(args.dims)), int(args.numDims)) + shape := make([]int, len(cDims)) + for i, d := range cDims { + shape[i] = int(d) + } + return shape, nil +} + +// OnDeviceSizeInBytes returns the buffer's memory footprint on the device. +func (b *Buffer) OnDeviceSizeInBytes() (int64, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_OnDeviceSizeInBytes_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // on_device_size int64 (out: size_t) + type sizeArgs struct { + structSize uintptr + buffer uintptr + onDeviceSize int64 + } + + args := sizeArgs{ + structSize: unsafe.Sizeof(sizeArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_OnDeviceSizeInBytes: %w", err) + } + return args.onDeviceSize, nil +} + +// ReadyEvent returns the PJRT_Event handle for this buffer's readiness. +// The caller is responsible for destroying the event via awaitEvent or +// destroyEvent. +func (b *Buffer) ReadyEvent() (uintptr, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ReadyEvent_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // event uintptr (out: PJRT_Event*) + type readyEventArgs struct { + structSize uintptr + buffer uintptr + event uintptr + } + + args := readyEventArgs{ + structSize: unsafe.Sizeof(readyEventArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_ReadyEvent: %w", err) + } + return args.event, nil +} + +// Delete marks the buffer for deletion. The runtime may release the +// device memory immediately or defer it. After Delete, the buffer +// handle should not be used for data access, but Destroy is still +// required for handle cleanup. +func (b *Buffer) Delete() error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil + } + b.mu.Unlock() + + // PJRT_Buffer_Delete_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type deleteArgs struct { + structSize uintptr + buffer uintptr + } + + args := deleteArgs{ + structSize: unsafe.Sizeof(deleteArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) + return b.lib.checkError(errPtr) +} + +// Close destroys the PJRT buffer handle and releases associated resources. +// Safe to call multiple times (double-close is a no-op for finalizer safety). +func (b *Buffer) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.closed { + return nil + } + b.closed = true + + // PJRT_Buffer_Destroy_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type destroyArgs struct { + structSize uintptr + buffer uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + buffer: b.handle, + } + + errPtr := ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) + b.handle = 0 + return b.lib.checkError(errPtr) +} + +// Handle returns the raw PJRT_Buffer pointer. +func (b *Buffer) Handle() uintptr { + return b.handle +} + +// awaitEvent calls PJRT_Event_Await to block until the event completes. +func (lib *PJRTLib) awaitEvent(event uintptr) error { + if event == 0 { + return nil + } + + // PJRT_Event_Await_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type awaitArgs struct { + structSize uintptr + event uintptr + } + + args := awaitArgs{ + structSize: unsafe.Sizeof(awaitArgs{}), + event: event, + } + + errPtr := ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) + return lib.checkError(errPtr) +} + +// destroyEvent frees a PJRT_Event. Safe to call with event == 0. +func (lib *PJRTLib) destroyEvent(event uintptr) { + if event == 0 { + return + } + + // PJRT_Event_Destroy_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type destroyArgs struct { + structSize uintptr + event uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + event: event, + } + ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) +} + +// BufferOption configures BufferFromHost behavior. +type BufferOption func(*bufferConfig) + +type bufferConfig struct { + semantics HostBufferSemantics +} + +// WithSemantics sets the host buffer semantics for the transfer. +func WithSemantics(s HostBufferSemantics) BufferOption { + return func(c *bufferConfig) { + c.semantics = s + } +} + +// WithDonation enables buffer donation semantics. The runtime is allowed +// to take ownership of the host memory, avoiding a copy. The caller must +// not access the source slice after calling BufferFromHost with this option. +func WithDonation() BufferOption { + return func(c *bufferConfig) { + c.semantics = HostBufferImmutableZeroCopy + } +} diff --git a/internal/pjrt/client.go b/internal/pjrt/client.go index 6ed2169..1353190 100644 --- a/internal/pjrt/client.go +++ b/internal/pjrt/client.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Client wraps a PJRT_Client handle and provides Go-friendly methods @@ -57,7 +55,7 @@ func NewClient(lib *PJRTLib, opts ...ClientOption) (*Client, error) { createOptions: cfg.createOptions, } - errPtr := cuda.Ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) if err := lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Create: %w", err) } @@ -83,7 +81,7 @@ func (c *Client) Close() error { structSize: unsafe.Sizeof(destroyArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) c.handle = 0 return c.lib.checkError(errPtr) } @@ -105,7 +103,7 @@ func (c *Client) PlatformName() (string, error) { structSize: unsafe.Sizeof(platformNameArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformName: %w", err) } @@ -124,7 +122,7 @@ func (c *Client) PlatformVersion() (string, error) { structSize: unsafe.Sizeof(platformVersionArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformVersion: %w", err) } @@ -148,7 +146,7 @@ func (c *Client) Devices() ([]*Device, error) { structSize: unsafe.Sizeof(devicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Devices: %w", err) } @@ -167,7 +165,7 @@ func (c *Client) AddressableDevices() ([]*Device, error) { structSize: unsafe.Sizeof(addressableDevicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_AddressableDevices: %w", err) } diff --git a/internal/pjrt/device.go b/internal/pjrt/device.go index eeeb94a..ff573ba 100644 --- a/internal/pjrt/device.go +++ b/internal/pjrt/device.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Device wraps a PJRT_Device handle and provides methods for @@ -37,7 +35,7 @@ func (d *Device) ID() (int, error) { structSize: unsafe.Sizeof(idArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_DeviceDescription_Id: %w", err) } @@ -66,7 +64,7 @@ func (d *Device) Kind() (string, error) { structSize: unsafe.Sizeof(kindArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_DeviceDescription_Kind: %w", err) } @@ -90,7 +88,7 @@ func (d *Device) IsAddressable() (bool, error) { structSize: unsafe.Sizeof(isAddressableArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return false, fmt.Errorf("PJRT_Device_IsAddressable: %w", err) } @@ -114,7 +112,7 @@ func (d *Device) LocalHardwareId() (int, error) { structSize: unsafe.Sizeof(localHWIDArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_LocalHardwareId: %w", err) } @@ -142,7 +140,7 @@ func (d *Device) getDescription() (uintptr, error) { structSize: unsafe.Sizeof(getDescArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_GetDescription: %w", err) } diff --git a/internal/pjrt/pjrt.go b/internal/pjrt/pjrt.go index 109917d..13247da 100644 --- a/internal/pjrt/pjrt.go +++ b/internal/pjrt/pjrt.go @@ -132,7 +132,7 @@ func Load(pluginName string) (*PJRTLib, error) { lib := &PJRTLib{} var lastErr string for _, path := range candidates { - h, err := cuda.DlopenPath(path) + h, err := dlopenPath(path) if err == nil { lib.handle = h break @@ -144,14 +144,14 @@ func Load(pluginName string) (*PJRTLib, error) { } // Resolve the single entry point. - getPjrtApi, err := cuda.Dlsym(lib.handle, "GetPjrtApi") + getPjrtApi, err := dlsym(lib.handle, "GetPjrtApi") if err != nil { lib.Close() return nil, fmt.Errorf("pjrt: dlsym GetPjrtApi: %w", err) } // Call GetPjrtApi() -> *PJRT_Api. - apiPtr := cuda.Ccall(getPjrtApi) + apiPtr := ccall(getPjrtApi) if apiPtr == 0 { lib.Close() return nil, fmt.Errorf("pjrt: GetPjrtApi returned null") @@ -240,7 +240,7 @@ func (lib *PJRTLib) errorMessage(errPtr uintptr) string { structSize: unsafe.Sizeof(errorMessageArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) if args.message == 0 || args.messageLen == 0 { return "unknown PJRT error" @@ -261,7 +261,7 @@ func (lib *PJRTLib) destroyError(errPtr uintptr) { structSize: unsafe.Sizeof(destroyArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) } // checkError converts a PJRT_Error pointer to a Go error. @@ -275,6 +275,23 @@ func (lib *PJRTLib) checkError(errPtr uintptr) error { return fmt.Errorf("pjrt: %s", msg) } +// ccall calls a C function pointer with the given arguments. +// Centralizes the internal/cuda dependency so other files in this +// package do not need to import it directly. +func ccall(fn uintptr, args ...uintptr) uintptr { + return cuda.Ccall(fn, args...) +} + +// dlopenPath opens a shared library at the given path. +func dlopenPath(path string) (uintptr, error) { + return cuda.DlopenPath(path) +} + +// dlsym resolves a symbol from a dlopen handle. +func dlsym(handle uintptr, name string) (uintptr, error) { + return cuda.Dlsym(handle, name) +} + // goStringN converts a C string pointer and length to a Go string. // //go:nosplit diff --git a/internal/stablehlo/emit.go b/internal/stablehlo/emit.go new file mode 100644 index 0000000..ccb0165 --- /dev/null +++ b/internal/stablehlo/emit.go @@ -0,0 +1,272 @@ +package stablehlo + +import "fmt" + +// Emitter generates StableHLO MLIR text from operation inputs. +// Each emit method takes SSA input names, tensor shapes, and a dtype, +// and returns the emitted MLIR line(s) plus the output SSA name. +type Emitter struct { + Namer *SSANamer +} + +// NewEmitter creates an Emitter with a fresh SSANamer. +func NewEmitter() *Emitter { + return &Emitter{Namer: &SSANamer{}} +} + +// EmitBinaryElementwise emits a binary element-wise op (add, subtract, multiply, divide, power). +// Both inputs must have the same shape and dtype. +func (e *Emitter) EmitBinaryElementwise(opName, lhs, rhs string, shape []int, dtype string) (mlir, outName string) { + outName = e.Namer.NextName() + ty := FormatTensorType(shape, dtype) + mlir = fmt.Sprintf("%s = %s %s, %s : %s", outName, opName, lhs, rhs, ty) + return mlir, outName +} + +// EmitAdd emits stablehlo.add. +func (e *Emitter) EmitAdd(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpAdd, lhs, rhs, shape, dtype) +} + +// EmitSub emits stablehlo.subtract. +func (e *Emitter) EmitSub(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpSubtract, lhs, rhs, shape, dtype) +} + +// EmitMul emits stablehlo.multiply. +func (e *Emitter) EmitMul(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpMultiply, lhs, rhs, shape, dtype) +} + +// EmitDiv emits stablehlo.divide. +func (e *Emitter) EmitDiv(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpDivide, lhs, rhs, shape, dtype) +} + +// EmitPow emits stablehlo.power. +func (e *Emitter) EmitPow(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpPower, lhs, rhs, shape, dtype) +} + +// EmitUnary emits a unary element-wise op (exponential, log, sine, cosine, tanh, sqrt, rsqrt, negate). +func (e *Emitter) EmitUnary(opName, input string, shape []int, dtype string) (mlir, outName string) { + outName = e.Namer.NextName() + ty := FormatTensorType(shape, dtype) + mlir = fmt.Sprintf("%s = %s %s : %s", outName, opName, input, ty) + return mlir, outName +} + +// EmitExp emits stablehlo.exponential. +func (e *Emitter) EmitExp(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpExp, input, shape, dtype) +} + +// EmitLog emits stablehlo.log. +func (e *Emitter) EmitLog(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpLog, input, shape, dtype) +} + +// EmitSin emits stablehlo.sine. +func (e *Emitter) EmitSin(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpSin, input, shape, dtype) +} + +// EmitCos emits stablehlo.cosine. +func (e *Emitter) EmitCos(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpCos, input, shape, dtype) +} + +// EmitTanh emits stablehlo.tanh. +func (e *Emitter) EmitTanh(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpTanh, input, shape, dtype) +} + +// EmitSqrt emits stablehlo.sqrt. +func (e *Emitter) EmitSqrt(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpSqrt, input, shape, dtype) +} + +// EmitRsqrt emits stablehlo.rsqrt. +func (e *Emitter) EmitRsqrt(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpRsqrt, input, shape, dtype) +} + +// EmitNeg emits stablehlo.negate. +func (e *Emitter) EmitNeg(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpNegate, input, shape, dtype) +} + +// EmitScalarOp emits a scalar operation as three MLIR instructions: +// 1. stablehlo.constant for the scalar value +// 2. stablehlo.broadcast_in_dim to broadcast to the tensor shape +// 3. The element-wise binary op +// +// Returns all three lines (newline-separated) and the final output SSA name. +func (e *Emitter) EmitScalarOp(elemOp, input string, scalar float64, shape []int, dtype string) (mlir, outName string) { + ty := FormatTensorType(shape, dtype) + scalarTy := FormatTensorType(nil, dtype) + + // 1. Constant + constName := e.Namer.NextName() + constLine := fmt.Sprintf("%s = %s dense<%v> : %s", constName, OpConstant, scalar, scalarTy) + + // 2. Broadcast + bcastName := e.Namer.NextName() + bcastLine := fmt.Sprintf("%s = %s %s, dims = [] : (%s) -> %s", bcastName, OpBroadcastIn, constName, scalarTy, ty) + + // 3. Element-wise op + outName = e.Namer.NextName() + opLine := fmt.Sprintf("%s = %s %s, %s : %s", outName, elemOp, input, bcastName, ty) + + mlir = constLine + "\n" + bcastLine + "\n" + opLine + return mlir, outName +} + +// EmitMulScalar emits stablehlo.constant + broadcast_in_dim + multiply. +func (e *Emitter) EmitMulScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpMultiply, input, scalar, shape, dtype) +} + +// EmitAddScalar emits stablehlo.constant + broadcast_in_dim + add. +func (e *Emitter) EmitAddScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpAdd, input, scalar, shape, dtype) +} + +// EmitDivScalar emits stablehlo.constant + broadcast_in_dim + divide. +func (e *Emitter) EmitDivScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpDivide, input, scalar, shape, dtype) +} + +// EmitOp dispatches to the appropriate emit function based on the engine op name. +// For binary ops, inputs should be [lhs, rhs]. For unary ops, inputs should be [input]. +// For scalar ops, inputs should be [input] and attrs must contain "scalar" (float64). +// Returns the emitted MLIR text and the output SSA name. +func (e *Emitter) EmitOp(opName string, inputs []string, shape []int, dtype string, attrs map[string]any) (string, string, error) { + switch opName { + case "Add": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitAdd(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Sub": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSub(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Mul": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitMul(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Div": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitDiv(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Pow": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitPow(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Exp": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitExp(inputs[0], shape, dtype) + return mlir, out, nil + case "Log": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitLog(inputs[0], shape, dtype) + return mlir, out, nil + case "Sin": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSin(inputs[0], shape, dtype) + return mlir, out, nil + case "Cos": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitCos(inputs[0], shape, dtype) + return mlir, out, nil + case "Tanh": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitTanh(inputs[0], shape, dtype) + return mlir, out, nil + case "Sqrt": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSqrt(inputs[0], shape, dtype) + return mlir, out, nil + case "Rsqrt": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitRsqrt(inputs[0], shape, dtype) + return mlir, out, nil + case "Neg": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitNeg(inputs[0], shape, dtype) + return mlir, out, nil + case "MulScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitMulScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + case "AddScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitAddScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + case "DivScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitDivScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + default: + return "", "", fmt.Errorf("EmitOp: unsupported op %q", opName) + } +} + +func scalarAttr(opName string, attrs map[string]any) (float64, error) { + if attrs == nil { + return 0, fmt.Errorf("EmitOp(%s): attrs map is nil, need \"scalar\" key", opName) + } + v, ok := attrs["scalar"] + if !ok { + return 0, fmt.Errorf("EmitOp(%s): missing \"scalar\" in attrs", opName) + } + s, ok := v.(float64) + if !ok { + return 0, fmt.Errorf("EmitOp(%s): \"scalar\" attr is %T, want float64", opName, v) + } + return s, nil +} diff --git a/internal/stablehlo/emit_structural.go b/internal/stablehlo/emit_structural.go new file mode 100644 index 0000000..8b2bb58 --- /dev/null +++ b/internal/stablehlo/emit_structural.go @@ -0,0 +1,207 @@ +package stablehlo + +import ( + "fmt" + "strings" +) + +// EmitMatMul emits a stablehlo.dot_general operation for matrix multiplication. +// Handles 2D (MxK @ KxN) and batched (BxMxK @ BxKxN) cases. +// Returns the MLIR line and the SSA name assigned to the result. +func EmitMatMul(namer *SSANamer, lhs, rhs string, lhsShape, rhsShape []int, dtype string) (string, string, error) { + if len(lhsShape) < 2 || len(rhsShape) < 2 { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: inputs must be at least rank 2, got rank %d and %d", len(lhsShape), len(rhsShape)) + } + if len(lhsShape) != len(rhsShape) { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: rank mismatch: %d vs %d", len(lhsShape), len(rhsShape)) + } + + rank := len(lhsShape) + // Contraction dimension: last axis of LHS, second-to-last axis of RHS. + lhsContract := rank - 1 + rhsContract := rank - 2 + + if lhsShape[lhsContract] != rhsShape[rhsContract] { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: contraction dimension mismatch: %d vs %d", lhsShape[lhsContract], rhsShape[rhsContract]) + } + + outShape, err := InferStructuralShape("MatMul", [][]int{lhsShape, rhsShape}, nil) + if err != nil { + return "", "", err + } + + result := namer.NextName() + lhsType := FormatTensorType(lhsShape, dtype) + rhsType := FormatTensorType(rhsShape, dtype) + outType := FormatTensorType(outShape, dtype) + + // Build batch dimensions list. + var batchDims []string + for i := 0; i < rank-2; i++ { + batchDims = append(batchDims, fmt.Sprintf("%d", i)) + } + + var b strings.Builder + fmt.Fprintf(&b, "%s = %s %s, %s, batching_dims = [%s] x [%s], contracting_dims = [%d] x [%d] : (%s, %s) -> %s", + result, OpDotGeneral, lhs, rhs, + strings.Join(batchDims, ", "), strings.Join(batchDims, ", "), + lhsContract, rhsContract, + lhsType, rhsType, outType, + ) + + return b.String(), result, nil +} + +// EmitTranspose emits a stablehlo.transpose operation. +// perm specifies the axis permutation (e.g., [2, 0, 1]). +func EmitTranspose(namer *SSANamer, operand string, shape []int, perm []int, dtype string) (string, string, error) { + if len(perm) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitTranspose: perm length %d does not match rank %d", len(perm), len(shape)) + } + + outShape, err := InferStructuralShape("Transpose", [][]int{shape}, map[string]any{"perm": perm}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + inType := FormatTensorType(shape, dtype) + outType := FormatTensorType(outShape, dtype) + + permStrs := make([]string, len(perm)) + for i, p := range perm { + permStrs[i] = fmt.Sprintf("%d", p) + } + + line := fmt.Sprintf("%s = %s %s, permutation = [%s] : (%s) -> %s", + result, OpTranspose, operand, + strings.Join(permStrs, ", "), + inType, outType, + ) + + return line, result, nil +} + +// EmitReshape emits a stablehlo.reshape operation. +// targetShape is the desired output shape. +func EmitReshape(namer *SSANamer, operand string, inShape, targetShape []int, dtype string) (string, string, error) { + outShape, err := InferStructuralShape("Reshape", [][]int{inShape}, map[string]any{"shape": targetShape}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + inType := FormatTensorType(inShape, dtype) + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s : (%s) -> %s", + result, OpReshape, operand, + inType, outType, + ) + + return line, result, nil +} + +// EmitConcat emits a stablehlo.concatenate operation along the given axis. +// operands are the SSA names, shapes are the corresponding tensor shapes. +func EmitConcat(namer *SSANamer, operands []string, shapes [][]int, axis int, dtype string) (string, string, error) { + if len(operands) != len(shapes) { + return "", "", fmt.Errorf("stablehlo.EmitConcat: operand count %d does not match shape count %d", len(operands), len(shapes)) + } + + outShape, err := InferStructuralShape("Concat", shapes, map[string]any{"axis": axis}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s, dimension = %d : %s", + result, OpConcatenate, strings.Join(operands, ", "), + axis, outType, + ) + + return line, result, nil +} + +// EmitSlice emits a stablehlo.slice operation with start, limit, and stride indices. +// strides may be nil, in which case all strides default to 1. +func EmitSlice(namer *SSANamer, operand string, shape, start, limit, strides []int, dtype string) (string, string, error) { + if len(start) != len(shape) || len(limit) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitSlice: start/limit length must match rank %d", len(shape)) + } + if strides == nil { + strides = make([]int, len(shape)) + for i := range strides { + strides[i] = 1 + } + } + if len(strides) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitSlice: strides length %d must match rank %d", len(strides), len(shape)) + } + + // Compute output shape: ceil((limit[i] - start[i]) / strides[i]). + outShape := make([]int, len(shape)) + for i := range shape { + if start[i] < 0 || limit[i] > shape[i] || start[i] > limit[i] || strides[i] <= 0 { + return "", "", fmt.Errorf("stablehlo.EmitSlice: invalid range [%d:%d] stride %d for dimension %d (size %d)", start[i], limit[i], strides[i], i, shape[i]) + } + outShape[i] = (limit[i] - start[i] + strides[i] - 1) / strides[i] + } + + result := namer.NextName() + inType := FormatTensorType(shape, dtype) + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s, starts = [%s], limits = [%s], strides = [%s] : (%s) -> %s", + result, OpSlice, operand, + formatIntSlice(start), formatIntSlice(limit), formatIntSlice(strides), + inType, outType, + ) + + return line, result, nil +} + +// EmitGather emits a stablehlo.gather operation. +// operandShape is the shape of the data tensor, indicesShape is the shape of the index tensor. +// sliceSizes specifies the size of each gathered slice. +// offsetDims, collapsedSliceDims, startIndexMap are the gather dimension numbers. +// indexVectorDim is the dimension in the indices tensor that contains the index vector. +func EmitGather(namer *SSANamer, operand, indices string, + operandShape, indicesShape, sliceSizes []int, + offsetDims, collapsedSliceDims, startIndexMap []int, + indexVectorDim int, + dtype string, +) (string, string, error) { + // Compute output shape from the gather semantics. + outShape, err := InferStructuralShape("Gather", [][]int{operandShape, indicesShape}, map[string]any{"sliceSizes": sliceSizes}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + outType := FormatTensorType(outShape, dtype) + + var b strings.Builder + fmt.Fprintf(&b, "%s = %s %s, %s, offset_dims = [%s], collapsed_slice_dims = [%s], start_index_map = [%s], index_vector_dim = %d, slice_sizes = [%s] : %s", + result, OpGather, operand, indices, + formatIntSlice(offsetDims), + formatIntSlice(collapsedSliceDims), + formatIntSlice(startIndexMap), + indexVectorDim, + formatIntSlice(sliceSizes), + outType, + ) + + return b.String(), result, nil +} + +// formatIntSlice formats an int slice as a comma-separated string (e.g., "0, 1, 2"). +func formatIntSlice(s []int) string { + parts := make([]string, len(s)) + for i, v := range s { + parts[i] = fmt.Sprintf("%d", v) + } + return strings.Join(parts, ", ") +} diff --git a/internal/stablehlo/emit_structural_test.go b/internal/stablehlo/emit_structural_test.go new file mode 100644 index 0000000..117ed3d --- /dev/null +++ b/internal/stablehlo/emit_structural_test.go @@ -0,0 +1,241 @@ +package stablehlo + +import ( + "strings" + "testing" +) + +func TestEmitMatMul2D(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitMatMul(namer, "%a", "%b", []int{4, 3}, []int{3, 5}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.dot_general %a, %b, batching_dims = [] x [], contracting_dims = [1] x [0] : (tensor<4x3xf32>, tensor<3x5xf32>) -> tensor<4x5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitMatMulBatched(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitMatMul(namer, "%a", "%b", []int{2, 4, 3}, []int{2, 3, 5}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.dot_general %a, %b, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x4x3xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitMatMulContractionMismatch(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitMatMul(namer, "%a", "%b", []int{4, 3}, []int{7, 5}, DTypeF32) + if err == nil { + t.Fatal("expected error for contraction dimension mismatch") + } +} + +func TestEmitMatMulRank1(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitMatMul(namer, "%a", "%b", []int{4}, []int{4}, DTypeF32) + if err == nil { + t.Fatal("expected error for rank-1 inputs") + } +} + +func TestEmitTranspose(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitTranspose(namer, "%a", []int{2, 3, 4}, []int{2, 0, 1}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.transpose %a, permutation = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitTransposeInvalidPerm(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitTranspose(namer, "%a", []int{2, 3}, []int{0}, DTypeF32) + if err == nil { + t.Fatal("expected error for perm length mismatch") + } +} + +func TestEmitReshape(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitReshape(namer, "%a", []int{2, 3, 4}, []int{6, 4}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.reshape %a : (tensor<2x3x4xf32>) -> tensor<6x4xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitReshapeElementMismatch(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitReshape(namer, "%a", []int{2, 3}, []int{7}, DTypeF32) + if err == nil { + t.Fatal("expected error for element count mismatch") + } +} + +func TestEmitConcat(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitConcat(namer, + []string{"%a", "%b"}, + [][]int{{2, 3}, {2, 5}}, + 1, DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.concatenate %a, %b, dimension = 1 : tensor<2x8xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitConcatThreeInputs(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitConcat(namer, + []string{"%a", "%b", "%c"}, + [][]int{{4, 2}, {4, 3}, {4, 1}}, + 1, DTypeF64, + ) + if err != nil { + t.Fatal(err) + } + want := `%v0 = stablehlo.concatenate %a, %b, %c, dimension = 1 : tensor<4x6xf64>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSlice(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitSlice(namer, "%a", + []int{8, 6}, + []int{1, 0}, []int{5, 4}, nil, + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.slice %a, starts = [1, 0], limits = [5, 4], strides = [1, 1] : (tensor<8x6xf32>) -> tensor<4x4xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSliceWithStrides(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitSlice(namer, "%a", + []int{10}, + []int{0}, []int{10}, []int{2}, + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + want := `%v0 = stablehlo.slice %a, starts = [0], limits = [10], strides = [2] : (tensor<10xf32>) -> tensor<5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSliceInvalidRange(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitSlice(namer, "%a", []int{4}, []int{3}, []int{1}, nil, DTypeF32) + if err == nil { + t.Fatal("expected error for invalid range (start > limit)") + } +} + +func TestEmitGather(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitGather(namer, "%data", "%indices", + []int{10, 8}, // operand shape + []int{3, 1}, // indices shape + []int{1, 8}, // slice sizes + []int{1}, // offset dims + []int{0}, // collapsed slice dims + []int{0}, // start index map + 1, // index vector dim + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + // Output shape from InferStructuralShape(Gather): indices[:-1] + sliceSizes = [3] + [1, 8] = [3, 1, 8] + if !strings.Contains(line, "stablehlo.gather") { + t.Errorf("expected stablehlo.gather in output, got: %s", line) + } + if !strings.Contains(line, "offset_dims = [1]") { + t.Errorf("expected offset_dims = [1], got: %s", line) + } + if !strings.Contains(line, "collapsed_slice_dims = [0]") { + t.Errorf("expected collapsed_slice_dims = [0], got: %s", line) + } + if !strings.Contains(line, "start_index_map = [0]") { + t.Errorf("expected start_index_map = [0], got: %s", line) + } + if !strings.Contains(line, "index_vector_dim = 1") { + t.Errorf("expected index_vector_dim = 1, got: %s", line) + } + if !strings.Contains(line, "slice_sizes = [1, 8]") { + t.Errorf("expected slice_sizes = [1, 8], got: %s", line) + } +} + +func TestEmitMatMulF16(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitMatMul(namer, "%a", "%b", []int{8, 16}, []int{16, 32}, DTypeF16) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(line, "f16") { + t.Errorf("expected f16 dtype in output, got: %s", line) + } + if !strings.Contains(line, "tensor<8x32xf16>") { + t.Errorf("expected output type tensor<8x32xf16>, got: %s", line) + } +} + +func TestSSANamerCounterAdvances(t *testing.T) { + namer := &SSANamer{} + + _, r0, _ := EmitReshape(namer, "%a", []int{6}, []int{2, 3}, DTypeF32) + _, r1, _ := EmitReshape(namer, "%b", []int{6}, []int{3, 2}, DTypeF32) + _, r2, _ := EmitTranspose(namer, "%c", []int{2, 3}, []int{1, 0}, DTypeF32) + + if r0 != "%v0" || r1 != "%v1" || r2 != "%v2" { + t.Errorf("expected %%v0, %%v1, %%v2 but got %s, %s, %s", r0, r1, r2) + } +} diff --git a/internal/stablehlo/emit_test.go b/internal/stablehlo/emit_test.go new file mode 100644 index 0000000..688e6e2 --- /dev/null +++ b/internal/stablehlo/emit_test.go @@ -0,0 +1,247 @@ +package stablehlo + +import ( + "strings" + "testing" +) + +func TestEmitBinaryElementwise(t *testing.T) { + tests := []struct { + name string + op string + lhs string + rhs string + shape []int + dtype string + wantOp string + }{ + {"add 2D", OpAdd, "%v0", "%v1", []int{2, 3}, DTypeF32, "stablehlo.add"}, + {"sub 1D", OpSubtract, "%a", "%b", []int{8}, DTypeF64, "stablehlo.subtract"}, + {"mul 3D", OpMultiply, "%x", "%y", []int{1, 4, 4}, DTypeBF16, "stablehlo.multiply"}, + {"div scalar", OpDivide, "%p", "%q", []int{}, DTypeF32, "stablehlo.divide"}, + {"pow 2D", OpPower, "%a", "%b", []int{3, 3}, DTypeF32, "stablehlo.power"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitBinaryElementwise(tt.op, tt.lhs, tt.rhs, tt.shape, tt.dtype) + wantTy := FormatTensorType(tt.shape, tt.dtype) + wantLine := out + " = " + tt.wantOp + " " + tt.lhs + ", " + tt.rhs + " : " + wantTy + if mlir != wantLine { + t.Errorf("got:\n %s\nwant:\n %s", mlir, wantLine) + } + if out != "%v0" { + t.Errorf("output name = %q, want %%v0", out) + } + }) + } +} + +func TestEmitAdd(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitAdd("%arg0", "%arg1", []int{2, 3}, DTypeF32) + want := "%v0 = stablehlo.add %arg0, %arg1 : tensor<2x3xf32>" + if mlir != want { + t.Errorf("EmitAdd:\n got: %s\n want: %s", mlir, want) + } + if out != "%v0" { + t.Errorf("output = %q, want %%v0", out) + } +} + +func TestEmitUnaryOps(t *testing.T) { + tests := []struct { + name string + emit func(*Emitter, string, []int, string) (string, string) + wantOp string + }{ + {"Exp", (*Emitter).EmitExp, "stablehlo.exponential"}, + {"Log", (*Emitter).EmitLog, "stablehlo.log"}, + {"Sin", (*Emitter).EmitSin, "stablehlo.sine"}, + {"Cos", (*Emitter).EmitCos, "stablehlo.cosine"}, + {"Tanh", (*Emitter).EmitTanh, "stablehlo.tanh"}, + {"Sqrt", (*Emitter).EmitSqrt, "stablehlo.sqrt"}, + {"Rsqrt", (*Emitter).EmitRsqrt, "stablehlo.rsqrt"}, + {"Neg", (*Emitter).EmitNeg, "stablehlo.negate"}, + } + shape := []int{4, 8} + dtype := DTypeF32 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := tt.emit(e, "%input", shape, dtype) + wantTy := FormatTensorType(shape, dtype) + want := out + " = " + tt.wantOp + " %input : " + wantTy + if mlir != want { + t.Errorf("got:\n %s\nwant:\n %s", mlir, want) + } + if out != "%v0" { + t.Errorf("output = %q, want %%v0", out) + } + }) + } +} + +func TestEmitScalarOps(t *testing.T) { + tests := []struct { + name string + emit func(*Emitter, string, float64, []int, string) (string, string) + wantOp string + }{ + {"MulScalar", (*Emitter).EmitMulScalar, "stablehlo.multiply"}, + {"AddScalar", (*Emitter).EmitAddScalar, "stablehlo.add"}, + {"DivScalar", (*Emitter).EmitDivScalar, "stablehlo.divide"}, + } + shape := []int{2, 3} + dtype := DTypeF32 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := tt.emit(e, "%x", 2.5, shape, dtype) + + lines := strings.Split(mlir, "\n") + if len(lines) != 3 { + t.Fatalf("expected 3 lines, got %d:\n%s", len(lines), mlir) + } + + // Line 1: constant + if !strings.Contains(lines[0], OpConstant) { + t.Errorf("line 0 missing %s: %s", OpConstant, lines[0]) + } + if !strings.Contains(lines[0], "dense<2.5>") { + t.Errorf("line 0 missing dense<2.5>: %s", lines[0]) + } + if !strings.Contains(lines[0], "tensor") { + t.Errorf("line 0 missing scalar type: %s", lines[0]) + } + + // Line 2: broadcast_in_dim + if !strings.Contains(lines[1], OpBroadcastIn) { + t.Errorf("line 1 missing %s: %s", OpBroadcastIn, lines[1]) + } + if !strings.Contains(lines[1], "tensor<2x3xf32>") { + t.Errorf("line 1 missing output type: %s", lines[1]) + } + + // Line 3: element-wise op + if !strings.Contains(lines[2], tt.wantOp) { + t.Errorf("line 2 missing %s: %s", tt.wantOp, lines[2]) + } + + if out != "%v2" { + t.Errorf("output = %q, want %%v2 (const=%%v0, bcast=%%v1, op=%%v2)", out) + } + }) + } +} + +func TestEmitScalarOpFullOutput(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitMulScalar("%arg0", 3, []int{4}, DTypeF32) + want := "%v0 = stablehlo.constant dense<3> : tensor\n" + + "%v1 = stablehlo.broadcast_in_dim %v0, dims = [] : (tensor) -> tensor<4xf32>\n" + + "%v2 = stablehlo.multiply %arg0, %v1 : tensor<4xf32>" + if mlir != want { + t.Errorf("EmitMulScalar full output:\n got:\n%s\n want:\n%s", mlir, want) + } + if out != "%v2" { + t.Errorf("output = %q, want %%v2", out) + } +} + +func TestEmitOpDispatch(t *testing.T) { + shape := []int{2, 4} + dtype := DTypeF32 + + tests := []struct { + name string + opName string + inputs []string + attrs map[string]any + wantOp string + wantErr bool + }{ + {"Add", "Add", []string{"%a", "%b"}, nil, "stablehlo.add", false}, + {"Sub", "Sub", []string{"%a", "%b"}, nil, "stablehlo.subtract", false}, + {"Mul", "Mul", []string{"%a", "%b"}, nil, "stablehlo.multiply", false}, + {"Div", "Div", []string{"%a", "%b"}, nil, "stablehlo.divide", false}, + {"Pow", "Pow", []string{"%a", "%b"}, nil, "stablehlo.power", false}, + {"Exp", "Exp", []string{"%a"}, nil, "stablehlo.exponential", false}, + {"Log", "Log", []string{"%a"}, nil, "stablehlo.log", false}, + {"Sin", "Sin", []string{"%a"}, nil, "stablehlo.sine", false}, + {"Cos", "Cos", []string{"%a"}, nil, "stablehlo.cosine", false}, + {"Tanh", "Tanh", []string{"%a"}, nil, "stablehlo.tanh", false}, + {"Sqrt", "Sqrt", []string{"%a"}, nil, "stablehlo.sqrt", false}, + {"Rsqrt", "Rsqrt", []string{"%a"}, nil, "stablehlo.rsqrt", false}, + {"Neg", "Neg", []string{"%a"}, nil, "stablehlo.negate", false}, + {"MulScalar", "MulScalar", []string{"%a"}, map[string]any{"scalar": 2.0}, "stablehlo.multiply", false}, + {"AddScalar", "AddScalar", []string{"%a"}, map[string]any{"scalar": 1.0}, "stablehlo.add", false}, + {"DivScalar", "DivScalar", []string{"%a"}, map[string]any{"scalar": 4.0}, "stablehlo.divide", false}, + {"unsupported", "Softmax", []string{"%a"}, nil, "", true}, + {"wrong inputs binary", "Add", []string{"%a"}, nil, "", true}, + {"wrong inputs unary", "Exp", []string{"%a", "%b"}, nil, "", true}, + {"missing scalar attr", "MulScalar", []string{"%a"}, nil, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, _, err := e.EmitOp(tt.opName, tt.inputs, shape, dtype, tt.attrs) + if tt.wantErr { + if err == nil { + t.Errorf("expected error, got mlir: %s", mlir) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(mlir, tt.wantOp) { + t.Errorf("EmitOp(%s) output missing %s:\n%s", tt.opName, tt.wantOp, mlir) + } + }) + } +} + +func TestEmitSSACounterProgresses(t *testing.T) { + e := NewEmitter() + _, out1 := e.EmitAdd("%a", "%b", []int{2}, DTypeF32) + _, out2 := e.EmitExp("%c", []int{2}, DTypeF32) + _, out3 := e.EmitSub("%d", "%e", []int{2}, DTypeF32) + + if out1 != "%v0" || out2 != "%v1" || out3 != "%v2" { + t.Errorf("SSA names = [%s, %s, %s], want [%%v0, %%v1, %%v2]", out1, out2, out3) + } + + if e.Namer.Count() != 3 { + t.Errorf("namer count = %d, want 3", e.Namer.Count()) + } +} + +func TestEmitScalarOpSSACounterProgresses(t *testing.T) { + e := NewEmitter() + // MulScalar uses 3 SSA names (const, broadcast, op). + _, out1 := e.EmitMulScalar("%x", 2.0, []int{4}, DTypeF32) + // Next op should get %v3. + _, out2 := e.EmitAdd("%a", "%b", []int{4}, DTypeF32) + + if out1 != "%v2" { + t.Errorf("MulScalar output = %q, want %%v2", out1) + } + if out2 != "%v3" { + t.Errorf("Add output after scalar = %q, want %%v3", out2) + } +} + +func TestEmitDifferentDtypes(t *testing.T) { + dtypes := []string{DTypeF32, DTypeF64, DTypeF16, DTypeBF16} + for _, dtype := range dtypes { + t.Run(dtype, func(t *testing.T) { + e := NewEmitter() + mlir, _ := e.EmitAdd("%a", "%b", []int{2, 3}, dtype) + wantTy := FormatTensorType([]int{2, 3}, dtype) + if !strings.HasSuffix(mlir, wantTy) { + t.Errorf("EmitAdd with dtype %s: %s does not end with %s", dtype, mlir, wantTy) + } + }) + } +}