diff --git a/pkg/memory/database/backup.go b/pkg/memory/database/backup.go new file mode 100644 index 000000000..927b78c08 --- /dev/null +++ b/pkg/memory/database/backup.go @@ -0,0 +1,104 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + + atomic "github.com/natefinch/atomic" + + "github.com/docker/docker-agent/pkg/sqliteutil" +) + +// ExportSnapshot writes a consistent SQLite snapshot of dbPath to finalPath. +// +// The snapshot is written to a temp file in finalPath's directory and then +// renamed into place, so readers of finalPath see either the previous snapshot +// or the complete new snapshot. The source memory DB lock is held while the +// snapshot is created to serialize it with memory writes. +func ExportSnapshot(ctx context.Context, dbPath, finalPath string) error { + if ctx == nil { + ctx = context.Background() + } + + lock := NewFileLock(LockPathForDatabase(dbPath)) + if err := lock.Lock(ctx); err != nil { + return err + } + defer func() { _ = lock.Unlock() }() + + db, err := sqliteutil.OpenDB(dbPath) + if err != nil { + return err + } + defer db.Close() + + dir := filepath.Dir(finalPath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("creating memory snapshot directory %q: %w", dir, err) + } + + tmp, err := os.CreateTemp(dir, ".mem_backup_*.db.tmp") + if err != nil { + return fmt.Errorf("creating temp memory snapshot: %w", err) + } + tmpName := tmp.Name() + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpName) + return fmt.Errorf("closing temp memory snapshot: %w", err) + } + if err := os.Remove(tmpName); err != nil { + return fmt.Errorf("removing empty temp memory snapshot: %w", err) + } + defer os.Remove(tmpName) + + if err := vacuumInto(ctx, db, tmpName); err != nil { + return err + } + if err := syncFile(tmpName); err != nil { + return err + } + + if err := atomic.ReplaceFile(tmpName, finalPath); err != nil { + return fmt.Errorf("publishing memory snapshot %q: %w", finalPath, err) + } + + syncDir(dir) + return nil +} + +func vacuumInto(ctx context.Context, db *sql.DB, path string) error { + stmt := "VACUUM INTO " + sqliteString(path) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("exporting memory snapshot: %w", err) + } + return nil +} + +func sqliteString(s string) string { + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + +func syncFile(path string) error { + f, err := os.OpenFile(path, os.O_RDWR, 0) + if err != nil { + return fmt.Errorf("opening memory snapshot for sync: %w", err) + } + defer f.Close() + if err := f.Sync(); err != nil { + return fmt.Errorf("syncing memory snapshot: %w", err) + } + return nil +} + +func syncDir(dir string) { + d, err := os.Open(dir) + if err != nil { + return + } + defer d.Close() + _ = d.Sync() +} diff --git a/pkg/memory/database/backup_test.go b/pkg/memory/database/backup_test.go new file mode 100644 index 000000000..1fb9a1a0d --- /dev/null +++ b/pkg/memory/database/backup_test.go @@ -0,0 +1,109 @@ +package database_test + +import ( + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/memory/database" + "github.com/docker/docker-agent/pkg/memory/database/sqlite" +) + +func TestExportSnapshotCreatesReadableBackup(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "memory.db") + db, err := sqlite.NewMemoryDatabase(sourcePath) + require.NoError(t, err) + source := db.(*sqlite.MemoryDatabase) + defer source.Close() + + require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{ + ID: "one", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "remember this", + Category: "fact", + })) + + backupPath := filepath.Join(t.TempDir(), "memory.bak") + require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath)) + + backupDB, err := sqlite.NewMemoryDatabase(backupPath) + require.NoError(t, err) + backup := backupDB.(*sqlite.MemoryDatabase) + defer backup.Close() + + memories, err := backupDB.GetMemories(t.Context()) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "one", memories[0].ID) + assert.Equal(t, "remember this", memories[0].Memory) + assert.Equal(t, "fact", memories[0].Category) +} + +func TestExportSnapshotIncludesOpenWALWrites(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "memory.db") + db, err := sqlite.NewMemoryDatabase(sourcePath) + require.NoError(t, err) + source := db.(*sqlite.MemoryDatabase) + defer source.Close() + + for i := 0; i < 10; i++ { + require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{ + ID: string(rune('a' + i)), + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "open db write", + })) + } + + backupPath := filepath.Join(t.TempDir(), "memory.bak") + require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath)) + + backupDB, err := sqlite.NewMemoryDatabase(backupPath) + require.NoError(t, err) + backup := backupDB.(*sqlite.MemoryDatabase) + defer backup.Close() + + memories, err := backupDB.GetMemories(t.Context()) + require.NoError(t, err) + assert.Len(t, memories, 10) +} + +func TestExportSnapshotOverwritesAtomicallyAndCleansTemp(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "memory.db") + db, err := sqlite.NewMemoryDatabase(sourcePath) + require.NoError(t, err) + source := db.(*sqlite.MemoryDatabase) + defer source.Close() + + require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{ + ID: "first", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "first snapshot", + })) + + backupDir := t.TempDir() + backupPath := filepath.Join(backupDir, "memory.bak") + require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath)) + + require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{ + ID: "second", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "second snapshot", + })) + require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath)) + + matches, err := filepath.Glob(filepath.Join(backupDir, ".mem_backup_*.db.tmp")) + require.NoError(t, err) + assert.Empty(t, matches) + + backupDB, err := sqlite.NewMemoryDatabase(backupPath) + require.NoError(t, err) + backup := backupDB.(*sqlite.MemoryDatabase) + defer backup.Close() + + memories, err := backupDB.GetMemories(t.Context()) + require.NoError(t, err) + assert.Len(t, memories, 2) +} diff --git a/pkg/memory/database/lock.go b/pkg/memory/database/lock.go new file mode 100644 index 000000000..2f0678ff4 --- /dev/null +++ b/pkg/memory/database/lock.go @@ -0,0 +1,153 @@ +package database + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +const lockRetryInterval = 10 * time.Millisecond + +type processLock chan struct{} + +func newProcessLock() processLock { + lock := make(processLock, 1) + lock <- struct{}{} + return lock +} + +func (l processLock) Lock(ctx context.Context) error { + select { + case <-l: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (l processLock) Unlock() { + select { + case l <- struct{}{}: + default: + } +} + +// FileLock is an advisory file lock for coordinating memory database writes +// across docker-agent processes. +// +// The lock file is intentionally never deleted. Keeping a stable sentinel file +// avoids a race where different processes lock different inodes for the same +// logical database. +type FileLock struct { + path string + file *os.File + processLock processLock + mu sync.Mutex +} + +var processLocks sync.Map + +// NewFileLock returns a lock using path as its persistent sentinel file. +func NewFileLock(path string) *FileLock { + absPath, err := filepath.Abs(path) + if err != nil { + absPath = path + } + processLockValue, _ := processLocks.LoadOrStore(absPath, newProcessLock()) + return &FileLock{ + path: absPath, + processLock: processLockValue.(processLock), + } +} + +// LockPathForDatabase returns the companion lock-file path for a memory DB. +func LockPathForDatabase(dbPath string) string { + return filepath.Join(filepath.Dir(dbPath), "memory.lock") +} + +// Lock blocks until the exclusive advisory lock is acquired or ctx is canceled. +func (l *FileLock) Lock(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return err + } + + l.mu.Lock() + defer l.mu.Unlock() + + if l.file != nil { + return nil + } + + if err := l.lockProcess(ctx); err != nil { + return err + } + processLocked := true + defer func() { + if processLocked { + l.processLock.Unlock() + } + }() + + if err := os.MkdirAll(filepath.Dir(l.path), 0o700); err != nil { + return fmt.Errorf("creating memory lock directory %q: %w", filepath.Dir(l.path), err) + } + + f, err := os.OpenFile(l.path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return fmt.Errorf("opening memory lock file %q: %w", l.path, err) + } + + for { + err = lockFileExclusive(f) + if err == nil { + l.file = f + processLocked = false + return nil + } + if !isLockUnavailable(err) { + _ = f.Close() + return fmt.Errorf("locking memory lock file %q: %w", l.path, err) + } + + select { + case <-ctx.Done(): + _ = f.Close() + return ctx.Err() + case <-time.After(lockRetryInterval): + } + } +} + +func (l *FileLock) lockProcess(ctx context.Context) error { + return l.processLock.Lock(ctx) +} + +// Unlock releases the advisory lock and closes the sentinel file descriptor. +func (l *FileLock) Unlock() error { + l.mu.Lock() + defer l.mu.Unlock() + + if l.file == nil { + return nil + } + + f := l.file + l.file = nil + + unlockErr := unlockFile(f) + closeErr := f.Close() + l.processLock.Unlock() + if unlockErr != nil { + return fmt.Errorf("unlocking memory lock file %q: %w", l.path, unlockErr) + } + if closeErr != nil { + return fmt.Errorf("closing memory lock file %q: %w", l.path, closeErr) + } + return nil +} diff --git a/pkg/memory/database/lock_js.go b/pkg/memory/database/lock_js.go new file mode 100644 index 000000000..43896a661 --- /dev/null +++ b/pkg/memory/database/lock_js.go @@ -0,0 +1,17 @@ +//go:build js && wasm + +package database + +import "os" + +func lockFileExclusive(_ *os.File) error { + return nil +} + +func unlockFile(_ *os.File) error { + return nil +} + +func isLockUnavailable(error) bool { + return false +} diff --git a/pkg/memory/database/lock_test.go b/pkg/memory/database/lock_test.go new file mode 100644 index 000000000..d5a5132ee --- /dev/null +++ b/pkg/memory/database/lock_test.go @@ -0,0 +1,76 @@ +package database + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFileLockRoundTripPersistsLockFile(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "memory.lock") + lock := NewFileLock(lockPath) + + require.NoError(t, lock.Lock(t.Context())) + require.FileExists(t, lockPath) + require.NoError(t, lock.Unlock()) + require.FileExists(t, lockPath) + + require.NoError(t, lock.Lock(t.Context())) + require.NoError(t, lock.Unlock()) + require.FileExists(t, lockPath) +} + +func TestFileLockSerializesAcrossProcesses(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "memory.lock") + lock := NewFileLock(lockPath) + require.NoError(t, lock.Lock(t.Context())) + + cmd := exec.Command(os.Args[0], "-test.run=TestFileLockHelperProcess", "--", lockPath) + cmd.Env = append(os.Environ(), "MEMORY_LOCK_HELPER=1") + + done := make(chan error, 1) + require.NoError(t, cmd.Start()) + go func() { done <- cmd.Wait() }() + + select { + case err := <-done: + require.NoError(t, err) + t.Fatal("helper acquired the lock before the parent released it") + case <-time.After(200 * time.Millisecond): + } + + require.NoError(t, lock.Unlock()) + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(5 * time.Second): + _ = cmd.Process.Kill() + t.Fatal("helper did not acquire the lock after the parent released it") + } +} + +func TestFileLockHelperProcess(t *testing.T) { + if os.Getenv("MEMORY_LOCK_HELPER") != "1" { + return + } + args := os.Args + for i, arg := range args { + if arg == "--" && i+1 < len(args) { + lock := NewFileLock(args[i+1]) + if err := lock.Lock(context.Background()); err != nil { + os.Exit(2) + } + if err := lock.Unlock(); err != nil { + os.Exit(3) + } + os.Exit(0) + } + } + os.Exit(4) +} diff --git a/pkg/memory/database/lock_unix.go b/pkg/memory/database/lock_unix.go new file mode 100644 index 000000000..177c6b3e7 --- /dev/null +++ b/pkg/memory/database/lock_unix.go @@ -0,0 +1,36 @@ +//go:build unix + +package database + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +// lockFileExclusive attempts to acquire an exclusive advisory lock without +// blocking. The retry loop in FileLock.Lock handles waiting and cancellation. +func lockFileExclusive(f *os.File) error { + lock := unix.Flock_t{ + Type: unix.F_WRLCK, + Whence: int16(os.SEEK_SET), + Start: 0, + Len: 0, + } + return unix.FcntlFlock(f.Fd(), unix.F_SETLK, &lock) +} + +func unlockFile(f *os.File) error { + lock := unix.Flock_t{ + Type: unix.F_UNLCK, + Whence: int16(os.SEEK_SET), + Start: 0, + Len: 0, + } + return unix.FcntlFlock(f.Fd(), unix.F_SETLK, &lock) +} + +func isLockUnavailable(err error) bool { + return errors.Is(err, unix.EACCES) || errors.Is(err, unix.EAGAIN) +} diff --git a/pkg/memory/database/lock_windows.go b/pkg/memory/database/lock_windows.go new file mode 100644 index 000000000..028ba07fc --- /dev/null +++ b/pkg/memory/database/lock_windows.go @@ -0,0 +1,39 @@ +//go:build windows + +package database + +import ( + "errors" + "os" + + "golang.org/x/sys/windows" +) + +const maxLockRange = ^uint32(0) + +func lockFileExclusive(f *os.File) error { + var ol windows.Overlapped + return windows.LockFileEx( + windows.Handle(f.Fd()), + windows.LOCKFILE_EXCLUSIVE_LOCK|windows.LOCKFILE_FAIL_IMMEDIATELY, + 0, + maxLockRange, + maxLockRange, + &ol, + ) +} + +func unlockFile(f *os.File) error { + var ol windows.Overlapped + return windows.UnlockFileEx( + windows.Handle(f.Fd()), + 0, + maxLockRange, + maxLockRange, + &ol, + ) +} + +func isLockUnavailable(err error) bool { + return errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_SHARING_VIOLATION) +} diff --git a/pkg/memory/database/sqlite/sqlite.go b/pkg/memory/database/sqlite/sqlite.go index 3f96b4a3f..dea9b03cd 100644 --- a/pkg/memory/database/sqlite/sqlite.go +++ b/pkg/memory/database/sqlite/sqlite.go @@ -11,7 +11,9 @@ import ( ) type MemoryDatabase struct { - db *sql.DB + db *sql.DB + path string + lockPath string } func NewMemoryDatabase(path string) (database.Database, error) { @@ -22,6 +24,14 @@ func NewMemoryDatabase(path string) (database.Database, error) { // Ensure we close the connection if table creation fails // Note: We don't defer close here because we return the db on success + lockPath := database.LockPathForDatabase(path) + lock := database.NewFileLock(lockPath) + if err := lock.Lock(context.Background()); err != nil { + db.Close() + return nil, err + } + defer func() { _ = lock.Unlock() }() + _, err = db.ExecContext(context.Background(), "CREATE TABLE IF NOT EXISTS memories (id TEXT PRIMARY KEY, created_at TEXT, memory TEXT)") if err != nil { db.Close() @@ -36,16 +46,31 @@ func NewMemoryDatabase(path string) (database.Database, error) { } } - return &MemoryDatabase{db: db}, nil + return &MemoryDatabase{db: db, path: path, lockPath: lockPath}, nil +} + +func (m *MemoryDatabase) Close() error { + return sqliteutil.CheckpointAndClose(m.db) +} + +func (m *MemoryDatabase) withWriteLock(ctx context.Context, fn func() error) error { + lock := database.NewFileLock(m.lockPath) + if err := lock.Lock(ctx); err != nil { + return err + } + defer func() { _ = lock.Unlock() }() + return fn() } func (m *MemoryDatabase) AddMemory(ctx context.Context, memory database.UserMemory) error { if memory.ID == "" { return database.ErrEmptyID } - _, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory, category) VALUES (?, ?, ?, ?)", - memory.ID, memory.CreatedAt, memory.Memory, memory.Category) - return err + return m.withWriteLock(ctx, func() error { + _, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory, category) VALUES (?, ?, ?, ?)", + memory.ID, memory.CreatedAt, memory.Memory, memory.Category) + return err + }) } func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory, error) { @@ -73,8 +98,10 @@ func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory } func (m *MemoryDatabase) DeleteMemory(ctx context.Context, memory database.UserMemory) error { - _, err := m.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", memory.ID) - return err + return m.withWriteLock(ctx, func() error { + _, err := m.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", memory.ID) + return err + }) } func (m *MemoryDatabase) SearchMemories(ctx context.Context, query, category string) ([]database.UserMemory, error) { @@ -130,19 +157,21 @@ func (m *MemoryDatabase) UpdateMemory(ctx context.Context, memory database.UserM return database.ErrEmptyID } - result, err := m.db.ExecContext(ctx, "UPDATE memories SET memory = ?, category = ? WHERE id = ?", - memory.Memory, memory.Category, memory.ID) - if err != nil { - return err - } + return m.withWriteLock(ctx, func() error { + result, err := m.db.ExecContext(ctx, "UPDATE memories SET memory = ?, category = ? WHERE id = ?", + memory.Memory, memory.Category, memory.ID) + if err != nil { + return err + } - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return fmt.Errorf("%w: %s", database.ErrMemoryNotFound, memory.ID) - } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return fmt.Errorf("%w: %s", database.ErrMemoryNotFound, memory.ID) + } - return nil + return nil + }) } diff --git a/pkg/memory/database/sqlite/sqlite_test.go b/pkg/memory/database/sqlite/sqlite_test.go index 1c1a15829..8946150ad 100644 --- a/pkg/memory/database/sqlite/sqlite_test.go +++ b/pkg/memory/database/sqlite/sqlite_test.go @@ -2,6 +2,9 @@ package sqlite import ( "context" + "fmt" + "path/filepath" + "sync" "testing" "time" @@ -23,7 +26,7 @@ func setupTestDB(t *testing.T) database.Database { t.Cleanup(func() { // Close connection memDB := db.(*MemoryDatabase) - memDB.db.Close() + _ = memDB.Close() }) return db @@ -286,13 +289,13 @@ func TestMigrationAddsCategory(t *testing.T) { Memory: "Old memory without category", }) require.NoError(t, err) - memDB1.db.Close() + _ = memDB1.Close() // Reopen - migration should be idempotent db2, err := NewMemoryDatabase(tmpFile) require.NoError(t, err) memDB2 := db2.(*MemoryDatabase) - defer memDB2.db.Close() + defer func() { _ = memDB2.Close() }() memories, err := db2.GetMemories(t.Context()) require.NoError(t, err) @@ -329,13 +332,26 @@ func TestDatabaseOperationsWithCanceledContext(t *testing.T) { require.Error(t, err, "UpdateMemory should fail with canceled context") } +func TestMemoryDatabaseUsesWALAndBusyTimeout(t *testing.T) { + db := setupTestDB(t) + memDB := db.(*MemoryDatabase) + + var journalMode string + require.NoError(t, memDB.db.QueryRowContext(t.Context(), "PRAGMA journal_mode").Scan(&journalMode)) + assert.Equal(t, "wal", journalMode) + + var busyTimeout int + require.NoError(t, memDB.db.QueryRowContext(t.Context(), "PRAGMA busy_timeout").Scan(&busyTimeout)) + assert.Equal(t, 5000, busyTimeout) +} + func TestDatabaseWithMultipleInstances(t *testing.T) { tmpFile := t.TempDir() + "/shared.db" db1, err := NewMemoryDatabase(tmpFile) require.NoError(t, err) defer func() { memDB := db1.(*MemoryDatabase) - memDB.db.Close() + _ = memDB.Close() }() memory := database.UserMemory{ @@ -351,7 +367,7 @@ func TestDatabaseWithMultipleInstances(t *testing.T) { require.NoError(t, err) defer func() { memDB := db2.(*MemoryDatabase) - memDB.db.Close() + _ = memDB.Close() }() memories, err := db2.GetMemories(t.Context()) @@ -360,3 +376,133 @@ func TestDatabaseWithMultipleInstances(t *testing.T) { assert.Equal(t, "shared-id", memories[0].ID) assert.Equal(t, "Shared memory", memories[0].Memory) } + +func TestConcurrentAddsPreserveAllRows(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "concurrent.db") + const workers = 8 + const perWorker = 25 + + dbs := make([]database.Database, workers) + for i := range workers { + db, err := NewMemoryDatabase(dbPath) + require.NoError(t, err) + dbs[i] = db + memDB := db.(*MemoryDatabase) + defer func() { _ = memDB.Close() }() + } + + var wg sync.WaitGroup + for worker := range workers { + worker := worker + wg.Add(1) + go func() { + defer wg.Done() + for i := range perWorker { + id := fmt.Sprintf("worker-%d-%d", worker, i) + require.NoError(t, dbs[worker].AddMemory(t.Context(), database.UserMemory{ + ID: id, + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "concurrent add", + })) + } + }() + } + wg.Wait() + + reader, err := NewMemoryDatabase(dbPath) + require.NoError(t, err) + readerDB := reader.(*MemoryDatabase) + defer func() { _ = readerDB.Close() }() + + memories, err := reader.GetMemories(t.Context()) + require.NoError(t, err) + require.Len(t, memories, workers*perWorker) + + seen := make(map[string]bool, len(memories)) + for _, memory := range memories { + seen[memory.ID] = true + } + for worker := range workers { + for i := range perWorker { + assert.True(t, seen[fmt.Sprintf("worker-%d-%d", worker, i)]) + } + } +} + +func TestConcurrentReadsDuringWrites(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "reads-writes.db") + db, err := NewMemoryDatabase(dbPath) + require.NoError(t, err) + memDB := db.(*MemoryDatabase) + defer func() { _ = memDB.Close() }() + + ctx := t.Context() + done := make(chan struct{}) + readErr := make(chan error, 1) + + go func() { + defer close(readErr) + for { + select { + case <-done: + return + default: + } + memories, err := db.GetMemories(ctx) + if err != nil { + readErr <- err + return + } + for _, memory := range memories { + if memory.ID == "" { + readErr <- fmt.Errorf("read malformed memory with empty ID") + return + } + } + } + }() + + for i := range 100 { + id := fmt.Sprintf("rw-%d", i) + require.NoError(t, db.AddMemory(ctx, database.UserMemory{ + ID: id, + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "initial", + })) + require.NoError(t, db.UpdateMemory(ctx, database.UserMemory{ + ID: id, + Memory: "updated", + })) + if i%3 == 0 { + require.NoError(t, db.DeleteMemory(ctx, database.UserMemory{ID: id})) + } + } + close(done) + if err := <-readErr; err != nil { + require.NoError(t, err) + } +} + +func TestWriteCreatesPersistentLockFile(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "locked.db") + db, err := NewMemoryDatabase(dbPath) + require.NoError(t, err) + memDB := db.(*MemoryDatabase) + defer func() { _ = memDB.Close() }() + + lockPath := database.LockPathForDatabase(dbPath) + require.FileExists(t, lockPath) + + require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{ + ID: "lock-file", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "creates lock", + })) + require.FileExists(t, lockPath) + + require.NoError(t, db.UpdateMemory(t.Context(), database.UserMemory{ + ID: "lock-file", + Memory: "preserves lock", + })) + require.FileExists(t, lockPath) +}