diff --git a/internal/stablehlo/emit_reduce.go b/internal/stablehlo/emit_reduce.go new file mode 100644 index 0000000..8cbd4ee --- /dev/null +++ b/internal/stablehlo/emit_reduce.go @@ -0,0 +1,139 @@ +package stablehlo + +import ( + "fmt" + "strings" +) + +// EmitReduceSum emits a StableHLO reduce with an add body. +// +// The generated MLIR has the form: +// +// %result = "stablehlo.reduce"(%input, %init) ({ +// ^bb0(%arg0: tensor, %arg1: tensor): +// %0 = stablehlo.add %arg0, %arg1 : tensor +// stablehlo.return %0 : tensor +// }) {dimensions = array} : (inputType, tensor) -> outputType +func EmitReduceSum(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) { + return emitReduce(namer, input, inputShape, axis, keepDims, dtype, "add") +} + +// EmitReduceMax emits a StableHLO reduce with a maximum body. +func EmitReduceMax(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) { + return emitReduce(namer, input, inputShape, axis, keepDims, dtype, "maximum") +} + +// EmitReduceMean emits a ReduceSum followed by a DivScalar to compute the mean. +// Returns the final result SSA name and the emitted MLIR text for both ops. +func EmitReduceMean(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string) { + sumName, sumMLIR := emitReduce(namer, input, inputShape, axis, keepDims, dtype, "add") + + // Compute the output shape of the reduction. + outShape := reduceShape(inputShape, axis, keepDims) + outType := FormatTensorType(outShape, dtype) + + // Emit a constant for the axis size and divide. + count := inputShape[axis] + constName := namer.NextName() + divName := namer.NextName() + + var b strings.Builder + b.WriteString(sumMLIR) + fmt.Fprintf(&b, "%s = stablehlo.constant dense<%d.0> : tensor<%s>\n", constName, count, dtype) + fmt.Fprintf(&b, "%s = stablehlo.divide %s, %s : %s\n", divName, sumName, constName, outType) + return divName, b.String() +} + +// EmitSoftmax decomposes Softmax into 5 StableHLO operations: +// 1. max = ReduceMax(input, axis, keepDims=true) +// 2. shifted = Sub(input, max) +// 3. exp = Exp(shifted) +// 4. sum = ReduceSum(exp, axis, keepDims=true) +// 5. result = Div(exp, sum) +// +// Returns the final result SSA name and the emitted MLIR text. +func EmitSoftmax(namer *SSANamer, input string, inputShape []int, axis int, dtype string) (string, string) { + inputType := FormatTensorType(inputShape, dtype) + + var b strings.Builder + + // 1. max = ReduceMax(input, axis, keepDims=true) + maxName, maxMLIR := EmitReduceMax(namer, input, inputShape, axis, true, dtype) + b.WriteString(maxMLIR) + + // 2. shifted = Sub(input, max) -- broadcast max back to input shape + shiftedName := namer.NextName() + fmt.Fprintf(&b, "%s = stablehlo.subtract %s, %s : %s\n", shiftedName, input, maxName, inputType) + + // 3. exp = Exp(shifted) + expName := namer.NextName() + fmt.Fprintf(&b, "%s = stablehlo.exponential %s : %s\n", expName, shiftedName, inputType) + + // 4. sum = ReduceSum(exp, axis, keepDims=true) + sumName, sumMLIR := EmitReduceSum(namer, expName, inputShape, axis, true, dtype) + b.WriteString(sumMLIR) + + // 5. result = Div(exp, sum) + resultName := namer.NextName() + fmt.Fprintf(&b, "%s = stablehlo.divide %s, %s : %s\n", resultName, expName, sumName, inputType) + + return resultName, b.String() +} + +// emitReduce generates a StableHLO reduce op with the given reduction body op. +// bodyOp should be "add" for sum or "maximum" for max. +func emitReduce(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype, bodyOp string) (string, string) { + scalarType := fmt.Sprintf("tensor<%s>", dtype) + inputType := FormatTensorType(inputShape, dtype) + outShape := reduceShape(inputShape, axis, false) // reduce always removes the dim + outType := FormatTensorType(outShape, dtype) + + initName := namer.NextName() + reduceName := namer.NextName() + + var b strings.Builder + + // Emit the init value (zero for add, -inf for maximum). + initVal := "0.0" + if bodyOp == "maximum" { + initVal = "0xFF800000" + } + fmt.Fprintf(&b, "%s = stablehlo.constant dense<%s> : %s\n", initName, initVal, scalarType) + + // Emit the reduce op with inline region body. + fmt.Fprintf(&b, "%s = \"stablehlo.reduce\"(%s, %s) ({\n", reduceName, input, initName) + fmt.Fprintf(&b, "^bb0(%%arg0: %s, %%arg1: %s):\n", scalarType, scalarType) + fmt.Fprintf(&b, " %%0 = stablehlo.%s %%arg0, %%arg1 : %s\n", bodyOp, scalarType) + fmt.Fprintf(&b, " stablehlo.return %%0 : %s\n", scalarType) + fmt.Fprintf(&b, "}) {dimensions = array} : (%s, %s) -> %s\n", axis, inputType, scalarType, outType) + + // If keepDims, reshape to insert size-1 dimension at axis. + if keepDims { + keepShape := reduceShape(inputShape, axis, true) + keepType := FormatTensorType(keepShape, dtype) + reshapeName := namer.NextName() + fmt.Fprintf(&b, "%s = stablehlo.reshape %s : %s -> %s\n", reshapeName, reduceName, outType, keepType) + return reshapeName, b.String() + } + + return reduceName, b.String() +} + +// reduceShape computes the output shape after reducing along axis. +// If keepDims is true, the reduced axis becomes size 1. +// If keepDims is false, the reduced axis is removed. +func reduceShape(shape []int, axis int, keepDims bool) []int { + if keepDims { + out := make([]int, len(shape)) + copy(out, shape) + out[axis] = 1 + return out + } + out := make([]int, 0, len(shape)-1) + for i, d := range shape { + if i != axis { + out = append(out, d) + } + } + return out +} diff --git a/internal/stablehlo/emit_reduce_test.go b/internal/stablehlo/emit_reduce_test.go new file mode 100644 index 0000000..db1db20 --- /dev/null +++ b/internal/stablehlo/emit_reduce_test.go @@ -0,0 +1,213 @@ +package stablehlo + +import ( + "strings" + "testing" +) + +func TestEmitReduceSum(t *testing.T) { + namer := &SSANamer{} + name, mlir := EmitReduceSum(namer, "%input", []int{2, 3}, 1, false, "f32") + + if name != "%v1" { + t.Errorf("result name = %q, want %%v1", name) + } + if !strings.Contains(mlir, "stablehlo.reduce") { + t.Error("missing stablehlo.reduce") + } + if !strings.Contains(mlir, "stablehlo.add") { + t.Error("missing stablehlo.add in reduction body") + } + if !strings.Contains(mlir, "dimensions = array") { + t.Error("missing dimensions attribute") + } + if !strings.Contains(mlir, "tensor<2xf32>") { + t.Errorf("missing output type tensor<2xf32> in:\n%s", mlir) + } +} + +func TestEmitReduceMax(t *testing.T) { + namer := &SSANamer{} + name, mlir := EmitReduceMax(namer, "%input", []int{4, 5}, 0, false, "f32") + + if name != "%v1" { + t.Errorf("result name = %q, want %%v1", name) + } + if !strings.Contains(mlir, "stablehlo.maximum") { + t.Error("missing stablehlo.maximum in reduction body") + } + if !strings.Contains(mlir, "0xFF800000") { + t.Error("missing -inf init value for max reduction") + } + if !strings.Contains(mlir, "dimensions = array") { + t.Error("missing dimensions attribute") + } + if !strings.Contains(mlir, "tensor<5xf32>") { + t.Errorf("missing output type tensor<5xf32> in:\n%s", mlir) + } +} + +func TestEmitReduceSumKeepDims(t *testing.T) { + namer := &SSANamer{} + name, mlir := EmitReduceSum(namer, "%input", []int{2, 3}, 1, true, "f32") + + if name != "%v2" { + t.Errorf("result name = %q, want %%v2", name) + } + if !strings.Contains(mlir, "stablehlo.reshape") { + t.Error("missing reshape for keepDims") + } + if !strings.Contains(mlir, "tensor<2x1xf32>") { + t.Errorf("missing keepDims output type tensor<2x1xf32> in:\n%s", mlir) + } +} + +func TestEmitReduceMean(t *testing.T) { + namer := &SSANamer{} + name, mlir := EmitReduceMean(namer, "%input", []int{2, 6}, 1, false, "f32") + + if name != "%v3" { + t.Errorf("result name = %q, want %%v3", name) + } + if !strings.Contains(mlir, "stablehlo.add") { + t.Error("missing sum reduction for mean") + } + if !strings.Contains(mlir, "dense<6.0>") { + t.Error("missing divisor constant for mean") + } + if !strings.Contains(mlir, "stablehlo.divide") { + t.Error("missing divide for mean") + } +} + +func TestEmitSoftmax(t *testing.T) { + namer := &SSANamer{} + name, mlir := EmitSoftmax(namer, "%input", []int{2, 3}, 1, "f32") + + if name == "" { + t.Fatal("result name is empty") + } + + // Softmax should decompose into exactly 5 logical operations: + // 1. ReduceMax (reduce + reshape for keepDims) + // 2. Subtract + // 3. Exp + // 4. ReduceSum (reduce + reshape for keepDims) + // 5. Divide + ops := []struct { + name string + op string + }{ + {"ReduceMax", "stablehlo.maximum"}, + {"Subtract", "stablehlo.subtract"}, + {"Exp", "stablehlo.exponential"}, + {"ReduceSum", "stablehlo.add"}, + {"Divide", "stablehlo.divide"}, + } + for _, op := range ops { + if !strings.Contains(mlir, op.op) { + t.Errorf("missing %s (%s) in Softmax decomposition", op.name, op.op) + } + } + + // Count the 5 high-level ops by counting the distinct operation types. + highLevelOps := 0 + if strings.Contains(mlir, "stablehlo.maximum") { + highLevelOps++ + } + if strings.Contains(mlir, "stablehlo.subtract") { + highLevelOps++ + } + if strings.Contains(mlir, "stablehlo.exponential") { + highLevelOps++ + } + if strings.Contains(mlir, "stablehlo.add") { + highLevelOps++ + } + if strings.Contains(mlir, "stablehlo.divide") { + highLevelOps++ + } + if highLevelOps != 5 { + t.Errorf("Softmax decomposition has %d high-level ops, want 5", highLevelOps) + } +} + +func TestEmitSoftmaxMLIRStructure(t *testing.T) { + namer := &SSANamer{} + _, mlir := EmitSoftmax(namer, "%input", []int{2, 3}, 1, "f32") + + // Verify the MLIR contains two reduce regions (one for max, one for sum). + reduceCount := strings.Count(mlir, `"stablehlo.reduce"`) + if reduceCount != 2 { + t.Errorf("expected 2 stablehlo.reduce ops, got %d", reduceCount) + } + + // Verify two reshape ops (keepDims for max and sum). + reshapeCount := strings.Count(mlir, "stablehlo.reshape") + if reshapeCount != 2 { + t.Errorf("expected 2 reshape ops for keepDims, got %d", reshapeCount) + } + + // Verify the output types contain the broadcast shapes. + if !strings.Contains(mlir, "tensor<2x1xf32>") { + t.Error("missing keepDims shape tensor<2x1xf32>") + } + if !strings.Contains(mlir, "tensor<2x3xf32>") { + t.Error("missing full shape tensor<2x3xf32>") + } +} + +func TestEmitReduceRegionBody(t *testing.T) { + namer := &SSANamer{} + _, mlir := EmitReduceSum(namer, "%x", []int{3, 4}, 0, false, "f64") + + // Verify region body structure. + if !strings.Contains(mlir, "^bb0(%arg0: tensor, %arg1: tensor)") { + t.Errorf("missing region block args in:\n%s", mlir) + } + if !strings.Contains(mlir, "stablehlo.return %0 : tensor") { + t.Errorf("missing stablehlo.return in:\n%s", mlir) + } +} + +func TestEmitReduce3D(t *testing.T) { + namer := &SSANamer{} + _, mlir := EmitReduceSum(namer, "%t", []int{2, 3, 4}, 2, false, "f32") + + if !strings.Contains(mlir, "dimensions = array") { + t.Error("wrong dimension for axis=2") + } + if !strings.Contains(mlir, "tensor<2x3xf32>") { + t.Errorf("wrong output shape, expected tensor<2x3xf32> in:\n%s", mlir) + } +} + +func TestReduceShape(t *testing.T) { + tests := []struct { + name string + shape []int + axis int + keepDims bool + want []int + }{ + {"remove dim", []int{2, 3}, 1, false, []int{2}}, + {"keep dim", []int{2, 3}, 1, true, []int{2, 1}}, + {"remove first", []int{4, 5}, 0, false, []int{5}}, + {"keep first", []int{4, 5}, 0, true, []int{1, 5}}, + {"3D middle", []int{2, 3, 4}, 1, false, []int{2, 4}}, + {"3D middle keep", []int{2, 3, 4}, 1, true, []int{2, 1, 4}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := reduceShape(tt.shape, tt.axis, tt.keepDims) + if len(got) != len(tt.want) { + t.Fatalf("reduceShape(%v, %d, %v) = %v, want %v", tt.shape, tt.axis, tt.keepDims, got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("reduceShape(%v, %d, %v)[%d] = %d, want %d", tt.shape, tt.axis, tt.keepDims, i, got[i], tt.want[i]) + } + } + }) + } +}