Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions pkg/memory/database/backup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package database

import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"

atomic "github.com/natefinch/atomic"

"github.com/docker/docker-agent/pkg/sqliteutil"
)

// ExportSnapshot writes a consistent SQLite snapshot of dbPath to finalPath.
//
// The snapshot is written to a temp file in finalPath's directory and then
// renamed into place, so readers of finalPath see either the previous snapshot
// or the complete new snapshot. The source memory DB lock is held while the
// snapshot is created to serialize it with memory writes.
func ExportSnapshot(ctx context.Context, dbPath, finalPath string) error {
if ctx == nil {
ctx = context.Background()
}

lock := NewFileLock(LockPathForDatabase(dbPath))
if err := lock.Lock(ctx); err != nil {
return err
}
defer func() { _ = lock.Unlock() }()

db, err := sqliteutil.OpenDB(dbPath)
if err != nil {
return err
}
defer db.Close()

dir := filepath.Dir(finalPath)
if err := os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("creating memory snapshot directory %q: %w", dir, err)
}

tmp, err := os.CreateTemp(dir, ".mem_backup_*.db.tmp")
if err != nil {
return fmt.Errorf("creating temp memory snapshot: %w", err)
}
tmpName := tmp.Name()
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpName)
return fmt.Errorf("closing temp memory snapshot: %w", err)
}
if err := os.Remove(tmpName); err != nil {
return fmt.Errorf("removing empty temp memory snapshot: %w", err)
}
defer os.Remove(tmpName)

if err := vacuumInto(ctx, db, tmpName); err != nil {
return err
}
if err := syncFile(tmpName); err != nil {
return err
}

if err := atomic.ReplaceFile(tmpName, finalPath); err != nil {
return fmt.Errorf("publishing memory snapshot %q: %w", finalPath, err)
}

syncDir(dir)
return nil
}

func vacuumInto(ctx context.Context, db *sql.DB, path string) error {
stmt := "VACUUM INTO " + sqliteString(path)

Check failure on line 74 in pkg/memory/database/backup.go

View workflow job for this annotation

GitHub Actions / lint

G202: SQL string concatenation (gosec)
if _, err := db.ExecContext(ctx, stmt); err != nil {
return fmt.Errorf("exporting memory snapshot: %w", err)
}
return nil
}

func sqliteString(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}

func syncFile(path string) error {
f, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("opening memory snapshot for sync: %w", err)
}
defer f.Close()
if err := f.Sync(); err != nil {
return fmt.Errorf("syncing memory snapshot: %w", err)
}
return nil
}

func syncDir(dir string) {
d, err := os.Open(dir)
if err != nil {
return
}
defer d.Close()
_ = d.Sync()
}
109 changes: 109 additions & 0 deletions pkg/memory/database/backup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package database_test

import (
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/memory/database"
"github.com/docker/docker-agent/pkg/memory/database/sqlite"
)

func TestExportSnapshotCreatesReadableBackup(t *testing.T) {
sourcePath := filepath.Join(t.TempDir(), "memory.db")
db, err := sqlite.NewMemoryDatabase(sourcePath)
require.NoError(t, err)
source := db.(*sqlite.MemoryDatabase)
defer source.Close()

require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{
ID: "one",
CreatedAt: time.Now().Format(time.RFC3339),
Memory: "remember this",
Category: "fact",
}))

backupPath := filepath.Join(t.TempDir(), "memory.bak")
require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath))

backupDB, err := sqlite.NewMemoryDatabase(backupPath)
require.NoError(t, err)
backup := backupDB.(*sqlite.MemoryDatabase)
defer backup.Close()

memories, err := backupDB.GetMemories(t.Context())
require.NoError(t, err)
require.Len(t, memories, 1)
assert.Equal(t, "one", memories[0].ID)
assert.Equal(t, "remember this", memories[0].Memory)
assert.Equal(t, "fact", memories[0].Category)
}

func TestExportSnapshotIncludesOpenWALWrites(t *testing.T) {
sourcePath := filepath.Join(t.TempDir(), "memory.db")
db, err := sqlite.NewMemoryDatabase(sourcePath)
require.NoError(t, err)
source := db.(*sqlite.MemoryDatabase)
defer source.Close()

for i := 0; i < 10; i++ {

Check failure on line 52 in pkg/memory/database/backup_test.go

View workflow job for this annotation

GitHub Actions / lint

for loop can be changed to use an integer range (Go 1.22+) (intrange)
require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{
ID: string(rune('a' + i)),
CreatedAt: time.Now().Format(time.RFC3339),
Memory: "open db write",
}))
}

backupPath := filepath.Join(t.TempDir(), "memory.bak")
require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath))

backupDB, err := sqlite.NewMemoryDatabase(backupPath)
require.NoError(t, err)
backup := backupDB.(*sqlite.MemoryDatabase)
defer backup.Close()

memories, err := backupDB.GetMemories(t.Context())
require.NoError(t, err)
assert.Len(t, memories, 10)
}

func TestExportSnapshotOverwritesAtomicallyAndCleansTemp(t *testing.T) {
sourcePath := filepath.Join(t.TempDir(), "memory.db")
db, err := sqlite.NewMemoryDatabase(sourcePath)
require.NoError(t, err)
source := db.(*sqlite.MemoryDatabase)
defer source.Close()

require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{
ID: "first",
CreatedAt: time.Now().Format(time.RFC3339),
Memory: "first snapshot",
}))

backupDir := t.TempDir()
backupPath := filepath.Join(backupDir, "memory.bak")
require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath))

require.NoError(t, db.AddMemory(t.Context(), database.UserMemory{
ID: "second",
CreatedAt: time.Now().Format(time.RFC3339),
Memory: "second snapshot",
}))
require.NoError(t, database.ExportSnapshot(t.Context(), sourcePath, backupPath))

matches, err := filepath.Glob(filepath.Join(backupDir, ".mem_backup_*.db.tmp"))
require.NoError(t, err)
assert.Empty(t, matches)

backupDB, err := sqlite.NewMemoryDatabase(backupPath)
require.NoError(t, err)
backup := backupDB.(*sqlite.MemoryDatabase)
defer backup.Close()

memories, err := backupDB.GetMemories(t.Context())
require.NoError(t, err)
assert.Len(t, memories, 2)
}
153 changes: 153 additions & 0 deletions pkg/memory/database/lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package database

import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)

const lockRetryInterval = 10 * time.Millisecond

type processLock chan struct{}

func newProcessLock() processLock {
lock := make(processLock, 1)
lock <- struct{}{}
return lock
}

func (l processLock) Lock(ctx context.Context) error {
select {
case <-l:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (l processLock) Unlock() {
select {
case l <- struct{}{}:
default:
}
}

// FileLock is an advisory file lock for coordinating memory database writes
// across docker-agent processes.
//
// The lock file is intentionally never deleted. Keeping a stable sentinel file
// avoids a race where different processes lock different inodes for the same
// logical database.
type FileLock struct {
path string
file *os.File
processLock processLock
mu sync.Mutex
}

var processLocks sync.Map

// NewFileLock returns a lock using path as its persistent sentinel file.
func NewFileLock(path string) *FileLock {
absPath, err := filepath.Abs(path)
if err != nil {
absPath = path
}
processLockValue, _ := processLocks.LoadOrStore(absPath, newProcessLock())
return &FileLock{
path: absPath,
processLock: processLockValue.(processLock),
}
}

// LockPathForDatabase returns the companion lock-file path for a memory DB.
func LockPathForDatabase(dbPath string) string {
return filepath.Join(filepath.Dir(dbPath), "memory.lock")
}

// Lock blocks until the exclusive advisory lock is acquired or ctx is canceled.
func (l *FileLock) Lock(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return err
}

l.mu.Lock()
defer l.mu.Unlock()

if l.file != nil {
return nil
}

if err := l.lockProcess(ctx); err != nil {
return err
}
processLocked := true
defer func() {
if processLocked {
l.processLock.Unlock()
}
}()

if err := os.MkdirAll(filepath.Dir(l.path), 0o700); err != nil {
return fmt.Errorf("creating memory lock directory %q: %w", filepath.Dir(l.path), err)
}

f, err := os.OpenFile(l.path, os.O_RDWR|os.O_CREATE, 0o600)
if err != nil {
return fmt.Errorf("opening memory lock file %q: %w", l.path, err)
}

for {
err = lockFileExclusive(f)
if err == nil {
l.file = f
processLocked = false
return nil
}
if !isLockUnavailable(err) {
_ = f.Close()
return fmt.Errorf("locking memory lock file %q: %w", l.path, err)
}

select {
case <-ctx.Done():
_ = f.Close()
return ctx.Err()
case <-time.After(lockRetryInterval):
}
}
}

func (l *FileLock) lockProcess(ctx context.Context) error {
return l.processLock.Lock(ctx)
}

// Unlock releases the advisory lock and closes the sentinel file descriptor.
func (l *FileLock) Unlock() error {
l.mu.Lock()
defer l.mu.Unlock()

if l.file == nil {
return nil
}

f := l.file
l.file = nil

unlockErr := unlockFile(f)
closeErr := f.Close()
l.processLock.Unlock()
if unlockErr != nil {
return fmt.Errorf("unlocking memory lock file %q: %w", l.path, unlockErr)
}
if closeErr != nil {
return fmt.Errorf("closing memory lock file %q: %w", l.path, closeErr)
}
return nil
}
17 changes: 17 additions & 0 deletions pkg/memory/database/lock_js.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build js && wasm

package database

import "os"

func lockFileExclusive(_ *os.File) error {
return nil
}

func unlockFile(_ *os.File) error {
return nil
}

func isLockUnavailable(error) bool {
return false
}
Loading
Loading