Skip to content

Commit 160cf06

Browse files
authored
Add Support for Querying Against Read Replicas (#70)
1 parent 506b16d commit 160cf06

6 files changed

Lines changed: 519 additions & 27 deletions

File tree

main_test.go

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ const (
2020
postgresImage = "docker.io/postgres:16.0-alpine"
2121
)
2222

23-
var postgresDataSource string
23+
// Connection strings for the three databases created in [TestMain]
24+
// These are used in other tests
25+
var (
26+
postgresDataSource string
27+
postgresDataSourceRR1 string
28+
postgresDataSourceRR2 string
29+
)
2430

2531
func withSchemaSQL() testcontainers.CustomizeRequestOption {
2632
return func(req *testcontainers.GenericContainerRequest) error {
@@ -34,24 +40,74 @@ func withSchemaSQL() testcontainers.CustomizeRequestOption {
3440
}
3541

3642
func TestMain(m *testing.M) {
37-
var cleanups []func()
38-
cleanup := func(fn func()) {
39-
cleanups = append(cleanups, fn)
40-
}
41-
fatal := func(args ...any) {
42-
fmt.Fprintln(os.Stderr, args...)
43-
for _, fn := range cleanups {
44-
fn()
43+
ctx := context.Background()
44+
cleanups, err := createPostgresContainers(ctx)
45+
if err != nil {
46+
fmt.Printf("did not create postgres containers: %v\n", err)
47+
for _, cleanup := range cleanups {
48+
if cleanup != nil {
49+
cleanup()
50+
}
4551
}
4652
os.Exit(1)
4753
}
4854

49-
ctx := context.Background()
55+
// nolint:gocritic // The docs for Run specify that the returned int is to be passed to os.Exit
56+
retCode := m.Run()
57+
for _, cleanup := range cleanups {
58+
if cleanup != nil {
59+
cleanup()
60+
}
61+
}
62+
os.Exit(retCode)
63+
}
64+
65+
// createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and
66+
// the last 2 are intended to mimic read replicas.
67+
// A []func() is returned to cleanups.
68+
// Package-level connection strings are set.
69+
func createPostgresContainers(ctx context.Context) ([]func(), error) {
70+
// Database connection strings are the same, except the port
71+
connString := func(port string) string {
72+
return fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, port, dbName)
73+
}
74+
75+
var cleanups []func()
76+
77+
// Create the master database
78+
cM, mpM, err := createPostgresContainer(ctx, "init-db.sh")
79+
cleanups = append(cleanups, cM)
80+
if err != nil {
81+
return cleanups, err
82+
}
83+
postgresDataSource = connString(mpM)
84+
85+
// Create first read replica
86+
cRR1, mpRR1, err := createPostgresContainer(ctx, "init-db-rr.sh")
87+
cleanups = append(cleanups, cRR1)
88+
if err != nil {
89+
return cleanups, err
90+
}
91+
postgresDataSourceRR1 = connString(mpRR1)
92+
93+
// Create second read replica
94+
cRR2, mpRR2, err := createPostgresContainer(ctx, "init-db-rr.sh")
95+
cleanups = append(cleanups, cRR2)
96+
if err != nil {
97+
return cleanups, err
98+
}
99+
postgresDataSourceRR2 = connString(mpRR2)
100+
101+
return cleanups, nil
102+
}
103+
104+
// createPostgresContainer creates a single postgres container on the specified port
105+
func createPostgresContainer(ctx context.Context, initFilename string) (func(), string, error) {
50106
postgresContainer, err := postgres.Run(ctx, postgresImage,
51107
postgres.WithDatabase(dbName),
52108
postgres.WithUsername(dbUser),
53109
postgres.WithPassword(dbPassword),
54-
postgres.WithInitScripts(filepath.Join("testdata", "init-db.sh")),
110+
postgres.WithInitScripts(filepath.Join("testdata", initFilename)),
55111
withSchemaSQL(),
56112
testcontainers.WithWaitStrategy(
57113
wait.ForLog("database system is ready to accept connections").
@@ -60,28 +116,27 @@ func TestMain(m *testing.M) {
60116
),
61117
)
62118
if err != nil {
63-
fatal("error creating postgres container:", err)
119+
return nil, "", fmt.Errorf("error creating postgres container: %w", err)
64120
}
65-
cleanup(func() {
121+
122+
cleanup := func() {
66123
if err := postgresContainer.Terminate(ctx); err != nil {
67124
fmt.Fprintln(os.Stderr, "error terminating postgres:", err)
68125
}
69-
})
126+
}
70127

71128
postgresState, err := postgresContainer.State(ctx)
72129
if err != nil {
73-
fatal(err)
130+
return cleanup, "", fmt.Errorf("checking container state: %w", err)
74131
}
75132
if !postgresState.Running {
76-
fatal("Postgres status:", postgresState.Status)
133+
return cleanup, "", fmt.Errorf("Postgres status %q is not \"running\"", postgresState.Status)
77134
}
78135

79-
postgresPort, err := postgresContainer.MappedPort(ctx, "5432/tcp")
136+
mp, err := postgresContainer.MappedPort(ctx, "5432/tcp")
80137
if err != nil {
81-
fatal(err)
138+
return cleanup, "", fmt.Errorf("mapped port 5432/tcp does not seem to be available: %w", err)
82139
}
83140

84-
postgresDataSource = fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, postgresPort.Port(), dbName)
85-
86-
os.Exit(m.Run())
141+
return cleanup, mp.Port(), nil
87142
}

read_replica_set.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package sequel
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
"sync"
9+
10+
"github.com/go-sqlx/sqlx"
11+
)
12+
13+
var ErrNoReadReplicaConnections = errors.New("no read replica connections available")
14+
15+
type readReplica struct {
16+
db *sqlx.DB
17+
18+
next *readReplica
19+
}
20+
21+
// ReadReplicaSet contains a set of DB connections. It is intended to give fair round robin access
22+
// through a circular singly linked list.
23+
//
24+
// Replicas are appended after the current one. The intended use is to build the replica set before
25+
// querying, but all operations are concurrent-safe.
26+
type ReadReplicaSet struct {
27+
m sync.Mutex
28+
29+
current *readReplica
30+
}
31+
32+
// add adds a DB to the collection of read replicas
33+
func (rr *ReadReplicaSet) add(db *sqlx.DB) {
34+
rr.m.Lock()
35+
defer rr.m.Unlock()
36+
37+
r := readReplica{
38+
db: db,
39+
}
40+
41+
// Empty ring, add new DB
42+
if rr.current == nil {
43+
r.next = &r
44+
45+
rr.current = &r
46+
47+
return
48+
}
49+
50+
// Insert new db after current
51+
n := rr.current.next
52+
r.next = n
53+
rr.current.next = &r
54+
}
55+
56+
// next returns the current DB. The current pointer is advanced.
57+
func (rr *ReadReplicaSet) next() (*sqlx.DB, error) {
58+
rr.m.Lock()
59+
defer rr.m.Unlock()
60+
61+
if rr.current == nil {
62+
return nil, ErrNoReadReplicaConnections
63+
}
64+
65+
c := rr.current
66+
67+
rr.current = rr.current.next
68+
69+
return c.db, nil
70+
}
71+
72+
// Close closes all read replica connections
73+
func (rr *ReadReplicaSet) Close() {
74+
if rr == nil {
75+
return
76+
}
77+
78+
rr.m.Lock()
79+
defer rr.m.Unlock()
80+
81+
// If this instance has no replicas, current would be nil
82+
if rr.current == nil {
83+
return
84+
}
85+
86+
first := rr.current
87+
for c := first; ; c = c.next {
88+
c.db.Close()
89+
90+
if c.next == first {
91+
break
92+
}
93+
}
94+
}
95+
96+
// Query executes a query against a read replica. Queries that are not SELECTs may not work.
97+
// The args are for any placeholder parameters in the query.
98+
func (rr *ReadReplicaSet) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
99+
if rr == nil {
100+
return nil, ErrNoReadReplicaConnections
101+
}
102+
103+
db, err := rr.next()
104+
if err != nil {
105+
return nil, fmt.Errorf("did not get read replica connection: %w", err)
106+
}
107+
108+
return db.QueryContext(ctx, query, args...)
109+
}
110+
111+
// QueryRow executes a query that is expected to return at most one row against a read replica.
112+
// QueryRowContext always returns a non-nil value. Errors are deferred until
113+
// Row's Scan method is called.
114+
//
115+
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
116+
// Otherwise, the *Row's Scan scans the first selected row and discards the
117+
// rest.
118+
func (rr *ReadReplicaSet) QueryRow(ctx context.Context, query string, args ...any) (*sql.Row, error) {
119+
if rr == nil {
120+
return nil, ErrNoReadReplicaConnections
121+
}
122+
123+
db, err := rr.next()
124+
if err != nil {
125+
return nil, fmt.Errorf("did not get read replica connection: %w", err)
126+
}
127+
128+
return db.QueryRowContext(ctx, query, args...), nil
129+
}
130+
131+
// Get populates the given model for the result of the given select query against a read replica.
132+
func (rr *ReadReplicaSet) Get(ctx context.Context, dest Model, query string, args ...any) error {
133+
if rr == nil {
134+
return ErrNoReadReplicaConnections
135+
}
136+
137+
db, err := rr.next()
138+
if err != nil {
139+
return fmt.Errorf("did not get read replica connection: %w", err)
140+
}
141+
142+
return db.GetContext(ctx, dest, query, args...)
143+
}
144+
145+
// GetAll populates the given destination with all the results of the given
146+
// select query (from a read replica). The method will fail if the destination is not a pointer to a
147+
// slice.
148+
func (rr *ReadReplicaSet) GetAll(ctx context.Context, dest any, query string, args ...any) error {
149+
if rr == nil {
150+
return ErrNoReadReplicaConnections
151+
}
152+
153+
db, err := rr.next()
154+
if err != nil {
155+
return fmt.Errorf("did not get read replica connection: %w", err)
156+
}
157+
158+
rows, err := db.QueryContext(ctx, query, args...)
159+
if err != nil {
160+
return err
161+
}
162+
163+
defer rows.Close()
164+
165+
if err := rows.Err(); err != nil {
166+
return err
167+
}
168+
return sqlx.StructScan(rows, dest)
169+
}

0 commit comments

Comments
 (0)