Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions internal/stablehlo/emit_reduce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package stablehlo

import (
"fmt"
"strings"
)

// EmitReduceSum emits a StableHLO reduce with an add body.
//
// The generated MLIR has the form:
//
// %result = "stablehlo.reduce"(%input, %init) ({
// ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
// %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
// stablehlo.return %0 : tensor<f32>
// }) {dimensions = array<i64: axis>} : (inputType, tensor<dtype>) -> outputType
func EmitReduceSum(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) {
return emitReduce(namer, input, inputShape, axis, keepDims, dtype, "add")
}

// EmitReduceMax emits a StableHLO reduce with a maximum body.
func EmitReduceMax(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) {
return emitReduce(namer, input, inputShape, axis, keepDims, dtype, "maximum")
}

// EmitReduceMean emits a ReduceSum followed by a DivScalar to compute the mean.
// Returns the final result SSA name and the emitted MLIR text for both ops.
func EmitReduceMean(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) {
sumName, sumMLIR := emitReduce(namer, input, inputShape, axis, keepDims, dtype, "add")

// Compute the output shape of the reduction.
outShape := reduceShape(inputShape, axis, keepDims)
outType := FormatTensorType(outShape, dtype)

// Emit a constant for the axis size and divide.
count := inputShape[axis]
constName := namer.NextName()
divName := namer.NextName()

var b strings.Builder
b.WriteString(sumMLIR)
fmt.Fprintf(&b, "%s = stablehlo.constant dense<%d.0> : tensor<%s>\n", constName, count, dtype)
fmt.Fprintf(&b, "%s = stablehlo.divide %s, %s : %s\n", divName, sumName, constName, outType)
return divName, b.String()
}

// EmitSoftmax decomposes Softmax into 5 StableHLO operations:
// 1. max = ReduceMax(input, axis, keepDims=true)
// 2. shifted = Sub(input, max)
// 3. exp = Exp(shifted)
// 4. sum = ReduceSum(exp, axis, keepDims=true)
// 5. result = Div(exp, sum)
//
// Returns the final result SSA name and the emitted MLIR text.
func EmitSoftmax(namer *SSANamer, input string, inputShape []int, axis int, dtype string) (string, string) {
inputType := FormatTensorType(inputShape, dtype)

var b strings.Builder

// 1. max = ReduceMax(input, axis, keepDims=true)
maxName, maxMLIR := EmitReduceMax(namer, input, inputShape, axis, true, dtype)
b.WriteString(maxMLIR)

// 2. shifted = Sub(input, max) -- broadcast max back to input shape
shiftedName := namer.NextName()
fmt.Fprintf(&b, "%s = stablehlo.subtract %s, %s : %s\n", shiftedName, input, maxName, inputType)

// 3. exp = Exp(shifted)
expName := namer.NextName()
fmt.Fprintf(&b, "%s = stablehlo.exponential %s : %s\n", expName, shiftedName, inputType)

// 4. sum = ReduceSum(exp, axis, keepDims=true)
sumName, sumMLIR := EmitReduceSum(namer, expName, inputShape, axis, true, dtype)
b.WriteString(sumMLIR)

// 5. result = Div(exp, sum)
resultName := namer.NextName()
fmt.Fprintf(&b, "%s = stablehlo.divide %s, %s : %s\n", resultName, expName, sumName, inputType)

return resultName, b.String()
}

// emitReduce generates a StableHLO reduce op with the given reduction body op.
// bodyOp should be "add" for sum or "maximum" for max.
func emitReduce(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype, bodyOp string) (string, string) {
scalarType := fmt.Sprintf("tensor<%s>", dtype)
inputType := FormatTensorType(inputShape, dtype)
outShape := reduceShape(inputShape, axis, false) // reduce always removes the dim
outType := FormatTensorType(outShape, dtype)

initName := namer.NextName()
reduceName := namer.NextName()

var b strings.Builder

// Emit the init value (zero for add, -inf for maximum).
initVal := "0.0"
if bodyOp == "maximum" {
initVal = "0xFF800000"
}
fmt.Fprintf(&b, "%s = stablehlo.constant dense<%s> : %s\n", initName, initVal, scalarType)

// Emit the reduce op with inline region body.
fmt.Fprintf(&b, "%s = \"stablehlo.reduce\"(%s, %s) ({\n", reduceName, input, initName)
fmt.Fprintf(&b, "^bb0(%%arg0: %s, %%arg1: %s):\n", scalarType, scalarType)
fmt.Fprintf(&b, " %%0 = stablehlo.%s %%arg0, %%arg1 : %s\n", bodyOp, scalarType)
fmt.Fprintf(&b, " stablehlo.return %%0 : %s\n", scalarType)
fmt.Fprintf(&b, "}) {dimensions = array<i64: %d>} : (%s, %s) -> %s\n", axis, inputType, scalarType, outType)

// If keepDims, reshape to insert size-1 dimension at axis.
if keepDims {
keepShape := reduceShape(inputShape, axis, true)
keepType := FormatTensorType(keepShape, dtype)
reshapeName := namer.NextName()
fmt.Fprintf(&b, "%s = stablehlo.reshape %s : %s -> %s\n", reshapeName, reduceName, outType, keepType)
return reshapeName, b.String()
}

return reduceName, b.String()
}

// reduceShape computes the output shape after reducing along axis.
// If keepDims is true, the reduced axis becomes size 1.
// If keepDims is false, the reduced axis is removed.
func reduceShape(shape []int, axis int, keepDims bool) []int {
if keepDims {
out := make([]int, len(shape))
copy(out, shape)
out[axis] = 1
return out
}
out := make([]int, 0, len(shape)-1)
for i, d := range shape {
if i != axis {
out = append(out, d)
}
}
return out
}
213 changes: 213 additions & 0 deletions internal/stablehlo/emit_reduce_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package stablehlo

import (
"strings"
"testing"
)

func TestEmitReduceSum(t *testing.T) {
namer := &SSANamer{}
name, mlir := EmitReduceSum(namer, "%input", []int{2, 3}, 1, false, "f32")

if name != "%v1" {
t.Errorf("result name = %q, want %%v1", name)
}
if !strings.Contains(mlir, "stablehlo.reduce") {
t.Error("missing stablehlo.reduce")
}
if !strings.Contains(mlir, "stablehlo.add") {
t.Error("missing stablehlo.add in reduction body")
}
if !strings.Contains(mlir, "dimensions = array<i64: 1>") {
t.Error("missing dimensions attribute")
}
if !strings.Contains(mlir, "tensor<2xf32>") {
t.Errorf("missing output type tensor<2xf32> in:\n%s", mlir)
}
}

func TestEmitReduceMax(t *testing.T) {
namer := &SSANamer{}
name, mlir := EmitReduceMax(namer, "%input", []int{4, 5}, 0, false, "f32")

if name != "%v1" {
t.Errorf("result name = %q, want %%v1", name)
}
if !strings.Contains(mlir, "stablehlo.maximum") {
t.Error("missing stablehlo.maximum in reduction body")
}
if !strings.Contains(mlir, "0xFF800000") {
t.Error("missing -inf init value for max reduction")
}
if !strings.Contains(mlir, "dimensions = array<i64: 0>") {
t.Error("missing dimensions attribute")
}
if !strings.Contains(mlir, "tensor<5xf32>") {
t.Errorf("missing output type tensor<5xf32> in:\n%s", mlir)
}
}

func TestEmitReduceSumKeepDims(t *testing.T) {
namer := &SSANamer{}
name, mlir := EmitReduceSum(namer, "%input", []int{2, 3}, 1, true, "f32")

if name != "%v2" {
t.Errorf("result name = %q, want %%v2", name)
}
if !strings.Contains(mlir, "stablehlo.reshape") {
t.Error("missing reshape for keepDims")
}
if !strings.Contains(mlir, "tensor<2x1xf32>") {
t.Errorf("missing keepDims output type tensor<2x1xf32> in:\n%s", mlir)
}
}

func TestEmitReduceMean(t *testing.T) {
namer := &SSANamer{}
name, mlir := EmitReduceMean(namer, "%input", []int{2, 6}, 1, false, "f32")

if name != "%v3" {
t.Errorf("result name = %q, want %%v3", name)
}
if !strings.Contains(mlir, "stablehlo.add") {
t.Error("missing sum reduction for mean")
}
if !strings.Contains(mlir, "dense<6.0>") {
t.Error("missing divisor constant for mean")
}
if !strings.Contains(mlir, "stablehlo.divide") {
t.Error("missing divide for mean")
}
}

func TestEmitSoftmax(t *testing.T) {
namer := &SSANamer{}
name, mlir := EmitSoftmax(namer, "%input", []int{2, 3}, 1, "f32")

if name == "" {
t.Fatal("result name is empty")
}

// Softmax should decompose into exactly 5 logical operations:
// 1. ReduceMax (reduce + reshape for keepDims)
// 2. Subtract
// 3. Exp
// 4. ReduceSum (reduce + reshape for keepDims)
// 5. Divide
ops := []struct {
name string
op string
}{
{"ReduceMax", "stablehlo.maximum"},
{"Subtract", "stablehlo.subtract"},
{"Exp", "stablehlo.exponential"},
{"ReduceSum", "stablehlo.add"},
{"Divide", "stablehlo.divide"},
}
for _, op := range ops {
if !strings.Contains(mlir, op.op) {
t.Errorf("missing %s (%s) in Softmax decomposition", op.name, op.op)
}
}

// Count the 5 high-level ops by counting the distinct operation types.
highLevelOps := 0
if strings.Contains(mlir, "stablehlo.maximum") {
highLevelOps++
}
if strings.Contains(mlir, "stablehlo.subtract") {
highLevelOps++
}
if strings.Contains(mlir, "stablehlo.exponential") {
highLevelOps++
}
if strings.Contains(mlir, "stablehlo.add") {
highLevelOps++
}
if strings.Contains(mlir, "stablehlo.divide") {
highLevelOps++
}
if highLevelOps != 5 {
t.Errorf("Softmax decomposition has %d high-level ops, want 5", highLevelOps)
}
}

func TestEmitSoftmaxMLIRStructure(t *testing.T) {
namer := &SSANamer{}
_, mlir := EmitSoftmax(namer, "%input", []int{2, 3}, 1, "f32")

// Verify the MLIR contains two reduce regions (one for max, one for sum).
reduceCount := strings.Count(mlir, `"stablehlo.reduce"`)
if reduceCount != 2 {
t.Errorf("expected 2 stablehlo.reduce ops, got %d", reduceCount)
}

// Verify two reshape ops (keepDims for max and sum).
reshapeCount := strings.Count(mlir, "stablehlo.reshape")
if reshapeCount != 2 {
t.Errorf("expected 2 reshape ops for keepDims, got %d", reshapeCount)
}

// Verify the output types contain the broadcast shapes.
if !strings.Contains(mlir, "tensor<2x1xf32>") {
t.Error("missing keepDims shape tensor<2x1xf32>")
}
if !strings.Contains(mlir, "tensor<2x3xf32>") {
t.Error("missing full shape tensor<2x3xf32>")
}
}

func TestEmitReduceRegionBody(t *testing.T) {
namer := &SSANamer{}
_, mlir := EmitReduceSum(namer, "%x", []int{3, 4}, 0, false, "f64")

// Verify region body structure.
if !strings.Contains(mlir, "^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>)") {
t.Errorf("missing region block args in:\n%s", mlir)
}
if !strings.Contains(mlir, "stablehlo.return %0 : tensor<f64>") {
t.Errorf("missing stablehlo.return in:\n%s", mlir)
}
}

func TestEmitReduce3D(t *testing.T) {
namer := &SSANamer{}
_, mlir := EmitReduceSum(namer, "%t", []int{2, 3, 4}, 2, false, "f32")

if !strings.Contains(mlir, "dimensions = array<i64: 2>") {
t.Error("wrong dimension for axis=2")
}
if !strings.Contains(mlir, "tensor<2x3xf32>") {
t.Errorf("wrong output shape, expected tensor<2x3xf32> in:\n%s", mlir)
}
}

func TestReduceShape(t *testing.T) {
tests := []struct {
name string
shape []int
axis int
keepDims bool
want []int
}{
{"remove dim", []int{2, 3}, 1, false, []int{2}},
{"keep dim", []int{2, 3}, 1, true, []int{2, 1}},
{"remove first", []int{4, 5}, 0, false, []int{5}},
{"keep first", []int{4, 5}, 0, true, []int{1, 5}},
{"3D middle", []int{2, 3, 4}, 1, false, []int{2, 4}},
{"3D middle keep", []int{2, 3, 4}, 1, true, []int{2, 1, 4}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := reduceShape(tt.shape, tt.axis, tt.keepDims)
if len(got) != len(tt.want) {
t.Fatalf("reduceShape(%v, %d, %v) = %v, want %v", tt.shape, tt.axis, tt.keepDims, got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("reduceShape(%v, %d, %v)[%d] = %d, want %d", tt.shape, tt.axis, tt.keepDims, i, got[i], tt.want[i])
}
}
})
}
}
Loading