diff --git a/internal/pjrt/cache.go b/internal/pjrt/cache.go new file mode 100644 index 0000000..e9745ba --- /dev/null +++ b/internal/pjrt/cache.go @@ -0,0 +1,254 @@ +package pjrt + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "sort" + "sync" + "time" +) + +// DefaultCacheDir is the default directory for cached PJRT executables. +const DefaultCacheDir = ".cache/zerfoo/pjrt" + +// DefaultMaxCacheSize is the default maximum cache size in bytes (2 GB). +const DefaultMaxCacheSize int64 = 2 << 30 + +// CacheOption configures a Cache. +type CacheOption func(*Cache) + +// WithCacheDir sets the cache directory. If empty, defaults to +// $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/. +func WithCacheDir(dir string) CacheOption { + return func(c *Cache) { c.dir = dir } +} + +// WithMaxCacheSize sets the maximum total size of cached files in bytes. +func WithMaxCacheSize(n int64) CacheOption { + return func(c *Cache) { c.maxSize = n } +} + +// CacheStats holds cache hit/miss/size statistics. +type CacheStats struct { + Hits int64 + Misses int64 + Size int64 // total bytes on disk + Files int // number of cached entries +} + +// Cache stores serialized PJRT executables keyed by a content hash of +// the StableHLO program text and platform name. It provides LRU eviction +// when the total size exceeds MaxSize. +type Cache struct { + mu sync.Mutex + dir string + maxSize int64 + hits int64 + misses int64 +} + +// NewCache creates a new executable cache. The cache directory is created +// on first Put if it does not already exist. +func NewCache(opts ...CacheOption) *Cache { + c := &Cache{maxSize: DefaultMaxCacheSize} + for _, o := range opts { + o(c) + } + if c.dir == "" { + c.dir = resolveCacheDir() + } + return c +} + +// Key returns the content-addressed cache key for the given StableHLO +// program and platform name: SHA256(program + "\x00" + platform). +func Key(stablehloMLIR, platformName string) string { + h := sha256.New() + h.Write([]byte(stablehloMLIR)) + h.Write([]byte{0}) + h.Write([]byte(platformName)) + return hex.EncodeToString(h.Sum(nil)) +} + +// Get looks up a cached serialized executable by key. If found, the raw +// bytes are returned (caller must DeserializeAndLoad). Returns nil, nil +// on cache miss. +func (c *Cache) Get(key string) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + path := c.entryPath(key) + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + c.misses++ + return nil, nil + } + if err != nil { + c.misses++ + return nil, fmt.Errorf("pjrt cache: read %s: %w", key, err) + } + + // Touch access time for LRU tracking. + now := time.Now() + _ = os.Chtimes(path, now, now) + + c.hits++ + return data, nil +} + +// Put stores serialized executable bytes under the given key. If storing +// the new entry would exceed MaxSize, the least-recently-used entries are +// evicted first. +func (c *Cache) Put(key string, data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := os.MkdirAll(c.dir, 0o755); err != nil { + return fmt.Errorf("pjrt cache: create dir: %w", err) + } + + path := c.entryPath(key) + + // Write atomically: write to tmp then rename. + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return fmt.Errorf("pjrt cache: write %s: %w", key, err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("pjrt cache: rename %s: %w", key, err) + } + + // Evict if over budget. + c.evictLocked() + return nil +} + +// Evict removes the least-recently-used entries until total cache size +// is within MaxSize. +func (c *Cache) Evict() { + c.mu.Lock() + defer c.mu.Unlock() + c.evictLocked() +} + +// Clear removes all cached entries. +func (c *Cache) Clear() error { + c.mu.Lock() + defer c.mu.Unlock() + + entries, _ := os.ReadDir(c.dir) + for _, e := range entries { + if e.IsDir() { + continue + } + _ = os.Remove(filepath.Join(c.dir, e.Name())) + } + return nil +} + +// Stats returns current cache statistics. +func (c *Cache) Stats() CacheStats { + c.mu.Lock() + defer c.mu.Unlock() + + var totalSize int64 + var fileCount int + entries, _ := os.ReadDir(c.dir) + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + totalSize += info.Size() + fileCount++ + } + + return CacheStats{ + Hits: c.hits, + Misses: c.misses, + Size: totalSize, + Files: fileCount, + } +} + +// Dir returns the cache directory path. +func (c *Cache) Dir() string { + return c.dir +} + +// entryPath returns the filesystem path for a cache key. +func (c *Cache) entryPath(key string) string { + return filepath.Join(c.dir, key+".pjrt") +} + +// cacheEntry holds file info for LRU sorting. +type cacheEntry struct { + path string + size int64 + modTime time.Time +} + +// evictLocked removes LRU entries until total size <= maxSize. Caller must hold mu. +func (c *Cache) evictLocked() { + entries, err := os.ReadDir(c.dir) + if err != nil { + return + } + + var files []cacheEntry + var totalSize int64 + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + files = append(files, cacheEntry{ + path: filepath.Join(c.dir, e.Name()), + size: info.Size(), + modTime: info.ModTime(), + }) + totalSize += info.Size() + } + + if totalSize <= c.maxSize { + return + } + + // Sort oldest first (least recently used). + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + for _, f := range files { + if totalSize <= c.maxSize { + break + } + if err := os.Remove(f.path); err == nil { + totalSize -= f.size + } + } +} + +// resolveCacheDir returns the cache directory, checking ZERFOO_PJRT_CACHE +// env var first, then falling back to ~/.cache/zerfoo/pjrt/. +func resolveCacheDir() string { + if dir := os.Getenv("ZERFOO_PJRT_CACHE"); dir != "" { + return dir + } + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(os.TempDir(), "zerfoo-pjrt-cache") + } + return filepath.Join(home, DefaultCacheDir) +} + diff --git a/internal/pjrt/cache_test.go b/internal/pjrt/cache_test.go new file mode 100644 index 0000000..3d8f98b --- /dev/null +++ b/internal/pjrt/cache_test.go @@ -0,0 +1,261 @@ +package pjrt + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestCacheKey(t *testing.T) { + // Same inputs produce same key. + k1 := Key("module { func.func @main() {} }", "cpu") + k2 := Key("module { func.func @main() {} }", "cpu") + if k1 != k2 { + t.Fatalf("same input produced different keys: %s vs %s", k1, k2) + } + + // Different programs produce different keys. + k3 := Key("module { func.func @other() {} }", "cpu") + if k1 == k3 { + t.Fatal("different programs produced same key") + } + + // Different platforms produce different keys. + k4 := Key("module { func.func @main() {} }", "cuda") + if k1 == k4 { + t.Fatal("different platforms produced same key") + } + + // Key is hex-encoded SHA256 (64 chars). + if len(k1) != 64 { + t.Fatalf("expected 64-char hex key, got %d chars", len(k1)) + } +} + +func TestCacheMiss(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + data, err := c.Get("nonexistent") + if err != nil { + t.Fatalf("unexpected error on miss: %v", err) + } + if data != nil { + t.Fatal("expected nil data on cache miss") + } + + stats := c.Stats() + if stats.Misses != 1 { + t.Fatalf("expected 1 miss, got %d", stats.Misses) + } + if stats.Hits != 0 { + t.Fatalf("expected 0 hits, got %d", stats.Hits) + } +} + +func TestCachePutGet(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("program1", "cpu") + payload := []byte("serialized-executable-bytes") + + if err := c.Put(key, payload); err != nil { + t.Fatalf("Put: %v", err) + } + + data, err := c.Get(key) + if err != nil { + t.Fatalf("Get: %v", err) + } + if string(data) != string(payload) { + t.Fatalf("Get returned %q, want %q", data, payload) + } + + stats := c.Stats() + if stats.Hits != 1 { + t.Fatalf("expected 1 hit, got %d", stats.Hits) + } + if stats.Files != 1 { + t.Fatalf("expected 1 file, got %d", stats.Files) + } + if stats.Size != int64(len(payload)) { + t.Fatalf("expected size %d, got %d", len(payload), stats.Size) + } +} + +func TestCacheLRUEviction(t *testing.T) { + dir := t.TempDir() + // Max size = 50 bytes. Each entry is 20 bytes. + c := NewCache(WithCacheDir(dir), WithMaxCacheSize(50)) + + k1 := Key("prog1", "cpu") + k2 := Key("prog2", "cpu") + k3 := Key("prog3", "cpu") + + d := make([]byte, 20) + + // Put k1, then sleep briefly so mod time differs. + if err := c.Put(k1, d); err != nil { + t.Fatal(err) + } + // Backdate k1 so it's the oldest. + p1 := c.entryPath(k1) + old := time.Now().Add(-2 * time.Second) + if err := os.Chtimes(p1, old, old); err != nil { + t.Fatal(err) + } + + if err := c.Put(k2, d); err != nil { + t.Fatal(err) + } + + // At this point: 40 bytes total, under 50 limit. Both should exist. + stats := c.Stats() + if stats.Files != 2 { + t.Fatalf("expected 2 files before eviction, got %d", stats.Files) + } + + // Put k3: total would be 60 bytes, exceeding 50. k1 (oldest) evicted. + if err := c.Put(k3, d); err != nil { + t.Fatal(err) + } + + stats = c.Stats() + if stats.Files != 2 { + t.Fatalf("expected 2 files after eviction, got %d", stats.Files) + } + if stats.Size != 40 { + t.Fatalf("expected size 40 after eviction, got %d", stats.Size) + } + + // k1 should be evicted. + data, _ := c.Get(k1) + if data != nil { + t.Fatal("expected k1 to be evicted") + } + + // k2 and k3 should still exist. + data, _ = c.Get(k2) + if data == nil { + t.Fatal("expected k2 to still exist") + } + data, _ = c.Get(k3) + if data == nil { + t.Fatal("expected k3 to still exist") + } +} + +func TestCacheClear(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + for i := 0; i < 5; i++ { + key := Key(string(rune('a'+i)), "cpu") + if err := c.Put(key, []byte("data")); err != nil { + t.Fatal(err) + } + } + + stats := c.Stats() + if stats.Files != 5 { + t.Fatalf("expected 5 files, got %d", stats.Files) + } + + if err := c.Clear(); err != nil { + t.Fatalf("Clear: %v", err) + } + + stats = c.Stats() + if stats.Files != 0 { + t.Fatalf("expected 0 files after clear, got %d", stats.Files) + } +} + +func TestCacheAtomicWrite(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("atomic-test", "cpu") + if err := c.Put(key, []byte("good-data")); err != nil { + t.Fatal(err) + } + + // Verify no .tmp files remain. + entries, _ := os.ReadDir(dir) + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + t.Fatalf("found leftover .tmp file: %s", e.Name()) + } + } + + data, _ := c.Get(key) + if string(data) != "good-data" { + t.Fatalf("expected 'good-data', got %q", data) + } +} + +func TestCacheOverwrite(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir)) + + key := Key("overwrite-test", "cpu") + if err := c.Put(key, []byte("v1")); err != nil { + t.Fatal(err) + } + if err := c.Put(key, []byte("v2")); err != nil { + t.Fatal(err) + } + + data, err := c.Get(key) + if err != nil { + t.Fatal(err) + } + if string(data) != "v2" { + t.Fatalf("expected 'v2', got %q", data) + } + + stats := c.Stats() + if stats.Files != 1 { + t.Fatalf("expected 1 file after overwrite, got %d", stats.Files) + } +} + +func TestCacheEnvVar(t *testing.T) { + dir := t.TempDir() + t.Setenv("ZERFOO_PJRT_CACHE", dir) + + c := NewCache() + if c.Dir() != dir { + t.Fatalf("expected dir %s from env, got %s", dir, c.Dir()) + } +} + +func TestCacheManualEvict(t *testing.T) { + dir := t.TempDir() + c := NewCache(WithCacheDir(dir), WithMaxCacheSize(10)) + + // Put 30 bytes, which exceeds budget. + k1 := Key("evict1", "cpu") + k2 := Key("evict2", "cpu") + if err := c.Put(k1, make([]byte, 15)); err != nil { + t.Fatal(err) + } + // Backdate k1. + old := time.Now().Add(-2 * time.Second) + _ = os.Chtimes(c.entryPath(k1), old, old) + + if err := c.Put(k2, make([]byte, 15)); err != nil { + t.Fatal(err) + } + + // After Put(k2), eviction already ran. But let's also test manual Evict(). + c.maxSize = 5 + c.Evict() + + stats := c.Stats() + if stats.Size > 5 { + t.Fatalf("expected size <= 5 after manual evict, got %d", stats.Size) + } +}