From 5fca11ea01c8f821678b982870fab906ef889260 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 14:56:29 -0700 Subject: [PATCH 1/2] feat(stablehlo): add shape inference for arithmetic ops --- internal/stablehlo/shapes.go | 95 +++++++++++++++++ internal/stablehlo/shapes_test.go | 167 ++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 internal/stablehlo/shapes.go create mode 100644 internal/stablehlo/shapes_test.go diff --git a/internal/stablehlo/shapes.go b/internal/stablehlo/shapes.go new file mode 100644 index 0000000..58f880f --- /dev/null +++ b/internal/stablehlo/shapes.go @@ -0,0 +1,95 @@ +// Package stablehlo provides StableHLO MLIR text emission for the PJRT backend. +package stablehlo + +import ( + "fmt" + "slices" +) + +// InferShape computes the output shape for a given operation name, input shapes, +// and optional attributes. It returns an error if the shapes are incompatible. +func InferShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error) { + switch opName { + // Element-wise binary ops with numpy-style broadcasting. + case "Add", "Sub", "Mul", "Div": + return inferBroadcastBinary(opName, inputShapes) + + // Scalar broadcast ops: output = shape of the non-scalar input. + case "MulScalar", "DivScalar", "AddScalar": + return inferScalarBroadcast(opName, inputShapes) + + // Unary ops: output shape = input shape. + case "Exp", "Log", "Sin", "Cos", "Tanh", "Sqrt", "Rsqrt", "Neg", "Abs": + return inferUnary(opName, inputShapes) + + // Binary same-shape ops: Pow takes two inputs of the same shape. + case "Pow": + return inferBroadcastBinary(opName, inputShapes) + + default: + return nil, fmt.Errorf("stablehlo.InferShape: unsupported op %q", opName) + } +} + +// inferBroadcastBinary computes the broadcast output shape for two input shapes +// using numpy-style broadcasting: dimensions are aligned from the trailing end, +// and each pair must be equal or one of them must be 1. +func inferBroadcastBinary(opName string, inputShapes [][]int) ([]int, error) { + if len(inputShapes) != 2 { + return nil, fmt.Errorf("stablehlo.InferShape(%s): expected 2 input shapes, got %d", opName, len(inputShapes)) + } + return broadcastShapes(opName, inputShapes[0], inputShapes[1]) +} + +// broadcastShapes returns the broadcast-compatible output shape for a and b, +// or an error if broadcasting is impossible. +func broadcastShapes(opName string, a, b []int) ([]int, error) { + rank := max(len(a), len(b)) + out := make([]int, rank) + + for i := range rank { + // Index from the trailing end. + ai := len(a) - rank + i + bi := len(b) - rank + i + + da := 1 + if ai >= 0 { + da = a[ai] + } + db := 1 + if bi >= 0 { + db = b[bi] + } + + switch { + case da == db: + out[i] = da + case da == 1: + out[i] = db + case db == 1: + out[i] = da + default: + return nil, fmt.Errorf("stablehlo.InferShape(%s): incompatible shapes %v and %v at dimension %d (%d vs %d)", + opName, a, b, i, da, db) + } + } + return out, nil +} + +// inferScalarBroadcast handles ops like MulScalar where one input is a scalar +// and the output takes the shape of the non-scalar (or first) input. +func inferScalarBroadcast(opName string, inputShapes [][]int) ([]int, error) { + if len(inputShapes) < 1 || len(inputShapes) > 2 { + return nil, fmt.Errorf("stablehlo.InferShape(%s): expected 1-2 input shapes, got %d", opName, len(inputShapes)) + } + // The first input is the tensor; the second (if present) is the scalar. + return slices.Clone(inputShapes[0]), nil +} + +// inferUnary validates that exactly one input is provided and returns its shape. +func inferUnary(opName string, inputShapes [][]int) ([]int, error) { + if len(inputShapes) != 1 { + return nil, fmt.Errorf("stablehlo.InferShape(%s): expected 1 input shape, got %d", opName, len(inputShapes)) + } + return slices.Clone(inputShapes[0]), nil +} diff --git a/internal/stablehlo/shapes_test.go b/internal/stablehlo/shapes_test.go new file mode 100644 index 0000000..a0fe3b5 --- /dev/null +++ b/internal/stablehlo/shapes_test.go @@ -0,0 +1,167 @@ +package stablehlo + +import ( + "slices" + "testing" +) + +func TestInferShapeSameShapeArithmetic(t *testing.T) { + for _, op := range []string{"Add", "Sub", "Mul", "Div"} { + got, err := InferShape(op, [][]int{{2, 3}, {2, 3}}, nil) + if err != nil { + t.Fatalf("%s same-shape: %v", op, err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("%s same-shape: got %v, want [2 3]", op, got) + } + } +} + +func TestInferShapeBroadcast2D(t *testing.T) { + // {2,3} + {1,3} -> {2,3} + got, err := InferShape("Add", [][]int{{2, 3}, {1, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("got %v, want [2 3]", got) + } +} + +func TestInferShapeBroadcast3D(t *testing.T) { + // {4,1,3} + {1,5,3} -> {4,5,3} + got, err := InferShape("Add", [][]int{{4, 1, 3}, {1, 5, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{4, 5, 3}) { + t.Errorf("got %v, want [4 5 3]", got) + } +} + +func TestInferShapeBroadcastRankMismatch(t *testing.T) { + // {3} + {2,3} -> {2,3} (lower-rank tensor is left-padded with 1s) + got, err := InferShape("Mul", [][]int{{3}, {2, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("got %v, want [2 3]", got) + } +} + +func TestInferShapeScalarOps(t *testing.T) { + for _, op := range []string{"MulScalar", "DivScalar", "AddScalar"} { + got, err := InferShape(op, [][]int{{4, 5}}, nil) + if err != nil { + t.Fatalf("%s: %v", op, err) + } + if !slices.Equal(got, []int{4, 5}) { + t.Errorf("%s: got %v, want [4 5]", op, got) + } + } +} + +func TestInferShapeScalarOpsTwoInputs(t *testing.T) { + // Scalar ops with explicit scalar shape as second input. + got, err := InferShape("MulScalar", [][]int{{3, 4}, {}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{3, 4}) { + t.Errorf("got %v, want [3 4]", got) + } +} + +func TestInferShapeUnaryOps(t *testing.T) { + unary := []string{"Exp", "Log", "Sin", "Cos", "Tanh", "Sqrt", "Rsqrt", "Neg", "Abs"} + for _, op := range unary { + got, err := InferShape(op, [][]int{{2, 3, 4}}, nil) + if err != nil { + t.Fatalf("%s: %v", op, err) + } + if !slices.Equal(got, []int{2, 3, 4}) { + t.Errorf("%s: got %v, want [2 3 4]", op, got) + } + } +} + +func TestInferShapePow(t *testing.T) { + got, err := InferShape("Pow", [][]int{{2, 3}, {2, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("got %v, want [2 3]", got) + } +} + +func TestInferShapePowBroadcast(t *testing.T) { + // Pow supports broadcasting: {2,3} ** {1,3} -> {2,3} + got, err := InferShape("Pow", [][]int{{2, 3}, {1, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("got %v, want [2 3]", got) + } +} + +func TestInferShapeIncompatibleShapes(t *testing.T) { + _, err := InferShape("Add", [][]int{{2, 3}, {4, 3}}, nil) + if err == nil { + t.Fatal("expected error for incompatible shapes {2,3} and {4,3}") + } +} + +func TestInferShapeIncompatibleShapes3D(t *testing.T) { + _, err := InferShape("Mul", [][]int{{2, 5, 3}, {2, 4, 3}}, nil) + if err == nil { + t.Fatal("expected error for incompatible shapes {2,5,3} and {2,4,3}") + } +} + +func TestInferShapeWrongInputCount(t *testing.T) { + // Binary op with one input. + _, err := InferShape("Add", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for Add with 1 input") + } + + // Unary op with two inputs. + _, err = InferShape("Exp", [][]int{{2, 3}, {2, 3}}, nil) + if err == nil { + t.Fatal("expected error for Exp with 2 inputs") + } +} + +func TestInferShapeUnsupportedOp(t *testing.T) { + _, err := InferShape("FooBarBaz", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for unsupported op") + } +} + +func TestInferShapeScalarInputs(t *testing.T) { + // Two scalar (rank-0) inputs. + got, err := InferShape("Add", [][]int{{}, {}}, nil) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("got %v, want [] (scalar)", got) + } +} + +func TestInferShapeOutputNotAliased(t *testing.T) { + // Verify the returned slice is a copy, not the original input. + input := [][]int{{5, 6}} + got, err := InferShape("Neg", input, nil) + if err != nil { + t.Fatal(err) + } + got[0] = 999 + if input[0][0] != 5 { + t.Error("InferShape returned a slice that aliases the input") + } +} From 508477af2d98abeb543c06ef72638a1f0a3bff91 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 14:59:14 -0700 Subject: [PATCH 2/2] feat(stablehlo): add shape inference for structural ops Implement InferStructuralShape for MatMul (2D and batched dot_general), Transpose (axis permutation), Reshape (element-count validated), Concat (axis concatenation), Slice (start/end indices), Gather (indices + slice sizes), and ReduceSum/ReduceMax/ReduceMean (with keepDims support). Separate file from shapes.go to avoid conflicts with T61.1.2. --- internal/stablehlo/shapes_structural.go | 318 ++++++++++++ internal/stablehlo/shapes_structural_test.go | 502 +++++++++++++++++++ 2 files changed, 820 insertions(+) create mode 100644 internal/stablehlo/shapes_structural.go create mode 100644 internal/stablehlo/shapes_structural_test.go diff --git a/internal/stablehlo/shapes_structural.go b/internal/stablehlo/shapes_structural.go new file mode 100644 index 0000000..20f6b68 --- /dev/null +++ b/internal/stablehlo/shapes_structural.go @@ -0,0 +1,318 @@ +package stablehlo + +import ( + "fmt" + "slices" +) + +// InferStructuralShape computes the output shape for structural operations: +// MatMul, Transpose, Reshape, Concat, Slice, Gather, ReduceSum, ReduceMax, ReduceMean. +// +// attrs supports the following keys depending on the operation: +// +// - "perm" ([]int): axis permutation for Transpose +// - "shape" ([]int): target shape for Reshape +// - "axis" (int): concatenation axis for Concat, reduction axis for Reduce* +// - "start" ([]int): start indices for Slice +// - "end" ([]int): end indices for Slice +// - "sliceSizes" ([]int): slice sizes for Gather +// - "keepDims" (bool): whether to keep the reduced dimension for Reduce* +func InferStructuralShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error) { + switch opName { + case "MatMul": + return inferMatMul(inputShapes) + case "Transpose": + return inferTranspose(inputShapes, attrs) + case "Reshape": + return inferReshape(inputShapes, attrs) + case "Concat": + return inferConcat(inputShapes, attrs) + case "Slice": + return inferSlice(inputShapes, attrs) + case "Gather": + return inferGather(inputShapes, attrs) + case "ReduceSum", "ReduceMax", "ReduceMean": + return inferReduce(opName, inputShapes, attrs) + default: + return nil, fmt.Errorf("stablehlo.InferStructuralShape: unsupported op %q", opName) + } +} + +// inferMatMul computes the output shape for matrix multiplication (dot_general). +// +// 2D: [M,K] @ [K,N] -> [M,N] +// 3D batched: [B,M,K] @ [B,K,N] -> [B,M,N] +// Higher-rank batched: [...,M,K] @ [...,K,N] -> [...,M,N] +func inferMatMul(inputShapes [][]int) ([]int, error) { + if len(inputShapes) != 2 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(MatMul): expected 2 input shapes, got %d", len(inputShapes)) + } + a, b := inputShapes[0], inputShapes[1] + if len(a) < 2 || len(b) < 2 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(MatMul): inputs must be at least rank 2, got rank %d and %d", len(a), len(b)) + } + + // Contraction dimension: last of a must match second-to-last of b. + m, k := a[len(a)-2], a[len(a)-1] + k2, n := b[len(b)-2], b[len(b)-1] + if k != k2 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(MatMul): contraction dimension mismatch: %v (K=%d) and %v (K=%d)", a, k, b, k2) + } + + // Batch dimensions: everything except the last two dims. + batchA := a[:len(a)-2] + batchB := b[:len(b)-2] + if len(batchA) != len(batchB) { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(MatMul): batch rank mismatch: %v and %v", a, b) + } + for i := range batchA { + if batchA[i] != batchB[i] { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(MatMul): batch dimension %d mismatch: %d vs %d", i, batchA[i], batchB[i]) + } + } + + out := make([]int, 0, len(batchA)+2) + out = append(out, batchA...) + out = append(out, m, n) + return out, nil +} + +// inferTranspose computes the output shape for axis permutation. +// Requires attrs["perm"] ([]int) specifying the permutation. +// Example: [2,3,4] with perm [2,0,1] -> [4,2,3]. +func inferTranspose(inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) != 1 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): expected 1 input shape, got %d", len(inputShapes)) + } + shape := inputShapes[0] + + permRaw, ok := attrs["perm"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): missing required attr \"perm\"") + } + perm, ok := permRaw.([]int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): attr \"perm\" must be []int, got %T", permRaw) + } + if len(perm) != len(shape) { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): perm length %d does not match rank %d", len(perm), len(shape)) + } + + // Validate permutation: must be a valid permutation of [0..rank-1]. + seen := make([]bool, len(perm)) + for i, p := range perm { + if p < 0 || p >= len(shape) { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): perm[%d]=%d out of range [0,%d)", i, p, len(shape)) + } + if seen[p] { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Transpose): duplicate axis %d in perm", p) + } + seen[p] = true + } + + out := make([]int, len(shape)) + for i, p := range perm { + out[i] = shape[p] + } + return out, nil +} + +// inferReshape computes the output shape for a reshape operation. +// Requires attrs["shape"] ([]int) specifying the target shape. +// Validates that the total element count matches. +func inferReshape(inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) != 1 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Reshape): expected 1 input shape, got %d", len(inputShapes)) + } + + targetRaw, ok := attrs["shape"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Reshape): missing required attr \"shape\"") + } + target, ok := targetRaw.([]int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Reshape): attr \"shape\" must be []int, got %T", targetRaw) + } + + srcElems := numElements(inputShapes[0]) + dstElems := numElements(target) + if srcElems != dstElems { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Reshape): element count mismatch: input has %d elements, target shape %v has %d", srcElems, target, dstElems) + } + return slices.Clone(target), nil +} + +// inferConcat computes the output shape for axis concatenation. +// Requires attrs["axis"] (int) specifying the concatenation axis. +// All input shapes must match on all dimensions except the concat axis. +func inferConcat(inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) < 2 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): expected at least 2 input shapes, got %d", len(inputShapes)) + } + + axisRaw, ok := attrs["axis"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): missing required attr \"axis\"") + } + axis, ok := axisRaw.(int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): attr \"axis\" must be int, got %T", axisRaw) + } + + rank := len(inputShapes[0]) + if rank == 0 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): cannot concatenate scalar tensors") + } + if axis < 0 || axis >= rank { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): axis %d out of range [0,%d)", axis, rank) + } + + // Validate all inputs have the same rank and match on non-concat dims. + concatSize := 0 + for i, s := range inputShapes { + if len(s) != rank { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): input %d has rank %d, expected %d", i, len(s), rank) + } + for d := range rank { + if d == axis { + continue + } + if s[d] != inputShapes[0][d] { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Concat): input %d dim %d is %d, expected %d", i, d, s[d], inputShapes[0][d]) + } + } + concatSize += s[axis] + } + + out := slices.Clone(inputShapes[0]) + out[axis] = concatSize + return out, nil +} + +// inferSlice computes the output shape for a slice operation. +// Requires attrs["start"] ([]int) and attrs["end"] ([]int). +// Output shape[i] = end[i] - start[i]. +func inferSlice(inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) != 1 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): expected 1 input shape, got %d", len(inputShapes)) + } + shape := inputShapes[0] + + startRaw, ok := attrs["start"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): missing required attr \"start\"") + } + start, ok := startRaw.([]int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): attr \"start\" must be []int, got %T", startRaw) + } + + endRaw, ok := attrs["end"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): missing required attr \"end\"") + } + end, ok := endRaw.([]int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): attr \"end\" must be []int, got %T", endRaw) + } + + if len(start) != len(shape) || len(end) != len(shape) { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): start/end length (%d, %d) must match rank %d", len(start), len(end), len(shape)) + } + + out := make([]int, len(shape)) + for i := range shape { + if start[i] < 0 || end[i] > shape[i] || start[i] > end[i] { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Slice): invalid range [%d:%d] for dimension %d (size %d)", start[i], end[i], i, shape[i]) + } + out[i] = end[i] - start[i] + } + return out, nil +} + +// inferGather computes the output shape for a gather operation. +// Requires attrs["sliceSizes"] ([]int) specifying the size of each slice dimension. +// The output shape is determined by the indices shape (first input) and the slice sizes. +// +// For a simple gather: output shape = indices_shape + sliceSizes (with collapsed dims removed). +// We use a simplified model: output = indices_shape[:-1] + sliceSizes. +func inferGather(inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) != 2 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Gather): expected 2 input shapes (operand, indices), got %d", len(inputShapes)) + } + indices := inputShapes[1] + + sliceSizesRaw, ok := attrs["sliceSizes"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Gather): missing required attr \"sliceSizes\"") + } + sliceSizes, ok := sliceSizesRaw.([]int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Gather): attr \"sliceSizes\" must be []int, got %T", sliceSizesRaw) + } + + // Output shape: batch dims from indices (all but last) + slice sizes. + // The last dim of indices is the index vector dimension. + if len(indices) == 0 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(Gather): indices must be at least rank 1") + } + + out := make([]int, 0, len(indices)-1+len(sliceSizes)) + out = append(out, indices[:len(indices)-1]...) + out = append(out, sliceSizes...) + return out, nil +} + +// inferReduce computes the output shape for a reduction operation. +// Requires attrs["axis"] (int) specifying the reduction axis. +// Optional attrs["keepDims"] (bool) to retain the reduced dimension as size 1. +func inferReduce(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error) { + if len(inputShapes) != 1 { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(%s): expected 1 input shape, got %d", opName, len(inputShapes)) + } + shape := inputShapes[0] + + axisRaw, ok := attrs["axis"] + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(%s): missing required attr \"axis\"", opName) + } + axis, ok := axisRaw.(int) + if !ok { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(%s): attr \"axis\" must be int, got %T", opName, axisRaw) + } + if axis < 0 || axis >= len(shape) { + return nil, fmt.Errorf("stablehlo.InferStructuralShape(%s): axis %d out of range [0,%d)", opName, axis, len(shape)) + } + + keepDims := false + if kd, ok := attrs["keepDims"]; ok { + if b, ok := kd.(bool); ok { + keepDims = b + } + } + + if keepDims { + out := slices.Clone(shape) + out[axis] = 1 + return out, nil + } + + // Remove the reduced dimension. + out := make([]int, 0, len(shape)-1) + for i, d := range shape { + if i != axis { + out = append(out, d) + } + } + return out, nil +} + +// numElements returns the total number of elements in a tensor with the given shape. +// Returns 1 for a scalar (empty shape). +func numElements(shape []int) int { + n := 1 + for _, d := range shape { + n *= d + } + return n +} diff --git a/internal/stablehlo/shapes_structural_test.go b/internal/stablehlo/shapes_structural_test.go new file mode 100644 index 0000000..945b880 --- /dev/null +++ b/internal/stablehlo/shapes_structural_test.go @@ -0,0 +1,502 @@ +package stablehlo + +import ( + "slices" + "testing" +) + +func TestInferMatMul2D(t *testing.T) { + got, err := InferStructuralShape("MatMul", [][]int{{2, 3}, {3, 4}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 4}) { + t.Errorf("got %v, want [2 4]", got) + } +} + +func TestInferMatMulBatched3D(t *testing.T) { + got, err := InferStructuralShape("MatMul", [][]int{{5, 2, 3}, {5, 3, 4}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{5, 2, 4}) { + t.Errorf("got %v, want [5 2 4]", got) + } +} + +func TestInferMatMulBatched4D(t *testing.T) { + got, err := InferStructuralShape("MatMul", [][]int{{2, 3, 4, 5}, {2, 3, 5, 6}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3, 4, 6}) { + t.Errorf("got %v, want [2 3 4 6]", got) + } +} + +func TestInferMatMulSquare(t *testing.T) { + got, err := InferStructuralShape("MatMul", [][]int{{3, 3}, {3, 3}}, nil) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{3, 3}) { + t.Errorf("got %v, want [3 3]", got) + } +} + +func TestInferMatMulContractionMismatch(t *testing.T) { + _, err := InferStructuralShape("MatMul", [][]int{{2, 3}, {4, 5}}, nil) + if err == nil { + t.Fatal("expected error for contraction dimension mismatch") + } +} + +func TestInferMatMulBatchMismatch(t *testing.T) { + _, err := InferStructuralShape("MatMul", [][]int{{2, 3, 4}, {5, 4, 6}}, nil) + if err == nil { + t.Fatal("expected error for batch dimension mismatch") + } +} + +func TestInferMatMulRank1(t *testing.T) { + _, err := InferStructuralShape("MatMul", [][]int{{3}, {3, 4}}, nil) + if err == nil { + t.Fatal("expected error for rank-1 input") + } +} + +func TestInferMatMulWrongInputCount(t *testing.T) { + _, err := InferStructuralShape("MatMul", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for single input") + } +} + +func TestInferMatMulBatchRankMismatch(t *testing.T) { + _, err := InferStructuralShape("MatMul", [][]int{{2, 3, 4}, {3, 4, 5}}, nil) + if err == nil { + t.Fatal("expected error for batch rank mismatch (batch dims [2] vs [3])") + } +} + +func TestInferTranspose(t *testing.T) { + got, err := InferStructuralShape("Transpose", [][]int{{2, 3, 4}}, map[string]any{"perm": []int{2, 0, 1}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{4, 2, 3}) { + t.Errorf("got %v, want [4 2 3]", got) + } +} + +func TestInferTranspose2D(t *testing.T) { + got, err := InferStructuralShape("Transpose", [][]int{{3, 7}}, map[string]any{"perm": []int{1, 0}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{7, 3}) { + t.Errorf("got %v, want [7 3]", got) + } +} + +func TestInferTransposeIdentity(t *testing.T) { + got, err := InferStructuralShape("Transpose", [][]int{{2, 3, 4}}, map[string]any{"perm": []int{0, 1, 2}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3, 4}) { + t.Errorf("got %v, want [2 3 4]", got) + } +} + +func TestInferTransposeMissingPerm(t *testing.T) { + _, err := InferStructuralShape("Transpose", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for missing perm attr") + } +} + +func TestInferTransposePermLengthMismatch(t *testing.T) { + _, err := InferStructuralShape("Transpose", [][]int{{2, 3, 4}}, map[string]any{"perm": []int{1, 0}}) + if err == nil { + t.Fatal("expected error for perm length mismatch") + } +} + +func TestInferTransposeDuplicateAxis(t *testing.T) { + _, err := InferStructuralShape("Transpose", [][]int{{2, 3, 4}}, map[string]any{"perm": []int{0, 0, 1}}) + if err == nil { + t.Fatal("expected error for duplicate axis in perm") + } +} + +func TestInferTransposeOutOfRange(t *testing.T) { + _, err := InferStructuralShape("Transpose", [][]int{{2, 3}}, map[string]any{"perm": []int{0, 5}}) + if err == nil { + t.Fatal("expected error for perm axis out of range") + } +} + +func TestInferReshape(t *testing.T) { + got, err := InferStructuralShape("Reshape", [][]int{{2, 3, 4}}, map[string]any{"shape": []int{6, 4}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{6, 4}) { + t.Errorf("got %v, want [6 4]", got) + } +} + +func TestInferReshapeFlatten(t *testing.T) { + got, err := InferStructuralShape("Reshape", [][]int{{2, 3, 4}}, map[string]any{"shape": []int{24}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{24}) { + t.Errorf("got %v, want [24]", got) + } +} + +func TestInferReshapeToScalar(t *testing.T) { + got, err := InferStructuralShape("Reshape", [][]int{{1}}, map[string]any{"shape": []int{}}) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("got %v, want [] (scalar)", got) + } +} + +func TestInferReshapeFromScalar(t *testing.T) { + got, err := InferStructuralShape("Reshape", [][]int{{}}, map[string]any{"shape": []int{1, 1}}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{1, 1}) { + t.Errorf("got %v, want [1 1]", got) + } +} + +func TestInferReshapeElementCountMismatch(t *testing.T) { + _, err := InferStructuralShape("Reshape", [][]int{{2, 3}}, map[string]any{"shape": []int{7}}) + if err == nil { + t.Fatal("expected error for element count mismatch") + } +} + +func TestInferReshapeMissingShape(t *testing.T) { + _, err := InferStructuralShape("Reshape", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for missing shape attr") + } +} + +func TestInferReshapeNotAliased(t *testing.T) { + target := []int{6, 4} + got, err := InferStructuralShape("Reshape", [][]int{{2, 3, 4}}, map[string]any{"shape": target}) + if err != nil { + t.Fatal(err) + } + got[0] = 999 + if target[0] != 6 { + t.Error("InferStructuralShape(Reshape) returned a slice that aliases the target attr") + } +} + +func TestInferConcat(t *testing.T) { + got, err := InferStructuralShape("Concat", [][]int{{2, 3}, {2, 5}}, map[string]any{"axis": 1}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 8}) { + t.Errorf("got %v, want [2 8]", got) + } +} + +func TestInferConcatAxis0(t *testing.T) { + got, err := InferStructuralShape("Concat", [][]int{{2, 3}, {4, 3}}, map[string]any{"axis": 0}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{6, 3}) { + t.Errorf("got %v, want [6 3]", got) + } +} + +func TestInferConcatMultiple(t *testing.T) { + got, err := InferStructuralShape("Concat", [][]int{{2, 3}, {2, 4}, {2, 1}}, map[string]any{"axis": 1}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 8}) { + t.Errorf("got %v, want [2 8]", got) + } +} + +func TestInferConcat3D(t *testing.T) { + got, err := InferStructuralShape("Concat", [][]int{{2, 3, 4}, {2, 3, 6}}, map[string]any{"axis": 2}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3, 10}) { + t.Errorf("got %v, want [2 3 10]", got) + } +} + +func TestInferConcatDimMismatch(t *testing.T) { + _, err := InferStructuralShape("Concat", [][]int{{2, 3}, {4, 5}}, map[string]any{"axis": 1}) + if err == nil { + t.Fatal("expected error for non-concat dimension mismatch") + } +} + +func TestInferConcatRankMismatch(t *testing.T) { + _, err := InferStructuralShape("Concat", [][]int{{2, 3}, {2, 3, 4}}, map[string]any{"axis": 0}) + if err == nil { + t.Fatal("expected error for rank mismatch") + } +} + +func TestInferConcatAxisOutOfRange(t *testing.T) { + _, err := InferStructuralShape("Concat", [][]int{{2, 3}, {2, 4}}, map[string]any{"axis": 5}) + if err == nil { + t.Fatal("expected error for axis out of range") + } +} + +func TestInferConcatSingleInput(t *testing.T) { + _, err := InferStructuralShape("Concat", [][]int{{2, 3}}, map[string]any{"axis": 0}) + if err == nil { + t.Fatal("expected error for single input") + } +} + +func TestInferSlice(t *testing.T) { + got, err := InferStructuralShape("Slice", [][]int{{10, 20}}, map[string]any{ + "start": []int{2, 5}, + "end": []int{5, 15}, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{3, 10}) { + t.Errorf("got %v, want [3 10]", got) + } +} + +func TestInferSlice3D(t *testing.T) { + got, err := InferStructuralShape("Slice", [][]int{{8, 6, 4}}, map[string]any{ + "start": []int{0, 1, 2}, + "end": []int{8, 4, 4}, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{8, 3, 2}) { + t.Errorf("got %v, want [8 3 2]", got) + } +} + +func TestInferSliceFullDim(t *testing.T) { + got, err := InferStructuralShape("Slice", [][]int{{5, 10}}, map[string]any{ + "start": []int{0, 0}, + "end": []int{5, 10}, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{5, 10}) { + t.Errorf("got %v, want [5 10]", got) + } +} + +func TestInferSliceInvalidRange(t *testing.T) { + _, err := InferStructuralShape("Slice", [][]int{{10, 20}}, map[string]any{ + "start": []int{5, 0}, + "end": []int{3, 10}, + }) + if err == nil { + t.Fatal("expected error for start > end") + } +} + +func TestInferSliceOutOfBounds(t *testing.T) { + _, err := InferStructuralShape("Slice", [][]int{{10, 20}}, map[string]any{ + "start": []int{0, 0}, + "end": []int{11, 20}, + }) + if err == nil { + t.Fatal("expected error for end > dim size") + } +} + +func TestInferSliceMissingAttrs(t *testing.T) { + _, err := InferStructuralShape("Slice", [][]int{{10}}, map[string]any{"start": []int{0}}) + if err == nil { + t.Fatal("expected error for missing end attr") + } + + _, err = InferStructuralShape("Slice", [][]int{{10}}, map[string]any{"end": []int{5}}) + if err == nil { + t.Fatal("expected error for missing start attr") + } +} + +func TestInferSliceLengthMismatch(t *testing.T) { + _, err := InferStructuralShape("Slice", [][]int{{10, 20}}, map[string]any{ + "start": []int{0}, + "end": []int{5, 10}, + }) + if err == nil { + t.Fatal("expected error for start/end length mismatch") + } +} + +func TestInferGather(t *testing.T) { + // operand [10, 20], indices [3, 1], sliceSizes [5] -> [3, 5] + got, err := InferStructuralShape("Gather", [][]int{{10, 20}, {3, 1}}, map[string]any{ + "sliceSizes": []int{5}, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{3, 5}) { + t.Errorf("got %v, want [3 5]", got) + } +} + +func TestInferGatherBatched(t *testing.T) { + // operand [100, 64], indices [4, 8, 1], sliceSizes [64] -> [4, 8, 64] + got, err := InferStructuralShape("Gather", [][]int{{100, 64}, {4, 8, 1}}, map[string]any{ + "sliceSizes": []int{64}, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{4, 8, 64}) { + t.Errorf("got %v, want [4 8 64]", got) + } +} + +func TestInferGatherMissingSliceSizes(t *testing.T) { + _, err := InferStructuralShape("Gather", [][]int{{10, 20}, {3, 1}}, nil) + if err == nil { + t.Fatal("expected error for missing sliceSizes attr") + } +} + +func TestInferGatherWrongInputCount(t *testing.T) { + _, err := InferStructuralShape("Gather", [][]int{{10, 20}}, map[string]any{"sliceSizes": []int{5}}) + if err == nil { + t.Fatal("expected error for single input (missing indices)") + } +} + +func TestInferReduceSum(t *testing.T) { + got, err := InferStructuralShape("ReduceSum", [][]int{{2, 3, 4}}, map[string]any{"axis": 1}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 4}) { + t.Errorf("got %v, want [2 4]", got) + } +} + +func TestInferReduceMax(t *testing.T) { + got, err := InferStructuralShape("ReduceMax", [][]int{{2, 3, 4}}, map[string]any{"axis": 0}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{3, 4}) { + t.Errorf("got %v, want [3 4]", got) + } +} + +func TestInferReduceMean(t *testing.T) { + got, err := InferStructuralShape("ReduceMean", [][]int{{2, 3, 4}}, map[string]any{"axis": 2}) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 3}) { + t.Errorf("got %v, want [2 3]", got) + } +} + +func TestInferReduceKeepDims(t *testing.T) { + got, err := InferStructuralShape("ReduceSum", [][]int{{2, 3, 4}}, map[string]any{ + "axis": 1, + "keepDims": true, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 1, 4}) { + t.Errorf("got %v, want [2 1 4]", got) + } +} + +func TestInferReduceKeepDimsAxis0(t *testing.T) { + got, err := InferStructuralShape("ReduceMax", [][]int{{5, 3}}, map[string]any{ + "axis": 0, + "keepDims": true, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{1, 3}) { + t.Errorf("got %v, want [1 3]", got) + } +} + +func TestInferReduceKeepDimsFalse(t *testing.T) { + got, err := InferStructuralShape("ReduceSum", [][]int{{2, 3, 4}}, map[string]any{ + "axis": 1, + "keepDims": false, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(got, []int{2, 4}) { + t.Errorf("got %v, want [2 4]", got) + } +} + +func TestInferReduce2DToScalar(t *testing.T) { + // Reducing axis 0 of a rank-1 tensor produces a scalar. + got, err := InferStructuralShape("ReduceSum", [][]int{{5}}, map[string]any{"axis": 0}) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("got %v, want [] (scalar)", got) + } +} + +func TestInferReduceAxisOutOfRange(t *testing.T) { + _, err := InferStructuralShape("ReduceSum", [][]int{{2, 3}}, map[string]any{"axis": 3}) + if err == nil { + t.Fatal("expected error for axis out of range") + } +} + +func TestInferReduceMissingAxis(t *testing.T) { + _, err := InferStructuralShape("ReduceSum", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for missing axis attr") + } +} + +func TestInferReduceWrongInputCount(t *testing.T) { + _, err := InferStructuralShape("ReduceSum", [][]int{{2, 3}, {2, 3}}, map[string]any{"axis": 0}) + if err == nil { + t.Fatal("expected error for two inputs") + } +} + +func TestInferStructuralShapeUnsupportedOp(t *testing.T) { + _, err := InferStructuralShape("FooBarBaz", [][]int{{2, 3}}, nil) + if err == nil { + t.Fatal("expected error for unsupported op") + } +}