From a2a59443875e0c503b42f530199fba90d2adb8d6 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:57:03 -0700 Subject: [PATCH 1/2] feat(pjrt): add content-addressed executable cache with LRU eviction Cache key is SHA256(stablehlo_mlir + platform_name). Serialized PJRT executables are stored to $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/. LRU eviction kicks in when total size exceeds configurable max (default 2 GB). Atomic writes via tmp+rename. Thread-safe via sync.Mutex. Implements T64.1.1 and T64.1.2. --- internal/pjrt/cache.go | 254 +++++++++++++++++++++++++++++++++++ internal/pjrt/cache_test.go | 261 ++++++++++++++++++++++++++++++++++++ 2 files changed, 515 insertions(+) create mode 100644 internal/pjrt/cache.go create mode 100644 internal/pjrt/cache_test.go diff --git a/internal/pjrt/cache.go b/internal/pjrt/cache.go new file mode 100644 index 0000000..e9745ba --- /dev/null +++ b/internal/pjrt/cache.go @@ -0,0 +1,254 @@ +package pjrt + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "sort" + "sync" + "time" +) + +// DefaultCacheDir is the default directory for cached PJRT executables. +const DefaultCacheDir = ".cache/zerfoo/pjrt" + +// DefaultMaxCacheSize is the default maximum cache size in bytes (2 GB). +const DefaultMaxCacheSize int64 = 2 << 30 + +// CacheOption configures a Cache. +type CacheOption func(*Cache) + +// WithCacheDir sets the cache directory. If empty, defaults to +// $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/. +func WithCacheDir(dir string) CacheOption { + return func(c *Cache) { c.dir = dir } +} + +// WithMaxCacheSize sets the maximum total size of cached files in bytes. +func WithMaxCacheSize(n int64) CacheOption { + return func(c *Cache) { c.maxSize = n } +} + +// CacheStats holds cache hit/miss/size statistics. +type CacheStats struct { + Hits int64 + Misses int64 + Size int64 // total bytes on disk + Files int // number of cached entries +} + +// Cache stores serialized PJRT executables keyed by a content hash of +// the StableHLO program text and platform name. It provides LRU eviction +// when the total size exceeds MaxSize. +type Cache struct { + mu sync.Mutex + dir string + maxSize int64 + hits int64 + misses int64 +} + +// NewCache creates a new executable cache. The cache directory is created +// on first Put if it does not already exist. +func NewCache(opts ...CacheOption) *Cache { + c := &Cache{maxSize: DefaultMaxCacheSize} + for _, o := range opts { + o(c) + } + if c.dir == "" { + c.dir = resolveCacheDir() + } + return c +} + +// Key returns the content-addressed cache key for the given StableHLO +// program and platform name: SHA256(program + "\x00" + platform). +func Key(stablehloMLIR, platformName string) string { + h := sha256.New() + h.Write([]byte(stablehloMLIR)) + h.Write([]byte{0}) + h.Write([]byte(platformName)) + return hex.EncodeToString(h.Sum(nil)) +} + +// Get looks up a cached serialized executable by key. If found, the raw +// bytes are returned (caller must DeserializeAndLoad). Returns nil, nil +// on cache miss. +func (c *Cache) Get(key string) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + path := c.entryPath(key) + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + c.misses++ + return nil, nil + } + if err != nil { + c.misses++ + return nil, fmt.Errorf("pjrt cache: read %s: %w", key, err) + } + + // Touch access time for LRU tracking. + now := time.Now() + _ = os.Chtimes(path, now, now) + + c.hits++ + return data, nil +} + +// Put stores serialized executable bytes under the given key. If storing +// the new entry would exceed MaxSize, the least-recently-used entries are +// evicted first. +func (c *Cache) Put(key string, data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := os.MkdirAll(c.dir, 0o755); err != nil { + return fmt.Errorf("pjrt cache: create dir: %w", err) + } + + path := c.entryPath(key) + + // Write atomically: write to tmp then rename. + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return fmt.Errorf("pjrt cache: write %s: %w", key, err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("pjrt cache: rename %s: %w", key, err) + } + + // Evict if over budget. + c.evictLocked() + return nil +} + +// Evict removes the least-recently-used entries until total cache size +// is within MaxSize. +func (c *Cache) Evict() { + c.mu.Lock() + defer c.mu.Unlock() + c.evictLocked() +} + +// Clear removes all cached entries. +func (c *Cache) Clear() error { + c.mu.Lock() + defer c.mu.Unlock() + + entries, _ := os.ReadDir(c.dir) + for _, e := range entries { + if e.IsDir() { + continue + } + _ = os.Remove(filepath.Join(c.dir, e.Name())) + } + return nil +} + +// Stats returns current cache statistics. +func (c *Cache) Stats() CacheStats { + c.mu.Lock() + defer c.mu.Unlock() + + var totalSize int64 + var fileCount int + entries, _ := os.ReadDir(c.dir) + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + totalSize += info.Size() + fileCount++ + } + + return CacheStats{ + Hits: c.hits, + Misses: c.misses, + Size: totalSize, + Files: fileCount, + } +} + +// Dir returns the cache directory path. +func (c *Cache) Dir() string { + return c.dir +} + +// entryPath returns the filesystem path for a cache key. +func (c *Cache) entryPath(key string) string { + return filepath.Join(c.dir, key+".pjrt") +} + +// cacheEntry holds file info for LRU sorting. +type cacheEntry struct { + path string + size int64 + modTime time.Time +} + +// evictLocked removes LRU entries until total size <= maxSize. Caller must hold mu. +func (c *Cache) evictLocked() { + entries, err := os.ReadDir(c.dir) + if err != nil { + return + } + + var files []cacheEntry + var totalSize int64 + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + files = append(files, cacheEntry{ + path: filepath.Join(c.dir, e.Name()), + size: info.Size(), + modTime: info.ModTime(), + }) + totalSize += info.Size() + } + + if totalSize <= c.maxSize { + return + } + + // Sort oldest first (least recently used). + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + for _, f := range files { + if totalSize <= c.maxSize { + break + } + if err := os.Remove(f.path); err == nil { + totalSize -= f.size + } + } +} + +// resolveCacheDir returns the cache directory, checking ZERFOO_PJRT_CACHE +// env var first, then falling back to ~/.cache/zerfoo/pjrt/. +func resolveCacheDir() string { + if dir := os.Getenv("ZERFOO_PJRT_CACHE"); dir != "" { + return dir + } + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(os.TempDir(), "zerfoo-pjrt-cache") + } + return filepath.Join(home, DefaultCacheDir) +} + diff --git a/internal/pjrt/cache_test.go b/internal/pjrt/cache_test.go new file mode 100644 index 0000000..3d8f98b --- /dev/null +++ b/internal/pjrt/cache_test.go @@ -0,0 +1,261 @@ +package pjrt + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestCacheKey(t *testing.T) { + // Same inputs produce same key. + k1 := Key("module { func.func @main() {} }", "cpu") + k2 := Key("module { func.func @main() {} }", "cpu") + if k1 != k2 { + t.Fatalf("same input produced different keys: %s vs %s", k1, k2) + } + + // Different programs produce different keys. + k3 := Key("module { func.func @other() {} }", "cpu") + if k1 == k3 { + t.Fatal("different programs produced same key") + } + + // Different platforms produce different keys. + k4 := Key("module { func.func @main() {} }", "cuda") + if k1 == k4 { + t.Fatal("different platforms produced same key") + } + + // Key is hex-encoded SHA256 (64 chars). + if len(k1) != 64 { + t.Fatalf("expected 64-char hex key, got %d chars", len(k1)) + } +} + +func TestCacheMiss(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + data, err := c.Get("nonexistent") + if err != nil { + t.Fatalf("unexpected error on miss: %v", err) + } + if data != nil { + t.Fatal("expected nil data on cache miss") + } + + stats := c.Stats() + if stats.Misses != 1 { + t.Fatalf("expected 1 miss, got %d", stats.Misses) + } + if stats.Hits != 0 { + t.Fatalf("expected 0 hits, got %d", stats.Hits) + } +} + +func TestCachePutGet(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("program1", "cpu") + payload := []byte("serialized-executable-bytes") + + if err := c.Put(key, payload); err != nil { + t.Fatalf("Put: %v", err) + } + + data, err := c.Get(key) + if err != nil { + t.Fatalf("Get: %v", err) + } + if string(data) != string(payload) { + t.Fatalf("Get returned %q, want %q", data, payload) + } + + stats := c.Stats() + if stats.Hits != 1 { + t.Fatalf("expected 1 hit, got %d", stats.Hits) + } + if stats.Files != 1 { + t.Fatalf("expected 1 file, got %d", stats.Files) + } + if stats.Size != int64(len(payload)) { + t.Fatalf("expected size %d, got %d", len(payload), stats.Size) + } +} + +func TestCacheLRUEviction(t *testing.T) { + dir := t.TempDir() + // Max size = 50 bytes. Each entry is 20 bytes. + c := NewCache(WithCacheDir(dir), WithMaxCacheSize(50)) + + k1 := Key("prog1", "cpu") + k2 := Key("prog2", "cpu") + k3 := Key("prog3", "cpu") + + d := make([]byte, 20) + + // Put k1, then sleep briefly so mod time differs. + if err := c.Put(k1, d); err != nil { + t.Fatal(err) + } + // Backdate k1 so it's the oldest. + p1 := c.entryPath(k1) + old := time.Now().Add(-2 * time.Second) + if err := os.Chtimes(p1, old, old); err != nil { + t.Fatal(err) + } + + if err := c.Put(k2, d); err != nil { + t.Fatal(err) + } + + // At this point: 40 bytes total, under 50 limit. Both should exist. + stats := c.Stats() + if stats.Files != 2 { + t.Fatalf("expected 2 files before eviction, got %d", stats.Files) + } + + // Put k3: total would be 60 bytes, exceeding 50. k1 (oldest) evicted. + if err := c.Put(k3, d); err != nil { + t.Fatal(err) + } + + stats = c.Stats() + if stats.Files != 2 { + t.Fatalf("expected 2 files after eviction, got %d", stats.Files) + } + if stats.Size != 40 { + t.Fatalf("expected size 40 after eviction, got %d", stats.Size) + } + + // k1 should be evicted. + data, _ := c.Get(k1) + if data != nil { + t.Fatal("expected k1 to be evicted") + } + + // k2 and k3 should still exist. + data, _ = c.Get(k2) + if data == nil { + t.Fatal("expected k2 to still exist") + } + data, _ = c.Get(k3) + if data == nil { + t.Fatal("expected k3 to still exist") + } +} + +func TestCacheClear(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + for i := 0; i < 5; i++ { + key := Key(string(rune('a'+i)), "cpu") + if err := c.Put(key, []byte("data")); err != nil { + t.Fatal(err) + } + } + + stats := c.Stats() + if stats.Files != 5 { + t.Fatalf("expected 5 files, got %d", stats.Files) + } + + if err := c.Clear(); err != nil { + t.Fatalf("Clear: %v", err) + } + + stats = c.Stats() + if stats.Files != 0 { + t.Fatalf("expected 0 files after clear, got %d", stats.Files) + } +} + +func TestCacheAtomicWrite(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("atomic-test", "cpu") + if err := c.Put(key, []byte("good-data")); err != nil { + t.Fatal(err) + } + + // Verify no .tmp files remain. + entries, _ := os.ReadDir(dir) + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + t.Fatalf("found leftover .tmp file: %s", e.Name()) + } + } + + data, _ := c.Get(key) + if string(data) != "good-data" { + t.Fatalf("expected 'good-data', got %q", data) + } +} + +func TestCacheOverwrite(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("overwrite-test", "cpu") + if err := c.Put(key, []byte("v1")); err != nil { + t.Fatal(err) + } + if err := c.Put(key, []byte("v2")); err != nil { + t.Fatal(err) + } + + data, err := c.Get(key) + if err != nil { + t.Fatal(err) + } + if string(data) != "v2" { + t.Fatalf("expected 'v2', got %q", data) + } + + stats := c.Stats() + if stats.Files != 1 { + t.Fatalf("expected 1 file after overwrite, got %d", stats.Files) + } +} + +func TestCacheEnvVar(t *testing.T) { + dir := t.TempDir() + t.Setenv("ZERFOO_PJRT_CACHE", dir) + + c := NewCache() + if c.Dir() != dir { + t.Fatalf("expected dir %s from env, got %s", dir, c.Dir()) + } +} + +func TestCacheManualEvict(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir), WithMaxCacheSize(10)) + + // Put 30 bytes, which exceeds budget. + k1 := Key("evict1", "cpu") + k2 := Key("evict2", "cpu") + if err := c.Put(k1, make([]byte, 15)); err != nil { + t.Fatal(err) + } + // Backdate k1. + old := time.Now().Add(-2 * time.Second) + _ = os.Chtimes(c.entryPath(k1), old, old) + + if err := c.Put(k2, make([]byte, 15)); err != nil { + t.Fatal(err) + } + + // After Put(k2), eviction already ran. But let's also test manual Evict(). + c.maxSize = 5 + c.Evict() + + stats := c.Stats() + if stats.Size > 5 { + t.Fatalf("expected size <= 5 after manual evict, got %d", stats.Size) + } +} From f01adf2843515b88b701e6aa36c5ce93209ff61f Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 2 Apr 2026 15:57:20 -0700 Subject: [PATCH 2/2] feat(stablehlo): add KV cache I/O rewriting to program emitter PJRT is pure-functional and cannot handle mutable state. The graph's StatefulInputNode KV cache feedback must be rewritten as explicit function I/O: KV cache tensors become both function arguments and return values. Add KVCacheSlot type and EmitKVCacheProgram function that: - Adds KV cache inputs as extra function arguments - Adds KV cache outputs as extra return values (tuple return) - For decode programs, emits stablehlo.concatenate to append new KV step along the sequence axis - For prefill programs, passes KV outputs through directly Implements T61.3.2 from plan-pjrt.md. --- internal/stablehlo/program.go | 181 +++++++++++++++++++++++ internal/stablehlo/program_test.go | 229 +++++++++++++++++++++++++++++ 2 files changed, 410 insertions(+) diff --git a/internal/stablehlo/program.go b/internal/stablehlo/program.go index d38af61..62f15e7 100644 --- a/internal/stablehlo/program.go +++ b/internal/stablehlo/program.go @@ -16,6 +16,19 @@ type ProgramOp struct { Attrs map[string]any // op-specific attributes } +// KVCacheSlot describes a stateful KV cache slot that must be rewritten as +// explicit function I/O for PJRT's pure-functional execution model. +// +// In the original graph, the KV cache is fed back via StatefulInputNode. +// For PJRT, each KV cache tensor becomes both a function argument (the +// previous state) and a return value (the updated state). +type KVCacheSlot struct { + InputSlot int // slot index where the KV cache is read (becomes a function arg) + OutputSlot int // slot index where the updated KV cache is produced (becomes a return value) + Shape []int // tensor shape (e.g., [num_heads, seq_len, head_dim]) + SeqAxis int // axis along which decode concatenation occurs +} + // EmitProgram takes a sequence of operations and produces a complete StableHLO // MLIR module. inputSlots identifies which slots are function arguments, and // inputShapes provides their shapes. The last operation's output slot is used @@ -97,6 +110,174 @@ func EmitProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, dtype s return b.String(), nil } +// EmitKVCacheProgram emits a StableHLO program with explicit KV cache I/O. +// +// KV cache tensors are added as both function arguments and return values. +// The function signature becomes: +// +// func.func @main(%regular_args..., %kv_in_0, %kv_in_1, ...) -> +// (regular_output, %kv_out_0, %kv_out_1, ...) +// +// For decode programs (decode=true), each KV cache output is produced by +// concatenating the KV input with the new KV step along the sequence axis: +// +// kv_out = concat(kv_in, kv_step, axis=seq_axis) +// +// For prefill programs (decode=false), KV cache outputs are passed through +// directly from the ops that produce them. +func EmitKVCacheProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, kvSlots []KVCacheSlot, dtype string, decode bool) (string, error) { + if len(ops) == 0 { + return "", fmt.Errorf("EmitKVCacheProgram: no operations provided") + } + if len(inputSlots) != len(inputShapes) { + return "", fmt.Errorf("EmitKVCacheProgram: inputSlots length %d != inputShapes length %d", len(inputSlots), len(inputShapes)) + } + if len(kvSlots) == 0 { + return "", fmt.Errorf("EmitKVCacheProgram: no KV cache slots provided") + } + + // Build slot table: slot index -> SSA name. + slotTable := make(map[int]string) + argIdx := 0 + for i, slot := range inputSlots { + _ = i + slotTable[slot] = fmt.Sprintf("%%arg%d", argIdx) + argIdx++ + } + + // Add KV cache input slots as additional function arguments. + kvArgStart := argIdx + for i, kv := range kvSlots { + slotTable[kv.InputSlot] = fmt.Sprintf("%%arg%d", kvArgStart+i) + argIdx++ + } + + // Build function argument declarations. + var argDecls []string + for i, shape := range inputShapes { + ty := FormatTensorType(shape, dtype) + argDecls = append(argDecls, fmt.Sprintf("%%arg%d: %s", i, ty)) + } + for i, kv := range kvSlots { + ty := FormatTensorType(kv.Shape, dtype) + argDecls = append(argDecls, fmt.Sprintf("%%arg%d: %s", kvArgStart+i, ty)) + } + + // Emit all ops. + emitter := NewEmitter() + var bodyLines []string + + for i, op := range ops { + inputNames := make([]string, len(op.InputSlots)) + for j, slot := range op.InputSlots { + name, ok := slotTable[slot] + if !ok { + return "", fmt.Errorf("EmitKVCacheProgram: op %d (%s) references undefined slot %d", i, op.OpName, slot) + } + inputNames[j] = name + } + + mlir, outName, err := dispatchProgramOp(emitter, op, inputNames) + if err != nil { + return "", fmt.Errorf("EmitKVCacheProgram: op %d (%s): %w", i, op.OpName, err) + } + + slotTable[op.OutputSlot] = outName + bodyLines = append(bodyLines, mlir) + } + + // Build return values: primary output + KV cache outputs. + lastOp := ops[len(ops)-1] + primaryReturnName, ok := slotTable[lastOp.OutputSlot] + if !ok { + return "", fmt.Errorf("EmitKVCacheProgram: return slot %d not found", lastOp.OutputSlot) + } + primaryReturnType := FormatTensorType(lastOp.OutputShape, lastOp.Dtype) + + var returnNames []string + var returnTypes []string + returnNames = append(returnNames, primaryReturnName) + returnTypes = append(returnTypes, primaryReturnType) + + // For each KV slot, resolve its output and optionally emit concat. + for i, kv := range kvSlots { + kvOutName, ok := slotTable[kv.OutputSlot] + if !ok { + return "", fmt.Errorf("EmitKVCacheProgram: KV output slot %d not found", kv.OutputSlot) + } + + if decode { + // Decode: concat(kv_in, kv_step, axis=seq_axis). + kvInName := fmt.Sprintf("%%arg%d", kvArgStart+i) + + // Compute the concat output shape: kv_in.Shape with seq_axis doubled + // (kv_in has full seq_len, kv_step has 1 step). + kvStepShape := make([]int, len(kv.Shape)) + copy(kvStepShape, kv.Shape) + kvStepShape[kv.SeqAxis] = 1 // the new step is a single position + + concatOutShape := make([]int, len(kv.Shape)) + copy(concatOutShape, kv.Shape) + concatOutShape[kv.SeqAxis] = kv.Shape[kv.SeqAxis] + 1 + + mlir, concatName, err := EmitConcat( + emitter.Namer, + []string{kvInName, kvOutName}, + [][]int{kv.Shape, kvStepShape}, + kv.SeqAxis, + dtype, + ) + if err != nil { + return "", fmt.Errorf("EmitKVCacheProgram: KV concat for slot %d: %w", i, err) + } + bodyLines = append(bodyLines, mlir) + kvOutName = concatName + returnTypes = append(returnTypes, FormatTensorType(concatOutShape, dtype)) + } else { + // Prefill: pass through the KV output directly. + // Find the shape from the op that produces this slot. + kvOutShape := findSlotShape(ops, kv.OutputSlot) + if kvOutShape == nil { + return "", fmt.Errorf("EmitKVCacheProgram: cannot determine shape for KV output slot %d", kv.OutputSlot) + } + returnTypes = append(returnTypes, FormatTensorType(kvOutShape, dtype)) + } + returnNames = append(returnNames, kvOutName) + } + + // Build the module. + returnTypeStr := "(" + strings.Join(returnTypes, ", ") + ")" + returnValueStr := strings.Join(returnNames, ", ") + returnAnnotation := strings.Join(returnTypes, ", ") + + var b strings.Builder + b.WriteString("module {\n") + fmt.Fprintf(&b, " func.func @main(%s) -> %s {\n", strings.Join(argDecls, ", "), returnTypeStr) + for _, line := range bodyLines { + for _, subLine := range strings.Split(line, "\n") { + if subLine == "" { + continue + } + fmt.Fprintf(&b, " %s\n", subLine) + } + } + fmt.Fprintf(&b, " return %s : %s\n", returnValueStr, returnAnnotation) + b.WriteString(" }\n") + b.WriteString("}") + + return b.String(), nil +} + +// findSlotShape finds the output shape of the op that produces a given slot. +func findSlotShape(ops []ProgramOp, slot int) []int { + for _, op := range ops { + if op.OutputSlot == slot { + return op.OutputShape + } + } + return nil +} + // dispatchProgramOp dispatches a ProgramOp to the appropriate emitter. // It handles element-wise ops via Emitter.EmitOp and structural/reduce ops // via their dedicated emitters in emit_structural.go and emit_reduce.go. diff --git a/internal/stablehlo/program_test.go b/internal/stablehlo/program_test.go index 82ae116..8664ea9 100644 --- a/internal/stablehlo/program_test.go +++ b/internal/stablehlo/program_test.go @@ -246,6 +246,235 @@ func TestEmitProgram_Reshape(t *testing.T) { } } +func TestEmitKVCacheProgram_Prefill(t *testing.T) { + // Simulate a simple prefill: matmul produces logits, a separate op produces KV cache. + // Slots: 0=input_tokens, 1=weights, 2=kv_weights, 3=matmul_out(logits), 4=kv_out + ops := []ProgramOp{ + { + OpName: "MatMul", + InputSlots: []int{0, 1}, + OutputSlot: 3, + InputShapes: [][]int{{1, 2048}, {2048, 32000}}, + OutputShape: []int{1, 32000}, + Dtype: DTypeF32, + }, + { + OpName: "MatMul", + InputSlots: []int{0, 2}, + OutputSlot: 4, + InputShapes: [][]int{{1, 2048}, {2048, 128}}, + OutputShape: []int{1, 128}, + Dtype: DTypeF32, + }, + } + + kvSlots := []KVCacheSlot{ + {InputSlot: 5, OutputSlot: 4, Shape: []int{32, 2048, 128}, SeqAxis: 1}, + } + + mlir, err := EmitKVCacheProgram(ops, []int{0, 1, 2}, [][]int{{1, 2048}, {2048, 32000}, {2048, 128}}, kvSlots, DTypeF32, false) + if err != nil { + t.Fatal(err) + } + + // Should have tuple return type: (logits, kv_cache). + if !strings.Contains(mlir, "-> (") { + t.Errorf("expected tuple return type:\n%s", mlir) + } + + // KV cache input should appear as a function argument. + if !strings.Contains(mlir, "%arg3:") { + t.Errorf("expected KV cache input arg (%%arg3):\n%s", mlir) + } + + // Return should have two values. + lines := strings.Split(mlir, "\n") + var returnLine string + for _, l := range lines { + if strings.Contains(l, "return ") { + returnLine = l + break + } + } + if returnLine == "" { + t.Fatal("no return statement found") + } + // Should return two comma-separated values. + returnParts := strings.SplitN(returnLine, ":", 2) + if len(returnParts) < 2 { + t.Fatalf("malformed return line: %s", returnLine) + } + returnValues := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(returnParts[0]), "return")) + commaCount := strings.Count(returnValues, ",") + if commaCount != 1 { + t.Errorf("expected 2 return values (1 comma), got %d commas in %q", commaCount, returnValues) + } +} + +func TestEmitKVCacheProgram_Decode(t *testing.T) { + // Decode program: single token + KV cache -> logits + updated KV cache. + // Slots: 0=token, 1=weights, 2=matmul_out(logits), 3=kv_step + ops := []ProgramOp{ + { + OpName: "MatMul", + InputSlots: []int{0, 1}, + OutputSlot: 2, + InputShapes: [][]int{{1, 128}, {128, 32000}}, + OutputShape: []int{1, 32000}, + Dtype: DTypeF32, + }, + { + OpName: "MatMul", + InputSlots: []int{0, 1}, + OutputSlot: 3, + InputShapes: [][]int{{1, 128}, {128, 128}}, + OutputShape: []int{1, 128}, + Dtype: DTypeF32, + }, + } + + kvSlots := []KVCacheSlot{ + {InputSlot: 10, OutputSlot: 3, Shape: []int{32, 64, 128}, SeqAxis: 1}, + } + + mlir, err := EmitKVCacheProgram(ops, []int{0, 1}, [][]int{{1, 128}, {128, 32000}}, kvSlots, DTypeF32, true) + if err != nil { + t.Fatal(err) + } + + // Decode should emit a concatenate for KV cache. + if !strings.Contains(mlir, "stablehlo.concatenate") { + t.Errorf("decode program should contain stablehlo.concatenate:\n%s", mlir) + } + + // KV cache arg should be present. + if !strings.Contains(mlir, "%arg2: tensor<32x64x128xf32>") { + t.Errorf("expected KV cache input arg with shape 32x64x128:\n%s", mlir) + } + + // Return should have (logits, updated_kv) = 2 values. + lines := strings.Split(mlir, "\n") + var returnLine string + for _, l := range lines { + if strings.Contains(l, "return ") { + returnLine = l + break + } + } + if returnLine == "" { + t.Fatal("no return statement found") + } + returnParts := strings.SplitN(returnLine, ":", 2) + if len(returnParts) < 2 { + t.Fatalf("malformed return line: %s", returnLine) + } + returnValues := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(returnParts[0]), "return")) + commaCount := strings.Count(returnValues, ",") + if commaCount != 1 { + t.Errorf("expected 2 return values (1 comma), got %d commas in %q", commaCount, returnValues) + } + + // The updated KV shape should be seq_len+1 along the seq axis. + if !strings.Contains(mlir, "tensor<32x65x128xf32>") { + t.Errorf("expected updated KV cache shape 32x65x128 (seq_len 64+1):\n%s", mlir) + } +} + +func TestEmitKVCacheProgram_MultiLayer(t *testing.T) { + // Two KV cache layers (like a 2-layer transformer). + ops := []ProgramOp{ + { + OpName: "Add", + InputSlots: []int{0, 1}, + OutputSlot: 4, + InputShapes: [][]int{{1, 64}, {1, 64}}, + OutputShape: []int{1, 64}, + Dtype: DTypeF32, + }, + { + OpName: "Add", + InputSlots: []int{0, 1}, + OutputSlot: 5, + InputShapes: [][]int{{1, 64}, {1, 64}}, + OutputShape: []int{1, 64}, + Dtype: DTypeF32, + }, + } + + kvSlots := []KVCacheSlot{ + {InputSlot: 10, OutputSlot: 4, Shape: []int{8, 32, 64}, SeqAxis: 1}, + {InputSlot: 11, OutputSlot: 5, Shape: []int{8, 32, 64}, SeqAxis: 1}, + } + + mlir, err := EmitKVCacheProgram(ops, []int{0, 1}, [][]int{{1, 64}, {1, 64}}, kvSlots, DTypeF32, false) + if err != nil { + t.Fatal(err) + } + + // Should have 4 function args: 2 regular + 2 KV cache. + for _, arg := range []string{"%arg0:", "%arg1:", "%arg2:", "%arg3:"} { + if !strings.Contains(mlir, arg) { + t.Errorf("expected arg %s in signature:\n%s", arg, mlir) + } + } + + // Return should have 3 values: primary + 2 KV outputs. + lines := strings.Split(mlir, "\n") + var returnLine string + for _, l := range lines { + if strings.Contains(l, "return ") { + returnLine = l + break + } + } + if returnLine == "" { + t.Fatal("no return statement found") + } + returnParts := strings.SplitN(returnLine, ":", 2) + returnValues := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(returnParts[0]), "return")) + commaCount := strings.Count(returnValues, ",") + if commaCount != 2 { + t.Errorf("expected 3 return values (2 commas), got %d commas in %q", commaCount, returnValues) + } +} + +func TestEmitKVCacheProgram_ErrorCases(t *testing.T) { + baseOps := []ProgramOp{ + {OpName: "Add", InputSlots: []int{0, 1}, OutputSlot: 2, + InputShapes: [][]int{{2}, {2}}, OutputShape: []int{2}, Dtype: DTypeF32}, + } + + t.Run("no ops", func(t *testing.T) { + _, err := EmitKVCacheProgram(nil, nil, nil, []KVCacheSlot{{InputSlot: 0, OutputSlot: 1, Shape: []int{2}}}, DTypeF32, false) + if err == nil { + t.Fatal("expected error for empty ops") + } + }) + + t.Run("no kv slots", func(t *testing.T) { + _, err := EmitKVCacheProgram(baseOps, []int{0, 1}, [][]int{{2}, {2}}, nil, DTypeF32, false) + if err == nil { + t.Fatal("expected error for nil KV slots") + } + }) + + t.Run("mismatched input slots and shapes", func(t *testing.T) { + _, err := EmitKVCacheProgram(baseOps, []int{0}, [][]int{{2}, {2}}, + []KVCacheSlot{{InputSlot: 3, OutputSlot: 2, Shape: []int{2}}}, DTypeF32, false) + if err == nil { + t.Fatal("expected error for mismatched slots/shapes") + } + }) + + t.Run("undefined kv output slot", func(t *testing.T) { + _, err := EmitKVCacheProgram(baseOps, []int{0, 1}, [][]int{{2}, {2}}, + []KVCacheSlot{{InputSlot: 3, OutputSlot: 99, Shape: []int{2}}}, DTypeF32, false) + if err == nil { + t.Fatal("expected error for undefined KV output slot") + } + }) +} + func TestEmitProgram_ReduceSum(t *testing.T) { ops := []ProgramOp{ {