44 "context"
55 "database/sql"
66 "fmt"
7+ "slices"
78 "strings"
89)
910
@@ -19,22 +20,23 @@ const BatchSize = 100
1920// Batch inserts a slice of items into the given table using multi-row INSERT
2021// statements. Items are inserted in chunks of [BatchSize]. The columns
2122// parameter specifies the column names, and the values function maps each item to
22- // its column values. The length of the slice returned by values must match the
23+ // its column values. The length of the slice returned by extractValues must match the
2324// length of columns. Batch does nothing if items is empty. Table, columns and onConfict
24- // are not sanitized; they must come from a trusted source. The values function will
25+ // are not sanitized; they must come from a trusted source. The extractValues function will
2526// never be called concurrently.
26- func Batch [T any ](ctx context.Context , exec Executor , table string , columns []string , onConflict string , items []T , values func (T ) []any ) error {
27- for i := 0 ; i < len ( items ); i += BatchSize {
28- end := min ( i + BatchSize , len (items ))
29- query , args := batchQuery (table , columns , onConflict , items [ i : end ], values )
27+ func Batch [T any ](ctx context.Context , exec Executor , table string , columns []string , onConflict string , items []T , extractValues func (T ) []any ) error {
28+ batch := 0
29+ for chunk := range slices . Chunk (items , BatchSize ) {
30+ query , args := batchQuery (table , columns , onConflict , chunk , extractValues )
3031 if _ , err := exec .ExecContext (ctx , query , args ... ); err != nil {
31- return err
32+ return fmt . Errorf ( "batch %d (%d items) failed: %w" , batch , len ( chunk ), err )
3233 }
34+ batch ++
3335 }
3436 return nil
3537}
3638
37- func batchQuery [T any ](table string , columns []string , onConflict string , items []T , values func (T ) []any ) (string , []any ) {
39+ func batchQuery [T any ](table string , columns []string , onConflict string , items []T , extractValues func (T ) []any ) (string , []any ) {
3840 ncols := len (columns )
3941 args := make ([]any , 0 , len (items )* ncols )
4042
@@ -46,7 +48,7 @@ func batchQuery[T any](table string, columns []string, onConflict string, items
4648 b .WriteString (", " )
4749 }
4850 b .WriteByte ('(' )
49- vals := values (item )
51+ vals := extractValues (item )
5052 for j , v := range vals {
5153 if j > 0 {
5254 b .WriteString (", " )
0 commit comments