From 50f2ca6bdc361f3108ef108cfe7378231d9f04a8 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:14:33 -0700 Subject: [PATCH 1/4] feat(stablehlo): add emitter for MatMul and structural ops Emit StableHLO MLIR text for structural operations: - MatMul -> stablehlo.dot_general with batching/contracting dims - Transpose -> stablehlo.transpose with permutation - Reshape -> stablehlo.reshape with shape validation - Concat -> stablehlo.concatenate along specified axis - Slice -> stablehlo.slice with starts/limits/strides - Gather -> stablehlo.gather with full dimension numbers All emitters use SSANamer and FormatTensorType from types.go and delegate shape inference to InferStructuralShape. --- internal/stablehlo/emit_structural.go | 207 ++++++++++++++++++ internal/stablehlo/emit_structural_test.go | 241 +++++++++++++++++++++ 2 files changed, 448 insertions(+) create mode 100644 internal/stablehlo/emit_structural.go create mode 100644 internal/stablehlo/emit_structural_test.go diff --git a/internal/stablehlo/emit_structural.go b/internal/stablehlo/emit_structural.go new file mode 100644 index 0000000..8b2bb58 --- /dev/null +++ b/internal/stablehlo/emit_structural.go @@ -0,0 +1,207 @@ +package stablehlo + +import ( + "fmt" + "strings" +) + +// EmitMatMul emits a stablehlo.dot_general operation for matrix multiplication. +// Handles 2D (MxK @ KxN) and batched (BxMxK @ BxKxN) cases. +// Returns the MLIR line and the SSA name assigned to the result. +func EmitMatMul(namer *SSANamer, lhs, rhs string, lhsShape, rhsShape []int, dtype string) (string, string, error) { + if len(lhsShape) < 2 || len(rhsShape) < 2 { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: inputs must be at least rank 2, got rank %d and %d", len(lhsShape), len(rhsShape)) + } + if len(lhsShape) != len(rhsShape) { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: rank mismatch: %d vs %d", len(lhsShape), len(rhsShape)) + } + + rank := len(lhsShape) + // Contraction dimension: last axis of LHS, second-to-last axis of RHS. + lhsContract := rank - 1 + rhsContract := rank - 2 + + if lhsShape[lhsContract] != rhsShape[rhsContract] { + return "", "", fmt.Errorf("stablehlo.EmitMatMul: contraction dimension mismatch: %d vs %d", lhsShape[lhsContract], rhsShape[rhsContract]) + } + + outShape, err := InferStructuralShape("MatMul", [][]int{lhsShape, rhsShape}, nil) + if err != nil { + return "", "", err + } + + result := namer.NextName() + lhsType := FormatTensorType(lhsShape, dtype) + rhsType := FormatTensorType(rhsShape, dtype) + outType := FormatTensorType(outShape, dtype) + + // Build batch dimensions list. + var batchDims []string + for i := 0; i < rank-2; i++ { + batchDims = append(batchDims, fmt.Sprintf("%d", i)) + } + + var b strings.Builder + fmt.Fprintf(&b, "%s = %s %s, %s, batching_dims = [%s] x [%s], contracting_dims = [%d] x [%d] : (%s, %s) -> %s", + result, OpDotGeneral, lhs, rhs, + strings.Join(batchDims, ", "), strings.Join(batchDims, ", "), + lhsContract, rhsContract, + lhsType, rhsType, outType, + ) + + return b.String(), result, nil +} + +// EmitTranspose emits a stablehlo.transpose operation. +// perm specifies the axis permutation (e.g., [2, 0, 1]). +func EmitTranspose(namer *SSANamer, operand string, shape []int, perm []int, dtype string) (string, string, error) { + if len(perm) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitTranspose: perm length %d does not match rank %d", len(perm), len(shape)) + } + + outShape, err := InferStructuralShape("Transpose", [][]int{shape}, map[string]any{"perm": perm}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + inType := FormatTensorType(shape, dtype) + outType := FormatTensorType(outShape, dtype) + + permStrs := make([]string, len(perm)) + for i, p := range perm { + permStrs[i] = fmt.Sprintf("%d", p) + } + + line := fmt.Sprintf("%s = %s %s, permutation = [%s] : (%s) -> %s", + result, OpTranspose, operand, + strings.Join(permStrs, ", "), + inType, outType, + ) + + return line, result, nil +} + +// EmitReshape emits a stablehlo.reshape operation. +// targetShape is the desired output shape. +func EmitReshape(namer *SSANamer, operand string, inShape, targetShape []int, dtype string) (string, string, error) { + outShape, err := InferStructuralShape("Reshape", [][]int{inShape}, map[string]any{"shape": targetShape}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + inType := FormatTensorType(inShape, dtype) + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s : (%s) -> %s", + result, OpReshape, operand, + inType, outType, + ) + + return line, result, nil +} + +// EmitConcat emits a stablehlo.concatenate operation along the given axis. +// operands are the SSA names, shapes are the corresponding tensor shapes. +func EmitConcat(namer *SSANamer, operands []string, shapes [][]int, axis int, dtype string) (string, string, error) { + if len(operands) != len(shapes) { + return "", "", fmt.Errorf("stablehlo.EmitConcat: operand count %d does not match shape count %d", len(operands), len(shapes)) + } + + outShape, err := InferStructuralShape("Concat", shapes, map[string]any{"axis": axis}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s, dimension = %d : %s", + result, OpConcatenate, strings.Join(operands, ", "), + axis, outType, + ) + + return line, result, nil +} + +// EmitSlice emits a stablehlo.slice operation with start, limit, and stride indices. +// strides may be nil, in which case all strides default to 1. +func EmitSlice(namer *SSANamer, operand string, shape, start, limit, strides []int, dtype string) (string, string, error) { + if len(start) != len(shape) || len(limit) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitSlice: start/limit length must match rank %d", len(shape)) + } + if strides == nil { + strides = make([]int, len(shape)) + for i := range strides { + strides[i] = 1 + } + } + if len(strides) != len(shape) { + return "", "", fmt.Errorf("stablehlo.EmitSlice: strides length %d must match rank %d", len(strides), len(shape)) + } + + // Compute output shape: ceil((limit[i] - start[i]) / strides[i]). + outShape := make([]int, len(shape)) + for i := range shape { + if start[i] < 0 || limit[i] > shape[i] || start[i] > limit[i] || strides[i] <= 0 { + return "", "", fmt.Errorf("stablehlo.EmitSlice: invalid range [%d:%d] stride %d for dimension %d (size %d)", start[i], limit[i], strides[i], i, shape[i]) + } + outShape[i] = (limit[i] - start[i] + strides[i] - 1) / strides[i] + } + + result := namer.NextName() + inType := FormatTensorType(shape, dtype) + outType := FormatTensorType(outShape, dtype) + + line := fmt.Sprintf("%s = %s %s, starts = [%s], limits = [%s], strides = [%s] : (%s) -> %s", + result, OpSlice, operand, + formatIntSlice(start), formatIntSlice(limit), formatIntSlice(strides), + inType, outType, + ) + + return line, result, nil +} + +// EmitGather emits a stablehlo.gather operation. +// operandShape is the shape of the data tensor, indicesShape is the shape of the index tensor. +// sliceSizes specifies the size of each gathered slice. +// offsetDims, collapsedSliceDims, startIndexMap are the gather dimension numbers. +// indexVectorDim is the dimension in the indices tensor that contains the index vector. +func EmitGather(namer *SSANamer, operand, indices string, + operandShape, indicesShape, sliceSizes []int, + offsetDims, collapsedSliceDims, startIndexMap []int, + indexVectorDim int, + dtype string, +) (string, string, error) { + // Compute output shape from the gather semantics. + outShape, err := InferStructuralShape("Gather", [][]int{operandShape, indicesShape}, map[string]any{"sliceSizes": sliceSizes}) + if err != nil { + return "", "", err + } + + result := namer.NextName() + outType := FormatTensorType(outShape, dtype) + + var b strings.Builder + fmt.Fprintf(&b, "%s = %s %s, %s, offset_dims = [%s], collapsed_slice_dims = [%s], start_index_map = [%s], index_vector_dim = %d, slice_sizes = [%s] : %s", + result, OpGather, operand, indices, + formatIntSlice(offsetDims), + formatIntSlice(collapsedSliceDims), + formatIntSlice(startIndexMap), + indexVectorDim, + formatIntSlice(sliceSizes), + outType, + ) + + return b.String(), result, nil +} + +// formatIntSlice formats an int slice as a comma-separated string (e.g., "0, 1, 2"). +func formatIntSlice(s []int) string { + parts := make([]string, len(s)) + for i, v := range s { + parts[i] = fmt.Sprintf("%d", v) + } + return strings.Join(parts, ", ") +} diff --git a/internal/stablehlo/emit_structural_test.go b/internal/stablehlo/emit_structural_test.go new file mode 100644 index 0000000..117ed3d --- /dev/null +++ b/internal/stablehlo/emit_structural_test.go @@ -0,0 +1,241 @@ +package stablehlo + +import ( + "strings" + "testing" +) + +func TestEmitMatMul2D(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitMatMul(namer, "%a", "%b", []int{4, 3}, []int{3, 5}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.dot_general %a, %b, batching_dims = [] x [], contracting_dims = [1] x [0] : (tensor<4x3xf32>, tensor<3x5xf32>) -> tensor<4x5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitMatMulBatched(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitMatMul(namer, "%a", "%b", []int{2, 4, 3}, []int{2, 3, 5}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.dot_general %a, %b, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x4x3xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitMatMulContractionMismatch(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitMatMul(namer, "%a", "%b", []int{4, 3}, []int{7, 5}, DTypeF32) + if err == nil { + t.Fatal("expected error for contraction dimension mismatch") + } +} + +func TestEmitMatMulRank1(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitMatMul(namer, "%a", "%b", []int{4}, []int{4}, DTypeF32) + if err == nil { + t.Fatal("expected error for rank-1 inputs") + } +} + +func TestEmitTranspose(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitTranspose(namer, "%a", []int{2, 3, 4}, []int{2, 0, 1}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.transpose %a, permutation = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitTransposeInvalidPerm(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitTranspose(namer, "%a", []int{2, 3}, []int{0}, DTypeF32) + if err == nil { + t.Fatal("expected error for perm length mismatch") + } +} + +func TestEmitReshape(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitReshape(namer, "%a", []int{2, 3, 4}, []int{6, 4}, DTypeF32) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.reshape %a : (tensor<2x3x4xf32>) -> tensor<6x4xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitReshapeElementMismatch(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitReshape(namer, "%a", []int{2, 3}, []int{7}, DTypeF32) + if err == nil { + t.Fatal("expected error for element count mismatch") + } +} + +func TestEmitConcat(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitConcat(namer, + []string{"%a", "%b"}, + [][]int{{2, 3}, {2, 5}}, + 1, DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.concatenate %a, %b, dimension = 1 : tensor<2x8xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitConcatThreeInputs(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitConcat(namer, + []string{"%a", "%b", "%c"}, + [][]int{{4, 2}, {4, 3}, {4, 1}}, + 1, DTypeF64, + ) + if err != nil { + t.Fatal(err) + } + want := `%v0 = stablehlo.concatenate %a, %b, %c, dimension = 1 : tensor<4x6xf64>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSlice(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitSlice(namer, "%a", + []int{8, 6}, + []int{1, 0}, []int{5, 4}, nil, + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + want := `%v0 = stablehlo.slice %a, starts = [1, 0], limits = [5, 4], strides = [1, 1] : (tensor<8x6xf32>) -> tensor<4x4xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSliceWithStrides(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitSlice(namer, "%a", + []int{10}, + []int{0}, []int{10}, []int{2}, + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + want := `%v0 = stablehlo.slice %a, starts = [0], limits = [10], strides = [2] : (tensor<10xf32>) -> tensor<5xf32>` + if line != want { + t.Errorf("mismatch:\ngot: %s\nwant: %s", line, want) + } +} + +func TestEmitSliceInvalidRange(t *testing.T) { + namer := &SSANamer{} + _, _, err := EmitSlice(namer, "%a", []int{4}, []int{3}, []int{1}, nil, DTypeF32) + if err == nil { + t.Fatal("expected error for invalid range (start > limit)") + } +} + +func TestEmitGather(t *testing.T) { + namer := &SSANamer{} + line, result, err := EmitGather(namer, "%data", "%indices", + []int{10, 8}, // operand shape + []int{3, 1}, // indices shape + []int{1, 8}, // slice sizes + []int{1}, // offset dims + []int{0}, // collapsed slice dims + []int{0}, // start index map + 1, // index vector dim + DTypeF32, + ) + if err != nil { + t.Fatal(err) + } + if result != "%v0" { + t.Errorf("expected result %%v0, got %s", result) + } + // Output shape from InferStructuralShape(Gather): indices[:-1] + sliceSizes = [3] + [1, 8] = [3, 1, 8] + if !strings.Contains(line, "stablehlo.gather") { + t.Errorf("expected stablehlo.gather in output, got: %s", line) + } + if !strings.Contains(line, "offset_dims = [1]") { + t.Errorf("expected offset_dims = [1], got: %s", line) + } + if !strings.Contains(line, "collapsed_slice_dims = [0]") { + t.Errorf("expected collapsed_slice_dims = [0], got: %s", line) + } + if !strings.Contains(line, "start_index_map = [0]") { + t.Errorf("expected start_index_map = [0], got: %s", line) + } + if !strings.Contains(line, "index_vector_dim = 1") { + t.Errorf("expected index_vector_dim = 1, got: %s", line) + } + if !strings.Contains(line, "slice_sizes = [1, 8]") { + t.Errorf("expected slice_sizes = [1, 8], got: %s", line) + } +} + +func TestEmitMatMulF16(t *testing.T) { + namer := &SSANamer{} + line, _, err := EmitMatMul(namer, "%a", "%b", []int{8, 16}, []int{16, 32}, DTypeF16) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(line, "f16") { + t.Errorf("expected f16 dtype in output, got: %s", line) + } + if !strings.Contains(line, "tensor<8x32xf16>") { + t.Errorf("expected output type tensor<8x32xf16>, got: %s", line) + } +} + +func TestSSANamerCounterAdvances(t *testing.T) { + namer := &SSANamer{} + + _, r0, _ := EmitReshape(namer, "%a", []int{6}, []int{2, 3}, DTypeF32) + _, r1, _ := EmitReshape(namer, "%b", []int{6}, []int{3, 2}, DTypeF32) + _, r2, _ := EmitTranspose(namer, "%c", []int{2, 3}, []int{1, 0}, DTypeF32) + + if r0 != "%v0" || r1 != "%v1" || r2 != "%v2" { + t.Errorf("expected %%v0, %%v1, %%v2 but got %s, %s, %s", r0, r1, r2) + } +} From f436e131ed23bacc79dcd42aa83d0752b7616efb Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:14:42 -0700 Subject: [PATCH 2/4] feat(stablehlo): add emitter for element-wise and unary ops Implement Emitter struct that generates StableHLO MLIR text from SSA inputs, shapes, and dtypes. Covers binary ops (add, subtract, multiply, divide, power), unary ops (exponential, log, sine, cosine, tanh, sqrt, rsqrt, negate), and scalar ops (MulScalar, AddScalar, DivScalar) which emit constant + broadcast_in_dim + element-wise op. Includes EmitOp dispatch function for use by the program assembler. --- internal/stablehlo/emit.go | 272 ++++++++++++++++++++++++++++++++ internal/stablehlo/emit_test.go | 247 +++++++++++++++++++++++++++++ 2 files changed, 519 insertions(+) create mode 100644 internal/stablehlo/emit.go create mode 100644 internal/stablehlo/emit_test.go diff --git a/internal/stablehlo/emit.go b/internal/stablehlo/emit.go new file mode 100644 index 0000000..ccb0165 --- /dev/null +++ b/internal/stablehlo/emit.go @@ -0,0 +1,272 @@ +package stablehlo + +import "fmt" + +// Emitter generates StableHLO MLIR text from operation inputs. +// Each emit method takes SSA input names, tensor shapes, and a dtype, +// and returns the emitted MLIR line(s) plus the output SSA name. +type Emitter struct { + Namer *SSANamer +} + +// NewEmitter creates an Emitter with a fresh SSANamer. +func NewEmitter() *Emitter { + return &Emitter{Namer: &SSANamer{}} +} + +// EmitBinaryElementwise emits a binary element-wise op (add, subtract, multiply, divide, power). +// Both inputs must have the same shape and dtype. +func (e *Emitter) EmitBinaryElementwise(opName, lhs, rhs string, shape []int, dtype string) (mlir, outName string) { + outName = e.Namer.NextName() + ty := FormatTensorType(shape, dtype) + mlir = fmt.Sprintf("%s = %s %s, %s : %s", outName, opName, lhs, rhs, ty) + return mlir, outName +} + +// EmitAdd emits stablehlo.add. +func (e *Emitter) EmitAdd(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpAdd, lhs, rhs, shape, dtype) +} + +// EmitSub emits stablehlo.subtract. +func (e *Emitter) EmitSub(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpSubtract, lhs, rhs, shape, dtype) +} + +// EmitMul emits stablehlo.multiply. +func (e *Emitter) EmitMul(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpMultiply, lhs, rhs, shape, dtype) +} + +// EmitDiv emits stablehlo.divide. +func (e *Emitter) EmitDiv(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpDivide, lhs, rhs, shape, dtype) +} + +// EmitPow emits stablehlo.power. +func (e *Emitter) EmitPow(lhs, rhs string, shape []int, dtype string) (string, string) { + return e.EmitBinaryElementwise(OpPower, lhs, rhs, shape, dtype) +} + +// EmitUnary emits a unary element-wise op (exponential, log, sine, cosine, tanh, sqrt, rsqrt, negate). +func (e *Emitter) EmitUnary(opName, input string, shape []int, dtype string) (mlir, outName string) { + outName = e.Namer.NextName() + ty := FormatTensorType(shape, dtype) + mlir = fmt.Sprintf("%s = %s %s : %s", outName, opName, input, ty) + return mlir, outName +} + +// EmitExp emits stablehlo.exponential. +func (e *Emitter) EmitExp(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpExp, input, shape, dtype) +} + +// EmitLog emits stablehlo.log. +func (e *Emitter) EmitLog(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpLog, input, shape, dtype) +} + +// EmitSin emits stablehlo.sine. +func (e *Emitter) EmitSin(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpSin, input, shape, dtype) +} + +// EmitCos emits stablehlo.cosine. +func (e *Emitter) EmitCos(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpCos, input, shape, dtype) +} + +// EmitTanh emits stablehlo.tanh. +func (e *Emitter) EmitTanh(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpTanh, input, shape, dtype) +} + +// EmitSqrt emits stablehlo.sqrt. +func (e *Emitter) EmitSqrt(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpSqrt, input, shape, dtype) +} + +// EmitRsqrt emits stablehlo.rsqrt. +func (e *Emitter) EmitRsqrt(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpRsqrt, input, shape, dtype) +} + +// EmitNeg emits stablehlo.negate. +func (e *Emitter) EmitNeg(input string, shape []int, dtype string) (string, string) { + return e.EmitUnary(OpNegate, input, shape, dtype) +} + +// EmitScalarOp emits a scalar operation as three MLIR instructions: +// 1. stablehlo.constant for the scalar value +// 2. stablehlo.broadcast_in_dim to broadcast to the tensor shape +// 3. The element-wise binary op +// +// Returns all three lines (newline-separated) and the final output SSA name. +func (e *Emitter) EmitScalarOp(elemOp, input string, scalar float64, shape []int, dtype string) (mlir, outName string) { + ty := FormatTensorType(shape, dtype) + scalarTy := FormatTensorType(nil, dtype) + + // 1. Constant + constName := e.Namer.NextName() + constLine := fmt.Sprintf("%s = %s dense<%v> : %s", constName, OpConstant, scalar, scalarTy) + + // 2. Broadcast + bcastName := e.Namer.NextName() + bcastLine := fmt.Sprintf("%s = %s %s, dims = [] : (%s) -> %s", bcastName, OpBroadcastIn, constName, scalarTy, ty) + + // 3. Element-wise op + outName = e.Namer.NextName() + opLine := fmt.Sprintf("%s = %s %s, %s : %s", outName, elemOp, input, bcastName, ty) + + mlir = constLine + "\n" + bcastLine + "\n" + opLine + return mlir, outName +} + +// EmitMulScalar emits stablehlo.constant + broadcast_in_dim + multiply. +func (e *Emitter) EmitMulScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpMultiply, input, scalar, shape, dtype) +} + +// EmitAddScalar emits stablehlo.constant + broadcast_in_dim + add. +func (e *Emitter) EmitAddScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpAdd, input, scalar, shape, dtype) +} + +// EmitDivScalar emits stablehlo.constant + broadcast_in_dim + divide. +func (e *Emitter) EmitDivScalar(input string, scalar float64, shape []int, dtype string) (string, string) { + return e.EmitScalarOp(OpDivide, input, scalar, shape, dtype) +} + +// EmitOp dispatches to the appropriate emit function based on the engine op name. +// For binary ops, inputs should be [lhs, rhs]. For unary ops, inputs should be [input]. +// For scalar ops, inputs should be [input] and attrs must contain "scalar" (float64). +// Returns the emitted MLIR text and the output SSA name. +func (e *Emitter) EmitOp(opName string, inputs []string, shape []int, dtype string, attrs map[string]any) (string, string, error) { + switch opName { + case "Add": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitAdd(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Sub": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSub(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Mul": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitMul(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Div": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitDiv(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Pow": + if len(inputs) != 2 { + return "", "", fmt.Errorf("EmitOp(%s): expected 2 inputs, got %d", opName, len(inputs)) + } + mlir, out := e.EmitPow(inputs[0], inputs[1], shape, dtype) + return mlir, out, nil + case "Exp": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitExp(inputs[0], shape, dtype) + return mlir, out, nil + case "Log": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitLog(inputs[0], shape, dtype) + return mlir, out, nil + case "Sin": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSin(inputs[0], shape, dtype) + return mlir, out, nil + case "Cos": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitCos(inputs[0], shape, dtype) + return mlir, out, nil + case "Tanh": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitTanh(inputs[0], shape, dtype) + return mlir, out, nil + case "Sqrt": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitSqrt(inputs[0], shape, dtype) + return mlir, out, nil + case "Rsqrt": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitRsqrt(inputs[0], shape, dtype) + return mlir, out, nil + case "Neg": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + mlir, out := e.EmitNeg(inputs[0], shape, dtype) + return mlir, out, nil + case "MulScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitMulScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + case "AddScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitAddScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + case "DivScalar": + if len(inputs) != 1 { + return "", "", fmt.Errorf("EmitOp(%s): expected 1 input, got %d", opName, len(inputs)) + } + s, err := scalarAttr(opName, attrs) + if err != nil { + return "", "", err + } + mlir, out := e.EmitDivScalar(inputs[0], s, shape, dtype) + return mlir, out, nil + default: + return "", "", fmt.Errorf("EmitOp: unsupported op %q", opName) + } +} + +func scalarAttr(opName string, attrs map[string]any) (float64, error) { + if attrs == nil { + return 0, fmt.Errorf("EmitOp(%s): attrs map is nil, need \"scalar\" key", opName) + } + v, ok := attrs["scalar"] + if !ok { + return 0, fmt.Errorf("EmitOp(%s): missing \"scalar\" in attrs", opName) + } + s, ok := v.(float64) + if !ok { + return 0, fmt.Errorf("EmitOp(%s): \"scalar\" attr is %T, want float64", opName, v) + } + return s, nil +} diff --git a/internal/stablehlo/emit_test.go b/internal/stablehlo/emit_test.go new file mode 100644 index 0000000..688e6e2 --- /dev/null +++ b/internal/stablehlo/emit_test.go @@ -0,0 +1,247 @@ +package stablehlo + +import ( + "strings" + "testing" +) + +func TestEmitBinaryElementwise(t *testing.T) { + tests := []struct { + name string + op string + lhs string + rhs string + shape []int + dtype string + wantOp string + }{ + {"add 2D", OpAdd, "%v0", "%v1", []int{2, 3}, DTypeF32, "stablehlo.add"}, + {"sub 1D", OpSubtract, "%a", "%b", []int{8}, DTypeF64, "stablehlo.subtract"}, + {"mul 3D", OpMultiply, "%x", "%y", []int{1, 4, 4}, DTypeBF16, "stablehlo.multiply"}, + {"div scalar", OpDivide, "%p", "%q", []int{}, DTypeF32, "stablehlo.divide"}, + {"pow 2D", OpPower, "%a", "%b", []int{3, 3}, DTypeF32, "stablehlo.power"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitBinaryElementwise(tt.op, tt.lhs, tt.rhs, tt.shape, tt.dtype) + wantTy := FormatTensorType(tt.shape, tt.dtype) + wantLine := out + " = " + tt.wantOp + " " + tt.lhs + ", " + tt.rhs + " : " + wantTy + if mlir != wantLine { + t.Errorf("got:\n %s\nwant:\n %s", mlir, wantLine) + } + if out != "%v0" { + t.Errorf("output name = %q, want %%v0", out) + } + }) + } +} + +func TestEmitAdd(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitAdd("%arg0", "%arg1", []int{2, 3}, DTypeF32) + want := "%v0 = stablehlo.add %arg0, %arg1 : tensor<2x3xf32>" + if mlir != want { + t.Errorf("EmitAdd:\n got: %s\n want: %s", mlir, want) + } + if out != "%v0" { + t.Errorf("output = %q, want %%v0", out) + } +} + +func TestEmitUnaryOps(t *testing.T) { + tests := []struct { + name string + emit func(*Emitter, string, []int, string) (string, string) + wantOp string + }{ + {"Exp", (*Emitter).EmitExp, "stablehlo.exponential"}, + {"Log", (*Emitter).EmitLog, "stablehlo.log"}, + {"Sin", (*Emitter).EmitSin, "stablehlo.sine"}, + {"Cos", (*Emitter).EmitCos, "stablehlo.cosine"}, + {"Tanh", (*Emitter).EmitTanh, "stablehlo.tanh"}, + {"Sqrt", (*Emitter).EmitSqrt, "stablehlo.sqrt"}, + {"Rsqrt", (*Emitter).EmitRsqrt, "stablehlo.rsqrt"}, + {"Neg", (*Emitter).EmitNeg, "stablehlo.negate"}, + } + shape := []int{4, 8} + dtype := DTypeF32 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := tt.emit(e, "%input", shape, dtype) + wantTy := FormatTensorType(shape, dtype) + want := out + " = " + tt.wantOp + " %input : " + wantTy + if mlir != want { + t.Errorf("got:\n %s\nwant:\n %s", mlir, want) + } + if out != "%v0" { + t.Errorf("output = %q, want %%v0", out) + } + }) + } +} + +func TestEmitScalarOps(t *testing.T) { + tests := []struct { + name string + emit func(*Emitter, string, float64, []int, string) (string, string) + wantOp string + }{ + {"MulScalar", (*Emitter).EmitMulScalar, "stablehlo.multiply"}, + {"AddScalar", (*Emitter).EmitAddScalar, "stablehlo.add"}, + {"DivScalar", (*Emitter).EmitDivScalar, "stablehlo.divide"}, + } + shape := []int{2, 3} + dtype := DTypeF32 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, out := tt.emit(e, "%x", 2.5, shape, dtype) + + lines := strings.Split(mlir, "\n") + if len(lines) != 3 { + t.Fatalf("expected 3 lines, got %d:\n%s", len(lines), mlir) + } + + // Line 1: constant + if !strings.Contains(lines[0], OpConstant) { + t.Errorf("line 0 missing %s: %s", OpConstant, lines[0]) + } + if !strings.Contains(lines[0], "dense<2.5>") { + t.Errorf("line 0 missing dense<2.5>: %s", lines[0]) + } + if !strings.Contains(lines[0], "tensor") { + t.Errorf("line 0 missing scalar type: %s", lines[0]) + } + + // Line 2: broadcast_in_dim + if !strings.Contains(lines[1], OpBroadcastIn) { + t.Errorf("line 1 missing %s: %s", OpBroadcastIn, lines[1]) + } + if !strings.Contains(lines[1], "tensor<2x3xf32>") { + t.Errorf("line 1 missing output type: %s", lines[1]) + } + + // Line 3: element-wise op + if !strings.Contains(lines[2], tt.wantOp) { + t.Errorf("line 2 missing %s: %s", tt.wantOp, lines[2]) + } + + if out != "%v2" { + t.Errorf("output = %q, want %%v2 (const=%%v0, bcast=%%v1, op=%%v2)", out) + } + }) + } +} + +func TestEmitScalarOpFullOutput(t *testing.T) { + e := NewEmitter() + mlir, out := e.EmitMulScalar("%arg0", 3, []int{4}, DTypeF32) + want := "%v0 = stablehlo.constant dense<3> : tensor\n" + + "%v1 = stablehlo.broadcast_in_dim %v0, dims = [] : (tensor) -> tensor<4xf32>\n" + + "%v2 = stablehlo.multiply %arg0, %v1 : tensor<4xf32>" + if mlir != want { + t.Errorf("EmitMulScalar full output:\n got:\n%s\n want:\n%s", mlir, want) + } + if out != "%v2" { + t.Errorf("output = %q, want %%v2", out) + } +} + +func TestEmitOpDispatch(t *testing.T) { + shape := []int{2, 4} + dtype := DTypeF32 + + tests := []struct { + name string + opName string + inputs []string + attrs map[string]any + wantOp string + wantErr bool + }{ + {"Add", "Add", []string{"%a", "%b"}, nil, "stablehlo.add", false}, + {"Sub", "Sub", []string{"%a", "%b"}, nil, "stablehlo.subtract", false}, + {"Mul", "Mul", []string{"%a", "%b"}, nil, "stablehlo.multiply", false}, + {"Div", "Div", []string{"%a", "%b"}, nil, "stablehlo.divide", false}, + {"Pow", "Pow", []string{"%a", "%b"}, nil, "stablehlo.power", false}, + {"Exp", "Exp", []string{"%a"}, nil, "stablehlo.exponential", false}, + {"Log", "Log", []string{"%a"}, nil, "stablehlo.log", false}, + {"Sin", "Sin", []string{"%a"}, nil, "stablehlo.sine", false}, + {"Cos", "Cos", []string{"%a"}, nil, "stablehlo.cosine", false}, + {"Tanh", "Tanh", []string{"%a"}, nil, "stablehlo.tanh", false}, + {"Sqrt", "Sqrt", []string{"%a"}, nil, "stablehlo.sqrt", false}, + {"Rsqrt", "Rsqrt", []string{"%a"}, nil, "stablehlo.rsqrt", false}, + {"Neg", "Neg", []string{"%a"}, nil, "stablehlo.negate", false}, + {"MulScalar", "MulScalar", []string{"%a"}, map[string]any{"scalar": 2.0}, "stablehlo.multiply", false}, + {"AddScalar", "AddScalar", []string{"%a"}, map[string]any{"scalar": 1.0}, "stablehlo.add", false}, + {"DivScalar", "DivScalar", []string{"%a"}, map[string]any{"scalar": 4.0}, "stablehlo.divide", false}, + {"unsupported", "Softmax", []string{"%a"}, nil, "", true}, + {"wrong inputs binary", "Add", []string{"%a"}, nil, "", true}, + {"wrong inputs unary", "Exp", []string{"%a", "%b"}, nil, "", true}, + {"missing scalar attr", "MulScalar", []string{"%a"}, nil, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEmitter() + mlir, _, err := e.EmitOp(tt.opName, tt.inputs, shape, dtype, tt.attrs) + if tt.wantErr { + if err == nil { + t.Errorf("expected error, got mlir: %s", mlir) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(mlir, tt.wantOp) { + t.Errorf("EmitOp(%s) output missing %s:\n%s", tt.opName, tt.wantOp, mlir) + } + }) + } +} + +func TestEmitSSACounterProgresses(t *testing.T) { + e := NewEmitter() + _, out1 := e.EmitAdd("%a", "%b", []int{2}, DTypeF32) + _, out2 := e.EmitExp("%c", []int{2}, DTypeF32) + _, out3 := e.EmitSub("%d", "%e", []int{2}, DTypeF32) + + if out1 != "%v0" || out2 != "%v1" || out3 != "%v2" { + t.Errorf("SSA names = [%s, %s, %s], want [%%v0, %%v1, %%v2]", out1, out2, out3) + } + + if e.Namer.Count() != 3 { + t.Errorf("namer count = %d, want 3", e.Namer.Count()) + } +} + +func TestEmitScalarOpSSACounterProgresses(t *testing.T) { + e := NewEmitter() + // MulScalar uses 3 SSA names (const, broadcast, op). + _, out1 := e.EmitMulScalar("%x", 2.0, []int{4}, DTypeF32) + // Next op should get %v3. + _, out2 := e.EmitAdd("%a", "%b", []int{4}, DTypeF32) + + if out1 != "%v2" { + t.Errorf("MulScalar output = %q, want %%v2", out1) + } + if out2 != "%v3" { + t.Errorf("Add output after scalar = %q, want %%v3", out2) + } +} + +func TestEmitDifferentDtypes(t *testing.T) { + dtypes := []string{DTypeF32, DTypeF64, DTypeF16, DTypeBF16} + for _, dtype := range dtypes { + t.Run(dtype, func(t *testing.T) { + e := NewEmitter() + mlir, _ := e.EmitAdd("%a", "%b", []int{2, 3}, dtype) + wantTy := FormatTensorType([]int{2, 3}, dtype) + if !strings.HasSuffix(mlir, wantTy) { + t.Errorf("EmitAdd with dtype %s: %s does not end with %s", dtype, mlir, wantTy) + } + }) + } +} From 6bb003c6d244b088831203bfb253c298b8b6d6a2 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:15:08 -0700 Subject: [PATCH 3/4] feat(pjrt): add buffer management (host-device transfer, readback, lifecycle) T60.2.1: BufferFromHost wraps PJRT_Client_BufferFromHostBuffer with Go type to PJRT element type mapping (F32, F64, F16, BF16, F8E4M3, S32, S64, etc.), shape validation, and buffer donation support. T60.2.2: ToHost/ToHostSlice wraps PJRT_Buffer_ToHostBuffer with async readback and PJRT_Event_Await synchronization. T60.2.3: Buffer metadata (Dtype, Shape, OnDeviceSizeInBytes, ReadyEvent) and lifecycle (Close, Delete) with double-close no-op for finalizer safety. --- internal/pjrt/buffer.go | 611 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 internal/pjrt/buffer.go diff --git a/internal/pjrt/buffer.go b/internal/pjrt/buffer.go new file mode 100644 index 0000000..509581c --- /dev/null +++ b/internal/pjrt/buffer.go @@ -0,0 +1,611 @@ +package pjrt + +import ( + "fmt" + "sync" + "unsafe" + + "github.com/zerfoo/float16" + "github.com/zerfoo/float8" + "github.com/zerfoo/ztensor/internal/cuda" +) + +// ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API. +type ElementType int32 + +const ( + ElementTypeInvalid ElementType = 0 + ElementTypePRED ElementType = 1 // bool + ElementTypeS8 ElementType = 2 // int8 + ElementTypeS16 ElementType = 3 // int16 + ElementTypeS32 ElementType = 4 // int32 + ElementTypeS64 ElementType = 5 // int64 + ElementTypeU8 ElementType = 6 // uint8 + ElementTypeU16 ElementType = 7 // uint16 + ElementTypeU32 ElementType = 8 // uint32 + ElementTypeU64 ElementType = 9 // uint64 + ElementTypeF16 ElementType = 10 // float16 + ElementTypeF32 ElementType = 11 // float32 + ElementTypeF64 ElementType = 12 // float64 + ElementTypeBF16 ElementType = 16 // bfloat16 + ElementTypeF8E4M3 ElementType = 20 // float8 E4M3FN +) + +// String returns the PJRT element type name. +func (t ElementType) String() string { + switch t { + case ElementTypePRED: + return "pred" + case ElementTypeS8: + return "s8" + case ElementTypeS16: + return "s16" + case ElementTypeS32: + return "s32" + case ElementTypeS64: + return "s64" + case ElementTypeU8: + return "u8" + case ElementTypeU16: + return "u16" + case ElementTypeU32: + return "u32" + case ElementTypeU64: + return "u64" + case ElementTypeF16: + return "f16" + case ElementTypeF32: + return "f32" + case ElementTypeF64: + return "f64" + case ElementTypeBF16: + return "bf16" + case ElementTypeF8E4M3: + return "f8e4m3fn" + default: + return fmt.Sprintf("unknown(%d)", int(t)) + } +} + +// ByteSize returns the size in bytes of a single element of this type. +func (t ElementType) ByteSize() int { + switch t { + case ElementTypePRED, ElementTypeS8, ElementTypeU8, ElementTypeF8E4M3: + return 1 + case ElementTypeS16, ElementTypeU16, ElementTypeF16, ElementTypeBF16: + return 2 + case ElementTypeS32, ElementTypeU32, ElementTypeF32: + return 4 + case ElementTypeS64, ElementTypeU64, ElementTypeF64: + return 8 + default: + return 0 + } +} + +// GoTypeToElementType maps a Go type (via its size and kind) to the +// corresponding PJRT element type. +func GoTypeToElementType[T any]() ElementType { + var zero T + switch any(zero).(type) { + case float32: + return ElementTypeF32 + case float64: + return ElementTypeF64 + case float16.Float16: + return ElementTypeF16 + case float16.BFloat16: + return ElementTypeBF16 + case float8.Float8: + return ElementTypeF8E4M3 + case int32: + return ElementTypeS32 + case int64: + return ElementTypeS64 + case int16: + return ElementTypeS16 + case int8: + return ElementTypeS8 + case uint8: + return ElementTypeU8 + case uint16: + return ElementTypeU16 + case uint32: + return ElementTypeU32 + case uint64: + return ElementTypeU64 + case bool: + return ElementTypePRED + default: + return ElementTypeInvalid + } +} + +// HostBufferSemantics controls how PJRT handles the host data pointer +// during BufferFromHostBuffer. +type HostBufferSemantics int32 + +const ( + // HostBufferImmutableOnlyDuringCall means PJRT copies the data during + // the call and the host buffer can be modified immediately after return. + HostBufferImmutableOnlyDuringCall HostBufferSemantics = 0 + + // HostBufferImmutableUntilTransferCompletes means the host buffer must + // remain valid until the returned event completes. Avoids a copy on + // some backends. + HostBufferImmutableUntilTransferCompletes HostBufferSemantics = 1 + + // HostBufferImmutableZeroCopy means PJRT uses the host memory directly + // (zero-copy). The host buffer must remain valid for the buffer lifetime. + HostBufferImmutableZeroCopy HostBufferSemantics = 2 +) + +// Buffer wraps a PJRT_Buffer handle and provides Go-friendly methods +// for device-to-host readback, metadata queries, and lifecycle management. +// +// Buffers must be closed with Close() when no longer needed. Double-close +// is a safe no-op (finalizer safety). +type Buffer struct { + lib *PJRTLib + client uintptr // PJRT_Client* (for readback calls) + handle uintptr // PJRT_Buffer* + + mu sync.Mutex + closed bool +} + +// BufferFromHost transfers a Go slice to a PJRT device buffer. +// +// The data slice is copied during the call (ImmutableOnlyDuringCall semantics +// by default). The shape describes the tensor dimensions. The target device +// determines where the buffer is placed. +// +// Use WithDonation() to enable buffer donation for KV cache optimization. +func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device, opts ...BufferOption) (*Buffer, error) { + if client == nil || client.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from nil or closed client") + } + if device == nil || device.handle == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer on nil or closed device") + } + if len(data) == 0 { + return nil, fmt.Errorf("pjrt: cannot create buffer from empty data") + } + + elemType := GoTypeToElementType[T]() + if elemType == ElementTypeInvalid { + return nil, fmt.Errorf("pjrt: unsupported Go type for PJRT buffer") + } + + // Verify element count matches shape. + numElements := 1 + for _, d := range shape { + numElements *= d + } + if numElements != len(data) { + return nil, fmt.Errorf("pjrt: shape %v requires %d elements, got %d", shape, numElements, len(data)) + } + + cfg := bufferConfig{ + semantics: HostBufferImmutableOnlyDuringCall, + } + for _, o := range opts { + o(&cfg) + } + + lib := client.lib + + // Build the int64 dimensions array that PJRT expects. + dims := make([]int64, len(shape)) + for i, d := range shape { + dims[i] = int64(d) + } + + var dimsPtr uintptr + if len(dims) > 0 { + dimsPtr = uintptr(unsafe.Pointer(&dims[0])) + } + + // PJRT_Client_BufferFromHostBuffer_Args: + // struct_size uintptr + // client uintptr (PJRT_Client*) + // data uintptr (const void*) + // type int32 (PJRT_Buffer_Type) + // _ [4]byte (padding) + // dims uintptr (const int64_t*) + // num_dims uintptr (size_t) + // byte_strides uintptr (const int64_t*, may be 0) + // num_byte_strides uintptr (size_t) + // host_buffer_semantics int32 (PJRT_HostBufferSemantics) + // _ [4]byte (padding) + // device uintptr (PJRT_Device*) + // memory uintptr (PJRT_Memory*, may be 0) + // device_layout uintptr (PJRT_Buffer_MemoryLayout*, may be 0) + // done_with_host_buffer uintptr (out: PJRT_Event*) + // buffer uintptr (out: PJRT_Buffer*) + type bufferFromHostArgs struct { + structSize uintptr + client uintptr + data uintptr + typ int32 + _ [4]byte + dims uintptr + numDims uintptr + byteStrides uintptr + numByteStrides uintptr + hostBufferSemantics int32 + _ [4]byte + device uintptr + memory uintptr + deviceLayout uintptr + doneWithHostBuffer uintptr + buffer uintptr + } + + args := bufferFromHostArgs{ + structSize: unsafe.Sizeof(bufferFromHostArgs{}), + client: client.handle, + data: uintptr(unsafe.Pointer(&data[0])), + typ: int32(elemType), + dims: dimsPtr, + numDims: uintptr(len(dims)), + hostBufferSemantics: int32(cfg.semantics), + device: device.handle, + } + + errPtr := cuda.Ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Client_BufferFromHostBuffer: %w", err) + } + if args.buffer == 0 { + return nil, fmt.Errorf("pjrt: PJRT_Client_BufferFromHostBuffer returned null buffer") + } + + // If the transfer produces a done event, wait for it so the host + // buffer is safe to reuse immediately. + if args.doneWithHostBuffer != 0 { + if err := lib.awaitEvent(args.doneWithHostBuffer); err != nil { + return nil, fmt.Errorf("pjrt: await host buffer transfer: %w", err) + } + lib.destroyEvent(args.doneWithHostBuffer) + } + + return &Buffer{ + lib: lib, + client: client.handle, + handle: args.buffer, + }, nil +} + +// ToHost copies device buffer data back to a pre-allocated Go slice. +// +// The destination slice must have exactly the right number of elements +// (product of Shape dimensions). The call blocks until the readback +// completes (PJRT_Event_Await). +func (b *Buffer) ToHost(dst []byte) error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + if len(dst) == 0 { + return fmt.Errorf("pjrt: destination slice is empty") + } + + // PJRT_Buffer_ToHostBuffer_Args: + // struct_size uintptr + // src uintptr (PJRT_Buffer*) + // dst uintptr (void*) + // dst_size uintptr (size_t, bytes) + // event uintptr (out: PJRT_Event*) + type toHostArgs struct { + structSize uintptr + src uintptr + dst uintptr + dstSize uintptr + event uintptr + } + + args := toHostArgs{ + structSize: unsafe.Sizeof(toHostArgs{}), + src: b.handle, + dst: uintptr(unsafe.Pointer(&dst[0])), + dstSize: uintptr(len(dst)), + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return fmt.Errorf("PJRT_Buffer_ToHostBuffer: %w", err) + } + + // Wait for the async readback to complete. + if args.event != 0 { + if err := b.lib.awaitEvent(args.event); err != nil { + return fmt.Errorf("pjrt: await readback: %w", err) + } + b.lib.destroyEvent(args.event) + } + + return nil +} + +// ToHostSlice is a typed convenience wrapper around ToHost that copies +// device buffer data into a pre-allocated Go slice of the appropriate type. +func ToHostSlice[T any](b *Buffer, dst []T) error { + var zero T + elemSize := int(unsafe.Sizeof(zero)) + byteLen := len(dst) * elemSize + bytes := unsafe.Slice((*byte)(unsafe.Pointer(&dst[0])), byteLen) + return b.ToHost(bytes) +} + +// Dtype returns the PJRT element type of this buffer. +func (b *Buffer) Dtype() (ElementType, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return ElementTypeInvalid, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ElementType_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // type int32 (out: PJRT_Buffer_Type) + type elementTypeArgs struct { + structSize uintptr + buffer uintptr + typ int32 + _ [4]byte + } + + args := elementTypeArgs{ + structSize: unsafe.Sizeof(elementTypeArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return ElementTypeInvalid, fmt.Errorf("PJRT_Buffer_ElementType: %w", err) + } + return ElementType(args.typ), nil +} + +// Shape returns the dimensions of this buffer. +func (b *Buffer) Shape() ([]int, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_Dimensions_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // dims uintptr (out: const int64_t*) + // num_dims uintptr (out: size_t) + type dimensionsArgs struct { + structSize uintptr + buffer uintptr + dims uintptr + numDims uintptr + } + + args := dimensionsArgs{ + structSize: unsafe.Sizeof(dimensionsArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return nil, fmt.Errorf("PJRT_Buffer_Dimensions: %w", err) + } + + if args.numDims == 0 { + return nil, nil // scalar + } + + cDims := unsafe.Slice((*int64)(unsafe.Pointer(args.dims)), int(args.numDims)) + shape := make([]int, len(cDims)) + for i, d := range cDims { + shape[i] = int(d) + } + return shape, nil +} + +// OnDeviceSizeInBytes returns the buffer's memory footprint on the device. +func (b *Buffer) OnDeviceSizeInBytes() (int64, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_OnDeviceSizeInBytes_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // on_device_size int64 (out: size_t) + type sizeArgs struct { + structSize uintptr + buffer uintptr + onDeviceSize int64 + } + + args := sizeArgs{ + structSize: unsafe.Sizeof(sizeArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_OnDeviceSizeInBytes: %w", err) + } + return args.onDeviceSize, nil +} + +// ReadyEvent returns the PJRT_Event handle for this buffer's readiness. +// The caller is responsible for destroying the event via awaitEvent or +// destroyEvent. +func (b *Buffer) ReadyEvent() (uintptr, error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return 0, fmt.Errorf("pjrt: buffer is closed") + } + b.mu.Unlock() + + // PJRT_Buffer_ReadyEvent_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + // event uintptr (out: PJRT_Event*) + type readyEventArgs struct { + structSize uintptr + buffer uintptr + event uintptr + } + + args := readyEventArgs{ + structSize: unsafe.Sizeof(readyEventArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) + if err := b.lib.checkError(errPtr); err != nil { + return 0, fmt.Errorf("PJRT_Buffer_ReadyEvent: %w", err) + } + return args.event, nil +} + +// Delete marks the buffer for deletion. The runtime may release the +// device memory immediately or defer it. After Delete, the buffer +// handle should not be used for data access, but Destroy is still +// required for handle cleanup. +func (b *Buffer) Delete() error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil + } + b.mu.Unlock() + + // PJRT_Buffer_Delete_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type deleteArgs struct { + structSize uintptr + buffer uintptr + } + + args := deleteArgs{ + structSize: unsafe.Sizeof(deleteArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) + return b.lib.checkError(errPtr) +} + +// Close destroys the PJRT buffer handle and releases associated resources. +// Safe to call multiple times (double-close is a no-op for finalizer safety). +func (b *Buffer) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.closed { + return nil + } + b.closed = true + + // PJRT_Buffer_Destroy_Args: + // struct_size uintptr + // buffer uintptr (PJRT_Buffer*) + type destroyArgs struct { + structSize uintptr + buffer uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + buffer: b.handle, + } + + errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) + b.handle = 0 + return b.lib.checkError(errPtr) +} + +// Handle returns the raw PJRT_Buffer pointer. +func (b *Buffer) Handle() uintptr { + return b.handle +} + +// awaitEvent calls PJRT_Event_Await to block until the event completes. +func (lib *PJRTLib) awaitEvent(event uintptr) error { + if event == 0 { + return nil + } + + // PJRT_Event_Await_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type awaitArgs struct { + structSize uintptr + event uintptr + } + + args := awaitArgs{ + structSize: unsafe.Sizeof(awaitArgs{}), + event: event, + } + + errPtr := cuda.Ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) + return lib.checkError(errPtr) +} + +// destroyEvent frees a PJRT_Event. Safe to call with event == 0. +func (lib *PJRTLib) destroyEvent(event uintptr) { + if event == 0 { + return + } + + // PJRT_Event_Destroy_Args: + // struct_size uintptr + // event uintptr (PJRT_Event*) + type destroyArgs struct { + structSize uintptr + event uintptr + } + + args := destroyArgs{ + structSize: unsafe.Sizeof(destroyArgs{}), + event: event, + } + cuda.Ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) +} + +// BufferOption configures BufferFromHost behavior. +type BufferOption func(*bufferConfig) + +type bufferConfig struct { + semantics HostBufferSemantics +} + +// WithSemantics sets the host buffer semantics for the transfer. +func WithSemantics(s HostBufferSemantics) BufferOption { + return func(c *bufferConfig) { + c.semantics = s + } +} + +// WithDonation enables buffer donation semantics. The runtime is allowed +// to take ownership of the host memory, avoiding a copy. The caller must +// not access the source slice after calling BufferFromHost with this option. +func WithDonation() BufferOption { + return func(c *bufferConfig) { + c.semantics = HostBufferImmutableZeroCopy + } +} From 30953e4f975a32aaace1e27f92304731c88015fe Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:20:40 -0700 Subject: [PATCH 4/4] fix(pjrt): centralize internal/cuda import in pjrt.go Add ccall, dlopenPath, and dlsym forwarding functions in pjrt.go so that buffer.go, client.go, and device.go do not import internal/cuda directly. Only pjrt.go holds the cuda dependency. --- internal/pjrt/buffer.go | 53 ++++++++++++++++++++--------------------- internal/pjrt/client.go | 14 +++++------ internal/pjrt/device.go | 12 ++++------ internal/pjrt/pjrt.go | 27 +++++++++++++++++---- 4 files changed, 59 insertions(+), 47 deletions(-) diff --git a/internal/pjrt/buffer.go b/internal/pjrt/buffer.go index 509581c..e600255 100644 --- a/internal/pjrt/buffer.go +++ b/internal/pjrt/buffer.go @@ -7,7 +7,6 @@ import ( "github.com/zerfoo/float16" "github.com/zerfoo/float8" - "github.com/zerfoo/ztensor/internal/cuda" ) // ElementType mirrors the PJRT_Buffer_Type enum from the PJRT C API. @@ -224,22 +223,22 @@ func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device // done_with_host_buffer uintptr (out: PJRT_Event*) // buffer uintptr (out: PJRT_Buffer*) type bufferFromHostArgs struct { - structSize uintptr - client uintptr - data uintptr - typ int32 - _ [4]byte - dims uintptr - numDims uintptr - byteStrides uintptr - numByteStrides uintptr - hostBufferSemantics int32 - _ [4]byte - device uintptr - memory uintptr - deviceLayout uintptr - doneWithHostBuffer uintptr - buffer uintptr + structSize uintptr + client uintptr + data uintptr + typ int32 + _ [4]byte + dims uintptr + numDims uintptr + byteStrides uintptr + numByteStrides uintptr + hostBufferSemantics int32 + _ [4]byte + device uintptr + memory uintptr + deviceLayout uintptr + doneWithHostBuffer uintptr + buffer uintptr } args := bufferFromHostArgs{ @@ -253,7 +252,7 @@ func BufferFromHost[T any](client *Client, data []T, shape []int, device *Device device: device.handle, } - errPtr := cuda.Ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Client_BufferFromHostBuffer, uintptr(unsafe.Pointer(&args))) if err := lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_BufferFromHostBuffer: %w", err) } @@ -315,7 +314,7 @@ func (b *Buffer) ToHost(dst []byte) error { dstSize: uintptr(len(dst)), } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ToHostBuffer, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return fmt.Errorf("PJRT_Buffer_ToHostBuffer: %w", err) } @@ -366,7 +365,7 @@ func (b *Buffer) Dtype() (ElementType, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ElementType, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return ElementTypeInvalid, fmt.Errorf("PJRT_Buffer_ElementType: %w", err) } @@ -399,7 +398,7 @@ func (b *Buffer) Shape() ([]int, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Dimensions, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Buffer_Dimensions: %w", err) } @@ -440,7 +439,7 @@ func (b *Buffer) OnDeviceSizeInBytes() (int64, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_OnDeviceSizeInBytes, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Buffer_OnDeviceSizeInBytes: %w", err) } @@ -473,7 +472,7 @@ func (b *Buffer) ReadyEvent() (uintptr, error) { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_ReadyEvent, uintptr(unsafe.Pointer(&args))) if err := b.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Buffer_ReadyEvent: %w", err) } @@ -505,7 +504,7 @@ func (b *Buffer) Delete() error { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Delete, uintptr(unsafe.Pointer(&args))) return b.lib.checkError(errPtr) } @@ -533,7 +532,7 @@ func (b *Buffer) Close() error { buffer: b.handle, } - errPtr := cuda.Ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(b.lib.PJRT_Buffer_Destroy, uintptr(unsafe.Pointer(&args))) b.handle = 0 return b.lib.checkError(errPtr) } @@ -562,7 +561,7 @@ func (lib *PJRTLib) awaitEvent(event uintptr) error { event: event, } - errPtr := cuda.Ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Event_Await, uintptr(unsafe.Pointer(&args))) return lib.checkError(errPtr) } @@ -584,7 +583,7 @@ func (lib *PJRTLib) destroyEvent(event uintptr) { structSize: unsafe.Sizeof(destroyArgs{}), event: event, } - cuda.Ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Event_Destroy, uintptr(unsafe.Pointer(&args))) } // BufferOption configures BufferFromHost behavior. diff --git a/internal/pjrt/client.go b/internal/pjrt/client.go index 6ed2169..1353190 100644 --- a/internal/pjrt/client.go +++ b/internal/pjrt/client.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Client wraps a PJRT_Client handle and provides Go-friendly methods @@ -57,7 +55,7 @@ func NewClient(lib *PJRTLib, opts ...ClientOption) (*Client, error) { createOptions: cfg.createOptions, } - errPtr := cuda.Ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(lib.PJRT_Client_Create, uintptr(unsafe.Pointer(&args))) if err := lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Create: %w", err) } @@ -83,7 +81,7 @@ func (c *Client) Close() error { structSize: unsafe.Sizeof(destroyArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Destroy, uintptr(unsafe.Pointer(&args))) c.handle = 0 return c.lib.checkError(errPtr) } @@ -105,7 +103,7 @@ func (c *Client) PlatformName() (string, error) { structSize: unsafe.Sizeof(platformNameArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformName, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformName: %w", err) } @@ -124,7 +122,7 @@ func (c *Client) PlatformVersion() (string, error) { structSize: unsafe.Sizeof(platformVersionArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_PlatformVersion, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_Client_PlatformVersion: %w", err) } @@ -148,7 +146,7 @@ func (c *Client) Devices() ([]*Device, error) { structSize: unsafe.Sizeof(devicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_Devices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_Devices: %w", err) } @@ -167,7 +165,7 @@ func (c *Client) AddressableDevices() ([]*Device, error) { structSize: unsafe.Sizeof(addressableDevicesArgs{}), client: c.handle, } - errPtr := cuda.Ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(c.lib.PJRT_Client_AddressableDevices, uintptr(unsafe.Pointer(&args))) if err := c.lib.checkError(errPtr); err != nil { return nil, fmt.Errorf("PJRT_Client_AddressableDevices: %w", err) } diff --git a/internal/pjrt/device.go b/internal/pjrt/device.go index eeeb94a..ff573ba 100644 --- a/internal/pjrt/device.go +++ b/internal/pjrt/device.go @@ -3,8 +3,6 @@ package pjrt import ( "fmt" "unsafe" - - "github.com/zerfoo/ztensor/internal/cuda" ) // Device wraps a PJRT_Device handle and provides methods for @@ -37,7 +35,7 @@ func (d *Device) ID() (int, error) { structSize: unsafe.Sizeof(idArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Id, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_DeviceDescription_Id: %w", err) } @@ -66,7 +64,7 @@ func (d *Device) Kind() (string, error) { structSize: unsafe.Sizeof(kindArgs{}), deviceDescription: desc, } - errPtr := cuda.Ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_DeviceDescription_Kind, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return "", fmt.Errorf("PJRT_DeviceDescription_Kind: %w", err) } @@ -90,7 +88,7 @@ func (d *Device) IsAddressable() (bool, error) { structSize: unsafe.Sizeof(isAddressableArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_IsAddressable, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return false, fmt.Errorf("PJRT_Device_IsAddressable: %w", err) } @@ -114,7 +112,7 @@ func (d *Device) LocalHardwareId() (int, error) { structSize: unsafe.Sizeof(localHWIDArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_LocalHardwareId, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_LocalHardwareId: %w", err) } @@ -142,7 +140,7 @@ func (d *Device) getDescription() (uintptr, error) { structSize: unsafe.Sizeof(getDescArgs{}), device: d.handle, } - errPtr := cuda.Ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) + errPtr := ccall(d.lib.PJRT_Device_GetDescription, uintptr(unsafe.Pointer(&args))) if err := d.lib.checkError(errPtr); err != nil { return 0, fmt.Errorf("PJRT_Device_GetDescription: %w", err) } diff --git a/internal/pjrt/pjrt.go b/internal/pjrt/pjrt.go index 109917d..13247da 100644 --- a/internal/pjrt/pjrt.go +++ b/internal/pjrt/pjrt.go @@ -132,7 +132,7 @@ func Load(pluginName string) (*PJRTLib, error) { lib := &PJRTLib{} var lastErr string for _, path := range candidates { - h, err := cuda.DlopenPath(path) + h, err := dlopenPath(path) if err == nil { lib.handle = h break @@ -144,14 +144,14 @@ func Load(pluginName string) (*PJRTLib, error) { } // Resolve the single entry point. - getPjrtApi, err := cuda.Dlsym(lib.handle, "GetPjrtApi") + getPjrtApi, err := dlsym(lib.handle, "GetPjrtApi") if err != nil { lib.Close() return nil, fmt.Errorf("pjrt: dlsym GetPjrtApi: %w", err) } // Call GetPjrtApi() -> *PJRT_Api. - apiPtr := cuda.Ccall(getPjrtApi) + apiPtr := ccall(getPjrtApi) if apiPtr == 0 { lib.Close() return nil, fmt.Errorf("pjrt: GetPjrtApi returned null") @@ -240,7 +240,7 @@ func (lib *PJRTLib) errorMessage(errPtr uintptr) string { structSize: unsafe.Sizeof(errorMessageArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Message, uintptr(unsafe.Pointer(&args))) if args.message == 0 || args.messageLen == 0 { return "unknown PJRT error" @@ -261,7 +261,7 @@ func (lib *PJRTLib) destroyError(errPtr uintptr) { structSize: unsafe.Sizeof(destroyArgs{}), error: errPtr, } - cuda.Ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) + ccall(lib.PJRT_Error_Destroy, uintptr(unsafe.Pointer(&args))) } // checkError converts a PJRT_Error pointer to a Go error. @@ -275,6 +275,23 @@ func (lib *PJRTLib) checkError(errPtr uintptr) error { return fmt.Errorf("pjrt: %s", msg) } +// ccall calls a C function pointer with the given arguments. +// Centralizes the internal/cuda dependency so other files in this +// package do not need to import it directly. +func ccall(fn uintptr, args ...uintptr) uintptr { + return cuda.Ccall(fn, args...) +} + +// dlopenPath opens a shared library at the given path. +func dlopenPath(path string) (uintptr, error) { + return cuda.DlopenPath(path) +} + +// dlsym resolves a symbol from a dlopen handle. +func dlsym(handle uintptr, name string) (uintptr, error) { + return cuda.Dlsym(handle, name) +} + // goStringN converts a C string pointer and length to a Go string. // //go:nosplit