diff --git a/internal/stablehlo/types.go b/internal/stablehlo/types.go new file mode 100644 index 0000000..62ed8c1 --- /dev/null +++ b/internal/stablehlo/types.go @@ -0,0 +1,145 @@ +// Package stablehlo generates StableHLO MLIR text for PJRT compilation. +// +// It provides type mapping (Go types to MLIR tensor type strings), +// SSA value naming (%v0, %v1, ...), shape formatting (tensor<2x3x4xf32>), +// and StableHLO operation name constants. +package stablehlo + +import ( + "fmt" + "strings" + "sync" +) + +// MLIR dtype strings for StableHLO tensor types. +const ( + DTypeF32 = "f32" + DTypeF64 = "f64" + DTypeF16 = "f16" + DTypeBF16 = "bf16" + DTypeF8 = "f8E4M3FN" + DTypeI8 = "i8" + DTypeI16 = "i16" + DTypeI32 = "i32" + DTypeI64 = "i64" + DTypeUI8 = "ui8" + DTypeUI32 = "ui32" + DTypeUI64 = "ui64" + DTypeBool = "i1" +) + +// StableHLO operation name constants. +const ( + OpAdd = "stablehlo.add" + OpSubtract = "stablehlo.subtract" + OpMultiply = "stablehlo.multiply" + OpDivide = "stablehlo.divide" + OpDotGeneral = "stablehlo.dot_general" + OpTranspose = "stablehlo.transpose" + OpReshape = "stablehlo.reshape" + OpBroadcastIn = "stablehlo.broadcast_in_dim" + OpReduce = "stablehlo.reduce" + OpGather = "stablehlo.gather" + OpSlice = "stablehlo.slice" + OpConcatenate = "stablehlo.concatenate" + OpExp = "stablehlo.exponential" + OpLog = "stablehlo.log" + OpSin = "stablehlo.sine" + OpCos = "stablehlo.cosine" + OpTanh = "stablehlo.tanh" + OpNegate = "stablehlo.negate" + OpAbs = "stablehlo.abs" + OpSqrt = "stablehlo.sqrt" + OpRsqrt = "stablehlo.rsqrt" + OpMaximum = "stablehlo.maximum" + OpMinimum = "stablehlo.minimum" + OpClamp = "stablehlo.clamp" + OpSelect = "stablehlo.select" + OpCompare = "stablehlo.compare" + OpConvert = "stablehlo.convert" + OpConstant = "stablehlo.constant" + OpIota = "stablehlo.iota" + OpPower = "stablehlo.power" +) + +// SSANamer generates monotonically increasing SSA value names (%v0, %v1, ...). +type SSANamer struct { + mu sync.Mutex + counter int +} + +// NextName returns the next SSA value name and advances the counter. +func (n *SSANamer) NextName() string { + n.mu.Lock() + name := fmt.Sprintf("%%v%d", n.counter) + n.counter++ + n.mu.Unlock() + return name +} + +// Count returns the current counter value (number of names issued so far). +func (n *SSANamer) Count() int { + n.mu.Lock() + defer n.mu.Unlock() + return n.counter +} + +// FormatTensorType formats a MLIR tensor type string from a shape and dtype. +// Example: FormatTensorType([]int{2, 3, 4}, "f32") returns "tensor<2x3x4xf32>". +// For scalar tensors (empty shape), it returns "tensor". +func FormatTensorType(shape []int, dtype string) string { + if len(shape) == 0 { + return "tensor<" + dtype + ">" + } + var b strings.Builder + b.WriteString("tensor<") + for _, dim := range shape { + fmt.Fprintf(&b, "%dx", dim) + } + b.WriteString(dtype) + b.WriteByte('>') + return b.String() +} + +// FormatScalarType returns the MLIR scalar type string for a dtype. +// Example: FormatScalarType("f32") returns "f32". +func FormatScalarType(dtype string) string { + return dtype +} + +// GoDTypeToMLIR maps a Go reflect type name to a MLIR dtype string. +// Supported mappings: +// +// float32 -> f32 +// float64 -> f64 +// float16 -> f16 +// bfloat16 -> bf16 +// float8 -> f8E4M3FN +// int8 -> i8 +// int16 -> i16 +// int32 -> i32 +// int64 -> i64 +// uint8 -> ui8 +// uint32 -> ui32 +// uint64 -> ui64 +// +// Returns the MLIR dtype string and true if the mapping exists, or ("", false) otherwise. +func GoDTypeToMLIR(goType string) (string, bool) { + dtype, ok := goTypeMap[goType] + return dtype, ok +} + +var goTypeMap = map[string]string{ + "float32": DTypeF32, + "float64": DTypeF64, + "float16": DTypeF16, + "bfloat16": DTypeBF16, + "float8": DTypeF8, + "int8": DTypeI8, + "int16": DTypeI16, + "int32": DTypeI32, + "int64": DTypeI64, + "uint8": DTypeUI8, + "uint32": DTypeUI32, + "uint64": DTypeUI64, +} diff --git a/internal/stablehlo/types_test.go b/internal/stablehlo/types_test.go new file mode 100644 index 0000000..8ababa0 --- /dev/null +++ b/internal/stablehlo/types_test.go @@ -0,0 +1,154 @@ +package stablehlo + +import ( + "testing" +) + +func TestFormatTensorType(t *testing.T) { + tests := []struct { + name string + shape []int + dtype string + want string + }{ + {"3D f32", []int{2, 3, 4}, "f32", "tensor<2x3x4xf32>"}, + {"2D f64", []int{8, 16}, "f64", "tensor<8x16xf64>"}, + {"1D i32", []int{10}, "i32", "tensor<10xi32>"}, + {"4D bf16", []int{1, 2, 3, 4}, "bf16", "tensor<1x2x3x4xbf16>"}, + {"scalar", []int{}, "f32", "tensor"}, + {"f16", []int{3, 5}, "f16", "tensor<3x5xf16>"}, + {"i64", []int{100}, "i64", "tensor<100xi64>"}, + {"f8E4M3FN", []int{4, 8}, "f8E4M3FN", "tensor<4x8xf8E4M3FN>"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatTensorType(tt.shape, tt.dtype) + if got != tt.want { + t.Errorf("FormatTensorType(%v, %q) = %q, want %q", tt.shape, tt.dtype, got, tt.want) + } + }) + } +} + +func TestFormatScalarType(t *testing.T) { + tests := []struct { + dtype string + want string + }{ + {"f32", "f32"}, + {"f64", "f64"}, + {"i32", "i32"}, + {"bf16", "bf16"}, + } + for _, tt := range tests { + t.Run(tt.dtype, func(t *testing.T) { + got := FormatScalarType(tt.dtype) + if got != tt.want { + t.Errorf("FormatScalarType(%q) = %q, want %q", tt.dtype, got, tt.want) + } + }) + } +} + +func TestSSANamer(t *testing.T) { + n := &SSANamer{} + + if n.Count() != 0 { + t.Fatalf("initial count = %d, want 0", n.Count()) + } + + expected := []string{"%v0", "%v1", "%v2", "%v3", "%v4"} + for i, want := range expected { + got := n.NextName() + if got != want { + t.Errorf("NextName() call %d = %q, want %q", i, got, want) + } + } + + if n.Count() != 5 { + t.Errorf("count after 5 calls = %d, want 5", n.Count()) + } +} + +func TestSSANamerConcurrent(t *testing.T) { + n := &SSANamer{} + const goroutines = 100 + + done := make(chan string, goroutines) + for range goroutines { + go func() { + done <- n.NextName() + }() + } + + seen := make(map[string]bool, goroutines) + for range goroutines { + name := <-done + if seen[name] { + t.Errorf("duplicate SSA name: %s", name) + } + seen[name] = true + } + + if n.Count() != goroutines { + t.Errorf("count = %d, want %d", n.Count(), goroutines) + } +} + +func TestGoDTypeToMLIR(t *testing.T) { + tests := []struct { + goType string + want string + ok bool + }{ + {"float32", "f32", true}, + {"float64", "f64", true}, + {"float16", "f16", true}, + {"bfloat16", "bf16", true}, + {"float8", "f8E4M3FN", true}, + {"int8", "i8", true}, + {"int16", "i16", true}, + {"int32", "i32", true}, + {"int64", "i64", true}, + {"uint8", "ui8", true}, + {"uint32", "ui32", true}, + {"uint64", "ui64", true}, + {"complex64", "", false}, + {"string", "", false}, + } + for _, tt := range tests { + t.Run(tt.goType, func(t *testing.T) { + got, ok := GoDTypeToMLIR(tt.goType) + if ok != tt.ok { + t.Errorf("GoDTypeToMLIR(%q) ok = %v, want %v", tt.goType, ok, tt.ok) + } + if got != tt.want { + t.Errorf("GoDTypeToMLIR(%q) = %q, want %q", tt.goType, got, tt.want) + } + }) + } +} + +func TestOpConstants(t *testing.T) { + // Verify key op constants have the expected stablehlo. prefix and op names. + ops := map[string]string{ + "OpAdd": OpAdd, + "OpSubtract": OpSubtract, + "OpMultiply": OpMultiply, + "OpDivide": OpDivide, + "OpDotGeneral": OpDotGeneral, + "OpTranspose": OpTranspose, + "OpReshape": OpReshape, + "OpBroadcastIn": OpBroadcastIn, + "OpReduce": OpReduce, + "OpExp": OpExp, + "OpLog": OpLog, + "OpTanh": OpTanh, + "OpConstant": OpConstant, + } + for name, val := range ops { + if len(val) < 12 || val[:10] != "stablehlo." { + t.Errorf("%s = %q, expected stablehlo.* prefix", name, val) + } + } +}