Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions internal/stablehlo/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Package stablehlo generates StableHLO MLIR text for PJRT compilation.
//
// It provides type mapping (Go types to MLIR tensor type strings),
// SSA value naming (%v0, %v1, ...), shape formatting (tensor<2x3x4xf32>),
// and StableHLO operation name constants.
package stablehlo

import (
"fmt"
"strings"
"sync"
)

// MLIR dtype strings for StableHLO tensor types.
const (
DTypeF32 = "f32"
DTypeF64 = "f64"
DTypeF16 = "f16"
DTypeBF16 = "bf16"
DTypeF8 = "f8E4M3FN"
DTypeI8 = "i8"
DTypeI16 = "i16"
DTypeI32 = "i32"
DTypeI64 = "i64"
DTypeUI8 = "ui8"
DTypeUI32 = "ui32"
DTypeUI64 = "ui64"
DTypeBool = "i1"
)

// StableHLO operation name constants.
const (
OpAdd = "stablehlo.add"
OpSubtract = "stablehlo.subtract"
OpMultiply = "stablehlo.multiply"
OpDivide = "stablehlo.divide"
OpDotGeneral = "stablehlo.dot_general"
OpTranspose = "stablehlo.transpose"
OpReshape = "stablehlo.reshape"
OpBroadcastIn = "stablehlo.broadcast_in_dim"
OpReduce = "stablehlo.reduce"
OpGather = "stablehlo.gather"
OpSlice = "stablehlo.slice"
OpConcatenate = "stablehlo.concatenate"
OpExp = "stablehlo.exponential"
OpLog = "stablehlo.log"
OpSin = "stablehlo.sine"
OpCos = "stablehlo.cosine"
OpTanh = "stablehlo.tanh"
OpNegate = "stablehlo.negate"
OpAbs = "stablehlo.abs"
OpSqrt = "stablehlo.sqrt"
OpRsqrt = "stablehlo.rsqrt"
OpMaximum = "stablehlo.maximum"
OpMinimum = "stablehlo.minimum"
OpClamp = "stablehlo.clamp"
OpSelect = "stablehlo.select"
OpCompare = "stablehlo.compare"
OpConvert = "stablehlo.convert"
OpConstant = "stablehlo.constant"
OpIota = "stablehlo.iota"
OpPower = "stablehlo.power"
)

// SSANamer generates monotonically increasing SSA value names (%v0, %v1, ...).
type SSANamer struct {
mu sync.Mutex
counter int
}

// NextName returns the next SSA value name and advances the counter.
func (n *SSANamer) NextName() string {
n.mu.Lock()
name := fmt.Sprintf("%%v%d", n.counter)
n.counter++
n.mu.Unlock()
return name
}

// Count returns the current counter value (number of names issued so far).
func (n *SSANamer) Count() int {
n.mu.Lock()
defer n.mu.Unlock()
return n.counter
}

// FormatTensorType formats a MLIR tensor type string from a shape and dtype.
// Example: FormatTensorType([]int{2, 3, 4}, "f32") returns "tensor<2x3x4xf32>".
// For scalar tensors (empty shape), it returns "tensor<f32>".
func FormatTensorType(shape []int, dtype string) string {
if len(shape) == 0 {
return "tensor<" + dtype + ">"
}
var b strings.Builder
b.WriteString("tensor<")
for _, dim := range shape {
fmt.Fprintf(&b, "%dx", dim)
}
b.WriteString(dtype)
b.WriteByte('>')
return b.String()
}

// FormatScalarType returns the MLIR scalar type string for a dtype.
// Example: FormatScalarType("f32") returns "f32".
func FormatScalarType(dtype string) string {
return dtype
}

// GoDTypeToMLIR maps a Go reflect type name to a MLIR dtype string.
// Supported mappings:
//
// float32 -> f32
// float64 -> f64
// float16 -> f16
// bfloat16 -> bf16
// float8 -> f8E4M3FN
// int8 -> i8
// int16 -> i16
// int32 -> i32
// int64 -> i64
// uint8 -> ui8
// uint32 -> ui32
// uint64 -> ui64
//
// Returns the MLIR dtype string and true if the mapping exists, or ("", false) otherwise.
func GoDTypeToMLIR(goType string) (string, bool) {
dtype, ok := goTypeMap[goType]
return dtype, ok
}

var goTypeMap = map[string]string{
"float32": DTypeF32,
"float64": DTypeF64,
"float16": DTypeF16,
"bfloat16": DTypeBF16,
"float8": DTypeF8,
"int8": DTypeI8,
"int16": DTypeI16,
"int32": DTypeI32,
"int64": DTypeI64,
"uint8": DTypeUI8,
"uint32": DTypeUI32,
"uint64": DTypeUI64,
}
154 changes: 154 additions & 0 deletions internal/stablehlo/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package stablehlo

import (
"testing"
)

func TestFormatTensorType(t *testing.T) {
tests := []struct {
name string
shape []int
dtype string
want string
}{
{"3D f32", []int{2, 3, 4}, "f32", "tensor<2x3x4xf32>"},
{"2D f64", []int{8, 16}, "f64", "tensor<8x16xf64>"},
{"1D i32", []int{10}, "i32", "tensor<10xi32>"},
{"4D bf16", []int{1, 2, 3, 4}, "bf16", "tensor<1x2x3x4xbf16>"},
{"scalar", []int{}, "f32", "tensor<f32>"},
{"f16", []int{3, 5}, "f16", "tensor<3x5xf16>"},
{"i64", []int{100}, "i64", "tensor<100xi64>"},
{"f8E4M3FN", []int{4, 8}, "f8E4M3FN", "tensor<4x8xf8E4M3FN>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatTensorType(tt.shape, tt.dtype)
if got != tt.want {
t.Errorf("FormatTensorType(%v, %q) = %q, want %q", tt.shape, tt.dtype, got, tt.want)
}
})
}
}

func TestFormatScalarType(t *testing.T) {
tests := []struct {
dtype string
want string
}{
{"f32", "f32"},
{"f64", "f64"},
{"i32", "i32"},
{"bf16", "bf16"},
}
for _, tt := range tests {
t.Run(tt.dtype, func(t *testing.T) {
got := FormatScalarType(tt.dtype)
if got != tt.want {
t.Errorf("FormatScalarType(%q) = %q, want %q", tt.dtype, got, tt.want)
}
})
}
}

func TestSSANamer(t *testing.T) {
n := &SSANamer{}

if n.Count() != 0 {
t.Fatalf("initial count = %d, want 0", n.Count())
}

expected := []string{"%v0", "%v1", "%v2", "%v3", "%v4"}
for i, want := range expected {
got := n.NextName()
if got != want {
t.Errorf("NextName() call %d = %q, want %q", i, got, want)
}
}

if n.Count() != 5 {
t.Errorf("count after 5 calls = %d, want 5", n.Count())
}
}

func TestSSANamerConcurrent(t *testing.T) {
n := &SSANamer{}
const goroutines = 100

done := make(chan string, goroutines)
for range goroutines {
go func() {
done <- n.NextName()
}()
}

seen := make(map[string]bool, goroutines)
for range goroutines {
name := <-done
if seen[name] {
t.Errorf("duplicate SSA name: %s", name)
}
seen[name] = true
}

if n.Count() != goroutines {
t.Errorf("count = %d, want %d", n.Count(), goroutines)
}
}

func TestGoDTypeToMLIR(t *testing.T) {
tests := []struct {
goType string
want string
ok bool
}{
{"float32", "f32", true},
{"float64", "f64", true},
{"float16", "f16", true},
{"bfloat16", "bf16", true},
{"float8", "f8E4M3FN", true},
{"int8", "i8", true},
{"int16", "i16", true},
{"int32", "i32", true},
{"int64", "i64", true},
{"uint8", "ui8", true},
{"uint32", "ui32", true},
{"uint64", "ui64", true},
{"complex64", "", false},
{"string", "", false},
}
for _, tt := range tests {
t.Run(tt.goType, func(t *testing.T) {
got, ok := GoDTypeToMLIR(tt.goType)
if ok != tt.ok {
t.Errorf("GoDTypeToMLIR(%q) ok = %v, want %v", tt.goType, ok, tt.ok)
}
if got != tt.want {
t.Errorf("GoDTypeToMLIR(%q) = %q, want %q", tt.goType, got, tt.want)
}
})
}
}

func TestOpConstants(t *testing.T) {
// Verify key op constants have the expected stablehlo. prefix and op names.
ops := map[string]string{
"OpAdd": OpAdd,
"OpSubtract": OpSubtract,
"OpMultiply": OpMultiply,
"OpDivide": OpDivide,
"OpDotGeneral": OpDotGeneral,
"OpTranspose": OpTranspose,
"OpReshape": OpReshape,
"OpBroadcastIn": OpBroadcastIn,
"OpReduce": OpReduce,
"OpExp": OpExp,
"OpLog": OpLog,
"OpTanh": OpTanh,
"OpConstant": OpConstant,
}
for name, val := range ops {
if len(val) < 12 || val[:10] != "stablehlo." {
t.Errorf("%s = %q, expected stablehlo.* prefix", name, val)
}
}
}
Loading