Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
610 changes: 610 additions & 0 deletions internal/pjrt/buffer.go

Large diffs are not rendered by default.

14 changes: 6 additions & 8 deletions internal/pjrt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import (
"fmt"
"unsafe"

"github.com/zerfoo/ztensor/internal/cuda"
)

// Client wraps a PJRT_Client handle and provides Go-friendly methods
Expand Down Expand Up @@ -57,7 +55,7 @@
createOptions: cfg.createOptions,
}

errPtr := cuda.Ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args)))
if err := lib.checkError(errPtr); err != nil {
return nil, fmt.Errorf("PJRT_Client_Create: %w", err)
}
Expand All @@ -83,7 +81,7 @@
structSize: unsafe.Sizeof(destroyArgs{}),
client: c.handle,
}
errPtr := cuda.Ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args)))
c.handle = 0
return c.lib.checkError(errPtr)
}
Expand All @@ -105,7 +103,7 @@
structSize: unsafe.Sizeof(platformNameArgs{}),
client: c.handle,
}
errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args)))
if err := c.lib.checkError(errPtr); err != nil {
return "", fmt.Errorf("PJRT_Client_PlatformName: %w", err)
}
Expand All @@ -124,7 +122,7 @@
structSize: unsafe.Sizeof(platformVersionArgs{}),
client: c.handle,
}
errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args)))
if err := c.lib.checkError(errPtr); err != nil {
return "", fmt.Errorf("PJRT_Client_PlatformVersion: %w", err)
}
Expand All @@ -148,7 +146,7 @@
structSize: unsafe.Sizeof(devicesArgs{}),
client: c.handle,
}
errPtr := cuda.Ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args)))
if err := c.lib.checkError(errPtr); err != nil {
return nil, fmt.Errorf("PJRT_Client_Devices: %w", err)
}
Expand All @@ -167,7 +165,7 @@
structSize: unsafe.Sizeof(addressableDevicesArgs{}),
client: c.handle,
}
errPtr := cuda.Ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args)))
if err := c.lib.checkError(errPtr); err != nil {
return nil, fmt.Errorf("PJRT_Client_AddressableDevices: %w", err)
}
Expand All @@ -182,7 +180,7 @@
return nil
}
// arrayPtr is a PJRT_Device** — an array of n pointers.
ptrs := unsafe.Slice((*uintptr)(unsafe.Pointer(arrayPtr)), n)

Check failure on line 183 in internal/pjrt/client.go

View workflow job for this annotation

GitHub Actions / test

possible misuse of unsafe.Pointer
devices := make([]*Device, n)
for i, ptr := range ptrs {
devices[i] = &Device{lib: c.lib, handle: ptr}
Expand Down
12 changes: 5 additions & 7 deletions internal/pjrt/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package pjrt
import (
"fmt"
"unsafe"

"github.com/zerfoo/ztensor/internal/cuda"
)

// Device wraps a PJRT_Device handle and provides methods for
Expand Down Expand Up @@ -37,7 +35,7 @@ func (d *Device) ID() (int, error) {
structSize: unsafe.Sizeof(idArgs{}),
deviceDescription: desc,
}
errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args)))
if err := d.lib.checkError(errPtr); err != nil {
return 0, fmt.Errorf("PJRT_DeviceDescription_Id: %w", err)
}
Expand Down Expand Up @@ -66,7 +64,7 @@ func (d *Device) Kind() (string, error) {
structSize: unsafe.Sizeof(kindArgs{}),
deviceDescription: desc,
}
errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args)))
if err := d.lib.checkError(errPtr); err != nil {
return "", fmt.Errorf("PJRT_DeviceDescription_Kind: %w", err)
}
Expand All @@ -90,7 +88,7 @@ func (d *Device) IsAddressable() (bool, error) {
structSize: unsafe.Sizeof(isAddressableArgs{}),
device: d.handle,
}
errPtr := cuda.Ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args)))
if err := d.lib.checkError(errPtr); err != nil {
return false, fmt.Errorf("PJRT_Device_IsAddressable: %w", err)
}
Expand All @@ -114,7 +112,7 @@ func (d *Device) LocalHardwareId() (int, error) {
structSize: unsafe.Sizeof(localHWIDArgs{}),
device: d.handle,
}
errPtr := cuda.Ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args)))
if err := d.lib.checkError(errPtr); err != nil {
return 0, fmt.Errorf("PJRT_Device_LocalHardwareId: %w", err)
}
Expand Down Expand Up @@ -142,7 +140,7 @@ func (d *Device) getDescription() (uintptr, error) {
structSize: unsafe.Sizeof(getDescArgs{}),
device: d.handle,
}
errPtr := cuda.Ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args)))
errPtr := ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args)))
if err := d.lib.checkError(errPtr); err != nil {
return 0, fmt.Errorf("PJRT_Device_GetDescription: %w", err)
}
Expand Down
27 changes: 22 additions & 5 deletions internal/pjrt/pjrt.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
lib := &PJRTLib{}
var lastErr string
for _, path := range candidates {
h, err := cuda.DlopenPath(path)
h, err := dlopenPath(path)
if err == nil {
lib.handle = h
break
Expand All @@ -144,19 +144,19 @@
}

// Resolve the single entry point.
getPjrtApi, err := cuda.Dlsym(lib.handle, "GetPjrtApi")
getPjrtApi, err := dlsym(lib.handle, "GetPjrtApi")
if err != nil {
lib.Close()
return nil, fmt.Errorf("pjrt: dlsym GetPjrtApi: %w", err)
}

// Call GetPjrtApi() -> *PJRT_Api.
apiPtr := cuda.Ccall(getPjrtApi)
apiPtr := ccall(getPjrtApi)
if apiPtr == 0 {
lib.Close()
return nil, fmt.Errorf("pjrt: GetPjrtApi returned null")
}
lib.api = unsafe.Pointer(apiPtr) //nolint:govet // apiPtr is a valid C pointer from GetPjrtApi

Check failure on line 159 in internal/pjrt/pjrt.go

View workflow job for this annotation

GitHub Actions / test

possible misuse of unsafe.Pointer

// Read struct_size and pjrt_api_version from the header.
header := (*pjrtApiHeader)(lib.api)
Expand Down Expand Up @@ -240,7 +240,7 @@
structSize: unsafe.Sizeof(errorMessageArgs{}),
error: errPtr,
}
cuda.Ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args)))
ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args)))

if args.message == 0 || args.messageLen == 0 {
return "unknown PJRT error"
Expand All @@ -261,7 +261,7 @@
structSize: unsafe.Sizeof(destroyArgs{}),
error: errPtr,
}
cuda.Ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args)))
ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args)))
}

// checkError converts a PJRT_Error pointer to a Go error.
Expand All @@ -275,6 +275,23 @@
return fmt.Errorf("pjrt: %s", msg)
}

// ccall calls a C function pointer with the given arguments.
// Centralizes the internal/cuda dependency so other files in this
// package do not need to import it directly.
func ccall(fn uintptr, args ...uintptr) uintptr {
return cuda.Ccall(fn, args...)
}

// dlopenPath opens a shared library at the given path.
func dlopenPath(path string) (uintptr, error) {
return cuda.DlopenPath(path)
}

// dlsym resolves a symbol from a dlopen handle.
func dlsym(handle uintptr, name string) (uintptr, error) {
return cuda.Dlsym(handle, name)
}

// goStringN converts a C string pointer and length to a Go string.
//
//go:nosplit
Expand All @@ -283,5 +300,5 @@
if p == 0 || n == 0 {
return ""
}
return string(unsafe.Slice((*byte)(unsafe.Pointer(p)), n))

Check failure on line 303 in internal/pjrt/pjrt.go

View workflow job for this annotation

GitHub Actions / test

possible misuse of unsafe.Pointer
}
Loading
Loading