Skip to content

Commit be0c1b6

Browse files
committed
Use slices.Chunk and rename values func
1 parent f45614e commit be0c1b6

1 file changed

Lines changed: 11 additions & 9 deletions

File tree

batch.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"slices"
78
"strings"
89
)
910

@@ -19,22 +20,23 @@ const BatchSize = 100
1920
// Batch inserts a slice of items into the given table using multi-row INSERT
2021
// statements. Items are inserted in chunks of [BatchSize]. The columns
2122
// parameter specifies the column names, and the values function maps each item to
22-
// its column values. The length of the slice returned by values must match the
23+
// its column values. The length of the slice returned by extractValues must match the
2324
// length of columns. Batch does nothing if items is empty. Table, columns and onConfict
24-
// are not sanitized; they must come from a trusted source. The values function will
25+
// are not sanitized; they must come from a trusted source. The extractValues function will
2526
// never be called concurrently.
26-
func Batch[T any](ctx context.Context, exec Executor, table string, columns []string, onConflict string, items []T, values func(T) []any) error {
27-
for i := 0; i < len(items); i += BatchSize {
28-
end := min(i+BatchSize, len(items))
29-
query, args := batchQuery(table, columns, onConflict, items[i:end], values)
27+
func Batch[T any](ctx context.Context, exec Executor, table string, columns []string, onConflict string, items []T, extractValues func(T) []any) error {
28+
batch := 0
29+
for chunk := range slices.Chunk(items, BatchSize) {
30+
query, args := batchQuery(table, columns, onConflict, chunk, extractValues)
3031
if _, err := exec.ExecContext(ctx, query, args...); err != nil {
31-
return err
32+
return fmt.Errorf("batch %d (%d items) failed: %w", batch, len(chunk), err)
3233
}
34+
batch++
3335
}
3436
return nil
3537
}
3638

37-
func batchQuery[T any](table string, columns []string, onConflict string, items []T, values func(T) []any) (string, []any) {
39+
func batchQuery[T any](table string, columns []string, onConflict string, items []T, extractValues func(T) []any) (string, []any) {
3840
ncols := len(columns)
3941
args := make([]any, 0, len(items)*ncols)
4042

@@ -46,7 +48,7 @@ func batchQuery[T any](table string, columns []string, onConflict string, items
4648
b.WriteString(", ")
4749
}
4850
b.WriteByte('(')
49-
vals := values(item)
51+
vals := extractValues(item)
5052
for j, v := range vals {
5153
if j > 0 {
5254
b.WriteString(", ")

0 commit comments

Comments
 (0)