diff --git a/batch_writer.go b/batch_writer.go index 5ed22cd8a..550f16566 100644 --- a/batch_writer.go +++ b/batch_writer.go @@ -7,6 +7,7 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" + "github.com/go-mysql-org/go-mysql/schema" "github.com/sirupsen/logrus" ) @@ -56,14 +57,65 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { return nil } - startPaginationKeypos, err := values[0].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } + var startPaginationKeypos, endPaginationKeypos PaginationKey + var err error + + paginationColumn := batch.TableSchema().GetPaginationColumn() - endPaginationKeypos, err := values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + var startValue, endValue uint64 + startValue, err = values[0].GetUint64(batch.PaginationKeyIndex()) + if err != nil { + return err + } + endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) + if err != nil { + return err + } + startPaginationKeypos = NewUint64Key(startValue) + endPaginationKeypos = NewUint64Key(endValue) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + startValueInterface := values[0][batch.PaginationKeyIndex()] + endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()] + + getBytes := func(val interface{}) ([]byte, error) { + switch v := val.(type) { + case []byte: + return v, nil + case string: + return []byte(v), nil + default: + return nil, fmt.Errorf("expected binary/string pagination key, got %T", val) + } + } + + startValue, err := getBytes(startValueInterface) + if err != nil { + return err + } + + endValue, err := getBytes(endValueInterface) + if err != nil { + return err + } + + startPaginationKeypos = NewBinaryKey(startValue) + endPaginationKeypos = NewBinaryKey(endValue) + + default: + var startValue, endValue uint64 + startValue, err = values[0].GetUint64(batch.PaginationKeyIndex()) + if err != nil { + return err + } + endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) + if err != nil { + return err + } + startPaginationKeypos = NewUint64Key(startValue) + endPaginationKeypos = NewUint64Key(endValue) } db := batch.TableSchema().Schema @@ -78,12 +130,12 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { query, args, err := batch.AsSQLQuery(db, table) if err != nil { - return fmt.Errorf("during generating sql query at paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, err) + return fmt.Errorf("during generating sql query at paginationKey %s -> %s: %v", startPaginationKeypos.String(), endPaginationKeypos.String(), err) } stmt, err := w.stmtCache.StmtFor(w.DB, query) if err != nil { - return fmt.Errorf("during prepare query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during prepare query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } tx, err := w.DB.Begin() @@ -94,14 +146,14 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { _, err = tx.Stmt(stmt).Exec(args...) if err != nil { tx.Rollback() - return fmt.Errorf("during exec query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during exec query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } if w.InlineVerifier != nil { mismatches, err := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch, w.EnforceInlineVerification) if err != nil { tx.Rollback() - return fmt.Errorf("during fingerprint checking for paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during fingerprint checking for paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } if w.EnforceInlineVerification { @@ -119,7 +171,7 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { err = tx.Commit() if err != nil { tx.Rollback() - return fmt.Errorf("during commit near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during commit near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } // Note that the state tracker expects us the track based on the original diff --git a/compression_verifier.go b/compression_verifier.go index 0efb0fd33..eb1b55a66 100644 --- a/compression_verifier.go +++ b/compression_verifier.go @@ -49,6 +49,7 @@ func (e UnsupportedCompressionError) Error() string { type CompressionVerifier struct { logger *logrus.Entry + TableSchemaCache TableSchemaCache supportedAlgorithms map[string]struct{} tableColumnCompressions TableColumnCompressionConfig } @@ -59,32 +60,66 @@ type CompressionVerifier struct { // The GetCompressedHashes method checks if the existing table contains compressed data // and will apply the decompression algorithm to the applicable columns if necessary. // After the columns are decompressed, the hashes of the data are used to verify equality -func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) { +func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) { c.logger.WithFields(logrus.Fields{ "tag": "compression_verifier", - "table": table, + "table": tableName, }).Info("decompressing table data before verification") - tableCompression := c.tableColumnCompressions[table] + tableCompression := c.tableColumnCompressions[tableName] + + table := c.TableSchemaCache.Get(schemaName, tableName) + if table == nil { + return nil, fmt.Errorf("table %s.%s not found in schema cache", schemaName, tableName) + } + paginationColumns := table.GetPaginationColumns() // Extract the raw rows using SQL to be decompressed - rows, err := getRows(db, schema, table, paginationKeyColumn, columns, paginationKeys) + rows, err := getRows(db, schemaName, tableName, paginationColumns, columns, paginationKeys) if err != nil { return nil, err } defer rows.Close() - // Decompress applicable columns and hash the resulting column values for comparison - resultSet := make(map[uint64][]byte) + resultSet := make(map[string][]byte) + numPaginationCols := len(paginationColumns) + for rows.Next() { - rowData, err := ScanByteRow(rows, len(columns)+1) + // Scan: pagination_col1, pagination_col2, ..., data_cols... + rowData, err := ScanByteRow(rows, len(columns)+numPaginationCols) if err != nil { return nil, err } - paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64) - if err != nil { - return nil, err + // Build pagination key from columns (works for both single and composite keys) + keys := make([]PaginationKey, len(paginationColumns)) + for i, paginationColumn := range paginationColumns { + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + keys[i] = NewBinaryKey(rowData[i]) + + default: + paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + } + } + + // For single column, use the key directly; for composite, wrap in CompositeKey + var paginationKeyStr string + if len(keys) == 1 { + paginationKeyStr = keys[0].String() + } else { + paginationKeyStr = CompositeKey(keys).String() } // Decompress the applicable columns and then hash them together @@ -94,14 +129,14 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag decompressedRowData := [][]byte{} for idx, column := range columns { if algorithm, ok := tableCompression[column.Name]; ok { - // rowData contains the result of "SELECT paginationKeyColumn, * FROM ...", so idx+1 to get each column - decompressedColData, err := c.Decompress(table, column.Name, algorithm, rowData[idx+1]) + // rowData contains the result of "SELECT paginationKeyCols..., * FROM ...", so idx+numPaginationCols to get each data column + decompressedColData, err := c.Decompress(tableName, column.Name, algorithm, rowData[idx+numPaginationCols]) if err != nil { return nil, err } decompressedRowData = append(decompressedRowData, decompressedColData) } else { - decompressedRowData = append(decompressedRowData, rowData[idx+1]) + decompressedRowData = append(decompressedRowData, rowData[idx+numPaginationCols]) } } @@ -111,20 +146,20 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag return nil, err } - resultSet[paginationKey] = decompressedRowHash + resultSet[paginationKeyStr] = decompressedRowHash } metrics.Gauge( "compression_verifier_decompress_rows", float64(len(resultSet)), - []MetricTag{{"table", table}}, + []MetricTag{{"table", tableName}}, 1.0, ) logrus.WithFields(logrus.Fields{ "tag": "compression_verifier", "rows": len(resultSet), - "table": table, + "table": tableName, }).Debug("decompressed rows will be compared") return resultSet, nil @@ -192,12 +227,13 @@ func (c *CompressionVerifier) verifyConfiguredCompression(tableColumnCompression // NewCompressionVerifier first checks the map for supported compression algorithms before // initializing and returning the initialized instance. -func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig) (*CompressionVerifier, error) { +func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig, tableSchemaCache TableSchemaCache) (*CompressionVerifier, error) { supportedAlgorithms := make(map[string]struct{}) supportedAlgorithms[CompressionSnappy] = struct{}{} compressionVerifier := &CompressionVerifier{ logger: logrus.WithField("tag", "compression_verifier"), + TableSchemaCache: tableSchemaCache, supportedAlgorithms: supportedAlgorithms, tableColumnCompressions: tableColumnCompressions, } @@ -209,14 +245,75 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig return compressionVerifier, nil } -func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) { - quotedPaginationKey := QuoteField(paginationKeyColumn) - sql, args, err := rowSelector(columns, paginationKeyColumn). - From(QuotedTableNameFromString(schema, table)). - Where(sq.Eq{quotedPaginationKey: paginationKeys}). - OrderBy(quotedPaginationKey). - ToSql() - +func getRows(db *sql.DB, schemaName, table string, paginationKeyColumns []*schema.TableColumn, columns []schema.TableColumn, paginationKeys []interface{}) (*sqlorig.Rows, error) { + builder := rowSelector(columns, paginationKeyColumns). + From(QuotedTableNameFromString(schemaName, table)) + + if len(paginationKeyColumns) == 1 { + // Single column WHERE clause + quotedPaginationKey := QuoteField(paginationKeyColumns[0].Name) + builder = builder.Where(sq.Eq{quotedPaginationKey: paginationKeys}) + builder = builder.OrderBy(quotedPaginationKey) + } else { + // Composite key WHERE clause: (col1, col2) IN ((?, ?), (?, ?), ...) + quotedPKCols := make([]string, len(paginationKeyColumns)) + for i, col := range paginationKeyColumns { + quotedPKCols[i] = QuoteField(col.Name) + } + tuple := fmt.Sprintf("(%s)", strings.Join(quotedPKCols, ", ")) + + // Build placeholder tuples for each pagination key string + placeholderTuples := make([]string, len(paginationKeys)) + args := make([]interface{}, 0, len(paginationKeys)*len(paginationKeyColumns)) + + for i, pkInterface := range paginationKeys { + pkStr, ok := pkInterface.(string) + if !ok { + return nil, fmt.Errorf("expected string pagination key for composite key, got %T", pkInterface) + } + + // Parse the composite key string (comma-separated) + parts := strings.Split(pkStr, ",") + if len(parts) != len(paginationKeyColumns) { + return nil, fmt.Errorf("pagination key has %d parts but expected %d", len(parts), len(paginationKeyColumns)) + } + + placeholders := make([]string, len(parts)) + for j, part := range parts { + placeholders[j] = "?" + // Convert string representation back to appropriate type + col := paginationKeyColumns[j] + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + val, err := strconv.ParseUint(part, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse pagination key part %q as uint64: %w", part, err) + } + args = append(args, val) + case schema.TYPE_BINARY, schema.TYPE_STRING: + // For binary keys, the string is hex-encoded + decoded, err := hex.DecodeString(part) + if err != nil { + return nil, fmt.Errorf("failed to decode pagination key part %q: %w", part, err) + } + args = append(args, decoded) + default: + val, err := strconv.ParseUint(part, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse pagination key part %q: %w", part, err) + } + args = append(args, val) + } + } + placeholderTuples[i] = fmt.Sprintf("(%s)", strings.Join(placeholders, ", ")) + } + + whereClause := fmt.Sprintf("%s IN (%s)", tuple, strings.Join(placeholderTuples, ", ")) + builder = builder.Where(whereClause, args...) + builder = builder.OrderBy(strings.Join(quotedPKCols, ", ")) + } + + sql, args, err := builder.ToSql() if err != nil { return nil, err } @@ -238,11 +335,17 @@ func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []sc return rows, nil } -func rowSelector(columns []schema.TableColumn, paginationKeyColumn string) sq.SelectBuilder { +func rowSelector(columns []schema.TableColumn, paginationKeyColumns []*schema.TableColumn) sq.SelectBuilder { + // Select all pagination key columns first + selectParts := make([]string, len(paginationKeyColumns)) + for i, col := range paginationKeyColumns { + selectParts[i] = QuoteField(col.Name) + } + columnStrs := make([]string, len(columns)) for idx, column := range columns { columnStrs[idx] = column.Name } - return sq.Select(fmt.Sprintf("%s, %s", QuoteField(paginationKeyColumn), strings.Join(columnStrs, ","))) + return sq.Select(fmt.Sprintf("%s, %s", strings.Join(selectParts, ", "), strings.Join(columnStrs, ","))) } diff --git a/config.go b/config.go index d9351f01a..49340a582 100644 --- a/config.go +++ b/config.go @@ -376,12 +376,19 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string { // CascadingPaginationColumnConfig to configure pagination columns to be // used. The term `Cascading` to denote that greater specificity takes // precedence. +// +// IMPORTANT: All configured pagination columns must contain unique values. +// When specifying a FallbackColumn for tables with composite primary keys, +// ensure the column has a unique constraint to prevent data loss during migration. type CascadingPaginationColumnConfig struct { // PerTable has greatest specificity and takes precedence over the other options - PerTable map[string]map[string]string // SchemaName => TableName => ColumnName + // For composite keys, specify comma-separated column names (e.g., "tenant_id,user_id") + PerTable map[string]map[string]string // SchemaName => TableName => "col1,col2,..." // FallbackColumn is a global default to fallback to and is less specific than the - // default, which is the Primary Key + // default, which is the Primary Key. + // For composite keys, specify comma-separated column names (e.g., "tenant_id,user_id") + // These columns MUST have unique values together (ideally a unique constraint) for data integrity. FallbackColumn string } @@ -727,10 +734,19 @@ type Config struct { // ForceIndexForVerification ForceIndexConfig - // Ghostferry requires a single numeric column to paginate over tables. Inferring that column is done in the following exact order: - // 1. Use the PerTable pagination column, if configured for a table. Fail if we cannot find this column in the table. - // 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric or is a composite key without a FallbackColumn specified. - // 3. Use the FallbackColumn pagination column, if configured. Fail if we cannot find this column in the table. + // Ghostferry requires one or more numeric/binary columns to paginate over tables. Inferring columns is done in the following exact order: + // 1. Use the PerTable pagination column(s), if configured for a table. Fail if we cannot find these columns in the table. + // 2. Use the table's primary key column(s) as the pagination column(s). This now supports composite primary keys. + // 3. Use the FallbackColumn pagination column(s), if configured. Fail if we cannot find these columns in the table. + // + // IMPORTANT: The pagination column(s) MUST contain unique values together for data integrity. + // For composite keys (e.g., "tenant_id,user_id"), the combination must be unique (ideally a unique constraint or primary key). + // The pagination algorithm uses WHERE (col1, col2) > (last_val1, last_val2) ORDER BY col1, col2 LIMIT batch_size. + // If duplicate value combinations exist, rows may be skipped during iteration, resulting in data loss during the migration. + // + // Examples: + // Single column: "id" + // Composite key: "tenant_id,user_id" (comma-separated, no spaces recommended) CascadingPaginationColumnConfig *CascadingPaginationColumnConfig // SkipTargetVerification is used to enable or disable target verification during moves. diff --git a/cursor.go b/cursor.go index 9a7a72ed1..d6d5cbabe 100644 --- a/cursor.go +++ b/cursor.go @@ -38,7 +38,7 @@ type CursorConfig struct { Throttler Throttler ColumnsToSelect []string - BuildSelect func([]string, *TableSchema, uint64, uint64) (squirrel.SelectBuilder, error) + BuildSelect func([]string, *TableSchema, PaginationKey, uint64) (squirrel.SelectBuilder, error) // BatchSize is a pointer to the BatchSize in Config.UpdatableConfig which can be independently updated from this code. // Having it as a pointer allows the updated value to be read without needing additional code to copy the batch size value into the cursor config for each cursor we create. BatchSize *uint64 @@ -47,7 +47,7 @@ type CursorConfig struct { } // returns a new Cursor with an embedded copy of itself -func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPaginationKey uint64) *Cursor { +func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPaginationKey PaginationKey) *Cursor { return &Cursor{ CursorConfig: *c, Table: table, @@ -58,7 +58,7 @@ func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPagi } // returns a new Cursor with an embedded copy of itself -func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginationKey, maxPaginationKey uint64) *Cursor { +func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginationKey, maxPaginationKey PaginationKey) *Cursor { cursor := c.NewCursor(table, startPaginationKey, maxPaginationKey) cursor.RowLock = false return cursor @@ -77,11 +77,12 @@ type Cursor struct { CursorConfig Table *TableSchema - MaxPaginationKey uint64 + MaxPaginationKey PaginationKey RowLock bool paginationKeyColumn *schema.TableColumn - lastSuccessfulPaginationKey uint64 + paginationKeyColumns []*schema.TableColumn + lastSuccessfulPaginationKey PaginationKey logger *logrus.Entry } @@ -91,15 +92,16 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { "tag": "cursor", }) c.paginationKeyColumn = c.Table.GetPaginationColumn() + c.paginationKeyColumns = c.Table.GetPaginationColumns() if len(c.ColumnsToSelect) == 0 { c.ColumnsToSelect = []string{"*"} } - for c.lastSuccessfulPaginationKey < c.MaxPaginationKey { + for c.lastSuccessfulPaginationKey.Compare(c.MaxPaginationKey) < 0 { var tx SqlPreparerAndRollbacker var batch *RowBatch - var paginationKeypos uint64 + var paginationKeypos PaginationKey err := WithRetries(c.ReadRetries, 1*time.Second, c.logger, "fetch rows", func() (err error) { if c.Throttler != nil { @@ -137,9 +139,9 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { break } - if paginationKeypos <= c.lastSuccessfulPaginationKey { + if paginationKeypos.Compare(c.lastSuccessfulPaginationKey) <= 0 { tx.Rollback() - err = fmt.Errorf("new paginationKeypos %d <= lastSuccessfulPaginationKey %d", paginationKeypos, c.lastSuccessfulPaginationKey) + err = fmt.Errorf("new paginationKeypos %s <= lastSuccessfulPaginationKey %s", paginationKeypos.String(), c.lastSuccessfulPaginationKey.String()) c.logger.WithError(err).Errorf("last successful paginationKey position did not advance") return err } @@ -159,7 +161,7 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { return nil } -func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64, err error) { +func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos PaginationKey, err error) { var selectBuilder squirrel.SelectBuilder batchSize := c.CursorConfig.GetBatchSize(c.Table.Schema, c.Table.Name) @@ -176,7 +178,7 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 if c.RowLock { mySqlVersion, err := c.DB.QueryMySQLVersion() if err != nil { - return nil, 0, err + return nil, NewUint64Key(0), err } if strings.HasPrefix(mySqlVersion, "8.") { selectBuilder = selectBuilder.Suffix("FOR SHARE NOWAIT") @@ -228,18 +230,21 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 return } - var paginationKeyIndex int = -1 - for idx, col := range columns { - if col == c.paginationKeyColumn.Name { - paginationKeyIndex = idx - break + paginationKeyIndexes := make([]int, len(c.paginationKeyColumns)) + for i, pkCol := range c.paginationKeyColumns { + found := false + for idx, col := range columns { + if col == pkCol.Name { + paginationKeyIndexes[i] = idx + found = true + break + } + } + if !found { + err = fmt.Errorf("paginationKey column %s is not found during iteration with columns: %v", pkCol.Name, columns) + logger.WithError(err).Error("failed to get paginationKey index") + return } - } - - if paginationKeyIndex < 0 { - err = fmt.Errorf("paginationKey is not found during iteration with columns: %v", columns) - logger.WithError(err).Error("failed to get paginationKey index") - return } var rowData RowData @@ -261,20 +266,59 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 } if len(batchData) > 0 { - paginationKeypos, err = batchData[len(batchData)-1].GetUint64(paginationKeyIndex) - if err != nil { - logger.WithError(err).Error("failed to get uint64 paginationKey value") - return - } - } + lastRowData := batchData[len(batchData)-1] + + // Construct paginationKeypos + keys := make([]PaginationKey, len(c.paginationKeyColumns)) + for i, idx := range paginationKeyIndexes { + col := c.paginationKeyColumns[i] + + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + var value uint64 + value, err = lastRowData.GetUint64(idx) + if err != nil { + logger.WithError(err).Errorf("failed to get uint64 paginationKey value for column %s", col.Name) + return + } + keys[i] = NewUint64Key(value) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + valueInterface := lastRowData[idx] + var valueBytes []byte + switch v := valueInterface.(type) { + case []byte: + valueBytes = v + case string: + valueBytes = []byte(v) + default: + err = fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface) + logger.WithError(err).Errorf("failed to get binary paginationKey value for column %s", col.Name) + return + } + keys[i] = NewBinaryKey(valueBytes) - batch = &RowBatch{ - values: batchData, - paginationKeyIndex: paginationKeyIndex, - table: c.Table, - columns: columns, + default: + // Fallback + var value uint64 + value, err = lastRowData.GetUint64(idx) + if err != nil { + logger.WithError(err).Errorf("failed to get uint64 paginationKey value for column %s", col.Name) + return + } + keys[i] = NewUint64Key(value) + } + } + + if len(keys) == 1 { + paginationKeypos = keys[0] + } else { + paginationKeypos = CompositeKey(keys) + } } + batch = NewRowBatchWithIndexes(c.Table, batchData, paginationKeyIndexes) + logger.Debugf("found %d rows", batch.Size()) return @@ -304,12 +348,40 @@ func ScanByteRow(rows *sqlorig.Rows, columnCount int) ([][]byte, error) { return values, err } -func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey, batchSize uint64) squirrel.SelectBuilder { - quotedPaginationKey := QuoteField(table.GetPaginationColumn().Name) +func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey PaginationKey, batchSize uint64) squirrel.SelectBuilder { + pkCols := table.GetPaginationColumns() + quotedPKCols := make([]string, len(pkCols)) + for i, col := range pkCols { + quotedPKCols[i] = QuoteField(col.Name) + } - return squirrel.Select(columns...). + builder := squirrel.Select(columns...). From(QuotedTableName(table)). - Where(squirrel.Gt{quotedPaginationKey: lastPaginationKey}). - Limit(batchSize). - OrderBy(quotedPaginationKey) + Limit(batchSize) + + // Add OrderBy + orderBy := make([]string, len(quotedPKCols)) + for i, colName := range quotedPKCols { + orderBy[i] = colName + } + builder = builder.OrderBy(strings.Join(orderBy, ", ")) + + // Add Where + if len(pkCols) == 1 { + builder = builder.Where(squirrel.Gt{quotedPKCols[0]: lastPaginationKey.SQLValue()}) + } else { + // Composite key: (k1, k2) > (v1, v2) + tuple := fmt.Sprintf("(%s)", strings.Join(quotedPKCols, ", ")) + + placeholders := make([]string, len(quotedPKCols)) + for i := range placeholders { + placeholders[i] = "?" + } + tuplePlaceholders := fmt.Sprintf("(%s)", strings.Join(placeholders, ", ")) + + vals := lastPaginationKey.SQLValue().([]interface{}) + builder = builder.Where(fmt.Sprintf("%s > %s", tuple, tuplePlaceholders), vals...) + } + + return builder } diff --git a/data_iterator.go b/data_iterator.go index 5621a24e0..2ab31b6aa 100644 --- a/data_iterator.go +++ b/data_iterator.go @@ -2,11 +2,11 @@ package ghostferry import ( "fmt" - "math" "sync" sql "github.com/Shopify/ghostferry/sqlwrapper" + "github.com/go-mysql-org/go-mysql/schema" "github.com/sirupsen/logrus" ) @@ -28,7 +28,7 @@ type DataIterator struct { type TableMaxPaginationKey struct { Table *TableSchema - MaxPaginationKey uint64 + MaxPaginationKey PaginationKey } func (d *DataIterator) Run(tables []*TableSchema) { @@ -86,15 +86,15 @@ func (d *DataIterator) Run(tables []*TableSchema) { return } - startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String()) - if startPaginationKey == math.MaxUint64 { + startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String(), table) + if startPaginationKey.IsMax() { err := fmt.Errorf("%v has been marked as completed but a table iterator has been spawned, this is likely a programmer error which resulted in the inconsistent starting state", table.String()) logger.WithError(err).Error("this is definitely a bug") d.ErrorHandler.Fatal("data_iterator", err) return } - cursor := d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(uint64)) + cursor := d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(PaginationKey)) if d.SelectFingerprint { if len(cursor.ColumnsToSelect) == 0 { cursor.ColumnsToSelect = []string{"*"} @@ -110,27 +110,65 @@ func (d *DataIterator) Run(tables []*TableSchema) { }, 1.0) if d.SelectFingerprint { - fingerprints := make(map[uint64][]byte) + fingerprints := make(map[string][]byte) rows := make([]RowData, batch.Size()) + paginationColumns := table.GetPaginationColumns() + paginationKeyIndexes := batch.PaginationKeyIndexes() for i, rowData := range batch.Values() { - paginationKey, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - logger.WithError(err).Error("failed to get paginationKey data") - return err + var paginationKeyStr string + keys := make([]PaginationKey, len(paginationColumns)) + + for k, col := range paginationColumns { + idx := paginationKeyIndexes[k] + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + logger.WithError(err).Error("failed to get uint64 paginationKey data") + return err + } + keys[k] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyInterface := rowData[idx] + var paginationKeyBytes []byte + switch v := paginationKeyInterface.(type) { + case []byte: + paginationKeyBytes = v + case string: + paginationKeyBytes = []byte(v) + default: + return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) + } + keys[k] = NewBinaryKey(paginationKeyBytes) + + default: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + logger.WithError(err).Error("failed to get paginationKey data") + return err + } + keys[k] = NewUint64Key(paginationKeyUint) + } } - fingerprints[paginationKey] = rowData[len(rowData)-1].([]byte) + if len(keys) == 1 { + paginationKeyStr = keys[0].String() + } else { + paginationKeyStr = CompositeKey(keys).String() + } + + fingerprints[paginationKeyStr] = rowData[len(rowData)-1].([]byte) rows[i] = rowData[:len(rowData)-1] } - batch = &RowBatch{ - values: rows, - paginationKeyIndex: batch.PaginationKeyIndex(), - table: table, - fingerprints: fingerprints, - columns: batch.columns[:len(batch.columns)-1], - } + batch = NewRowBatchWithIndexes( + table, + rows, + batch.PaginationKeyIndexes(), + ) + batch.fingerprints = fingerprints } for _, listener := range d.batchListeners { diff --git a/data_iterator_sorter.go b/data_iterator_sorter.go index 4dc912869..8e00625b6 100644 --- a/data_iterator_sorter.go +++ b/data_iterator_sorter.go @@ -8,13 +8,13 @@ import ( // DataIteratorSorter is an interface for the DataIterator to choose which order it will process table type DataIteratorSorter interface { - Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) + Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) } // MaxPaginationKeySorter arranges table based on the MaxPaginationKey in DESC order type MaxPaginationKeySorter struct{} -func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) { +func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) { orderedTables := make([]TableMaxPaginationKey, len(unorderedTables)) i := 0 @@ -24,7 +24,17 @@ func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]uint64) ( } sort.Slice(orderedTables, func(i, j int) bool { - return orderedTables[i].MaxPaginationKey > orderedTables[j].MaxPaginationKey + keyI := orderedTables[i].MaxPaginationKey + keyJ := orderedTables[j].MaxPaginationKey + + // Handle mixed types by sorting by type string first to prevent panics in Compare + typeI := fmt.Sprintf("%T", keyI) + typeJ := fmt.Sprintf("%T", keyJ) + if typeI != typeJ { + return typeI > typeJ + } + + return keyI.Compare(keyJ) > 0 }) return orderedTables, nil @@ -35,7 +45,7 @@ type MaxTableSizeSorter struct { DataIterator *DataIterator } -func (s *MaxTableSizeSorter) Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) { +func (s *MaxTableSizeSorter) Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) { orderedTables := []TableMaxPaginationKey{} tableNames := []string{} databaseSchemasSet := map[string]struct{}{} diff --git a/dml_events.go b/dml_events.go index 46695800f..923a390f6 100644 --- a/dml_events.go +++ b/dml_events.go @@ -76,7 +76,7 @@ type DMLEvent interface { AsSQLString(string, string) (string, error) OldValues() RowData NewValues() RowData - PaginationKey() (uint64, error) + PaginationKey() (string, error) BinlogPosition() mysql.Position ResumableBinlogPosition() mysql.Position Annotation() (string, error) @@ -180,7 +180,7 @@ func (e *BinlogInsertEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogInsertEvent) PaginationKey() (uint64, error) { +func (e *BinlogInsertEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.newValues) } @@ -233,7 +233,7 @@ func (e *BinlogUpdateEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogUpdateEvent) PaginationKey() (uint64, error) { +func (e *BinlogUpdateEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.newValues) } @@ -274,7 +274,7 @@ func (e *BinlogDeleteEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogDeleteEvent) PaginationKey() (uint64, error) { +func (e *BinlogDeleteEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.oldValues) } @@ -571,10 +571,49 @@ func appendEscapedBuffer(buffer, value []byte, isJSON bool) []byte { return buffer } -func paginationKeyFromEventData(table *TableSchema, rowData RowData) (uint64, error) { +func paginationKeyFromEventData(table *TableSchema, rowData RowData) (string, error) { if err := verifyValuesHasTheSameLengthAsColumns(table, rowData); err != nil { - return 0, err + return "", err + } + + paginationColumns := table.GetPaginationColumns() + paginationKeyIndexes := table.GetPaginationKeyIndexes() + keys := make([]PaginationKey, len(paginationColumns)) + + for i, col := range paginationColumns { + idx := paginationKeyIndexes[i] + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + return "", err + } + keys[i] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyInterface := rowData[idx] + var paginationKeyBytes []byte + switch v := paginationKeyInterface.(type) { + case []byte: + paginationKeyBytes = v + case string: + paginationKeyBytes = []byte(v) + default: + return "", fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) + } + keys[i] = NewBinaryKey(paginationKeyBytes) + + default: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + return "", err + } + keys[i] = NewUint64Key(paginationKeyUint) + } } - return rowData.GetUint64(table.GetPaginationKeyIndex()) + if len(keys) == 1 { + return keys[0].String(), nil + } + return CompositeKey(keys).String(), nil } diff --git a/ferry.go b/ferry.go index 3718931b9..28ab46efb 100644 --- a/ferry.go +++ b/ferry.go @@ -288,7 +288,7 @@ func (f *Ferry) NewIterativeVerifier() (*IterativeVerifier, error) { var compressionVerifier *CompressionVerifier if config.TableColumnCompression != nil { - compressionVerifier, err = NewCompressionVerifier(config.TableColumnCompression) + compressionVerifier, err = NewCompressionVerifier(config.TableColumnCompression, f.Tables) if err != nil { return nil, err } @@ -995,7 +995,7 @@ func (f *Ferry) Progress() *Progress { s.Tables = make(map[string]TableProgress) targetPaginationKeys := make(map[string]uint64) f.DataIterator.TargetPaginationKeys.Range(func(k, v interface{}) bool { - targetPaginationKeys[k.(string)] = v.(uint64) + targetPaginationKeys[k.(string)] = uint64(v.(PaginationKey).NumericPosition()) return true }) @@ -1009,7 +1009,7 @@ func (f *Ferry) Progress() *Progress { for _, table := range tables { var currentAction string tableName := table.String() - lastSuccessfulPaginationKey, foundInProgress := serializedState.LastSuccessfulPaginationKeys[tableName] + lastSuccessfulPaginationKeyInterface, foundInProgress := serializedState.LastSuccessfulPaginationKeys[tableName] if serializedState.CompletedTables[tableName] { currentAction = TableActionCompleted @@ -1022,6 +1022,11 @@ func (f *Ferry) Progress() *Progress { rowWrittenStats, _ := rowStatsWrittenPerTable[tableName] + var lastSuccessfulPaginationKey uint64 + if lastSuccessfulPaginationKeyInterface != nil { + lastSuccessfulPaginationKey = uint64(lastSuccessfulPaginationKeyInterface.NumericPosition()) + } + s.Tables[tableName] = TableProgress{ LastSuccessfulPaginationKey: lastSuccessfulPaginationKey, TargetPaginationKey: targetPaginationKeys[tableName], @@ -1041,7 +1046,7 @@ func (f *Ferry) Progress() *Progress { } for _, completedPaginationKey := range serializedState.LastSuccessfulPaginationKeys { - completedPaginationKeys += completedPaginationKey + completedPaginationKeys += uint64(completedPaginationKey.NumericPosition()) } var remainingPaginationKeys float64 = 0 diff --git a/filter.go b/filter.go index de6621429..9c367156e 100644 --- a/filter.go +++ b/filter.go @@ -12,10 +12,10 @@ type CopyFilter interface { // allowing for restricting copying to a subset of data. Returning an error // here will cause the query to be retried, until the retry limit is // reached, at which point the ferry will be aborted. BuildSelect is passed - // the columns to be selected, table being copied, the last primary key value + // the columns to be selected, table being copied, the last pagination key value // from the previous batch, and the batch size. Call DefaultBuildSelect to // generate the default query, which may be used as a starting point. - BuildSelect([]string, *TableSchema, uint64, uint64) (sq.SelectBuilder, error) + BuildSelect([]string, *TableSchema, PaginationKey, uint64) (sq.SelectBuilder, error) // ApplicableEvent is used to filter events for rows that have been // filtered in ConstrainSelect. ApplicableEvent should return true if the diff --git a/inline_verifier.go b/inline_verifier.go index 552c88e6e..cfcb7cce5 100644 --- a/inline_verifier.go +++ b/inline_verifier.go @@ -15,6 +15,7 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" + "github.com/go-mysql-org/go-mysql/schema" "github.com/golang/snappy" "github.com/sirupsen/logrus" ) @@ -56,7 +57,7 @@ type BinlogVerifyStore struct { currentRowCount uint64 // The number of rows in store currently. } -type BinlogVerifySerializedStore map[string]map[string]map[uint64]int +type BinlogVerifySerializedStore map[string]map[string]map[string]int func (s BinlogVerifySerializedStore) RowCount() uint64 { var v uint64 = 0 @@ -85,9 +86,9 @@ func (s BinlogVerifySerializedStore) Copy() BinlogVerifySerializedStore { copyS := make(BinlogVerifySerializedStore) for db, _ := range s { - copyS[db] = make(map[string]map[uint64]int) + copyS[db] = make(map[string]map[string]int) for table, _ := range s[db] { - copyS[db][table] = make(map[uint64]int) + copyS[db][table] = make(map[string]int) for paginationKey, count := range s[db][table] { copyS[db][table][paginationKey] = count } @@ -100,14 +101,14 @@ func (s BinlogVerifySerializedStore) Copy() BinlogVerifySerializedStore { type BinlogVerifyBatch struct { SchemaName string TableName string - PaginationKeys []uint64 + PaginationKeys []interface{} } func NewBinlogVerifyStore() *BinlogVerifyStore { return &BinlogVerifyStore{ - EmitLogPerRowsAdded: uint64(10000), // TODO: make this configurable + EmitLogPerRowsAdded: uint64(10000), mutex: &sync.Mutex{}, - store: make(map[string]map[string]map[uint64]int), + store: make(map[string]map[string]map[string]int), totalRowCount: uint64(0), currentRowCount: uint64(0), } @@ -123,18 +124,18 @@ func NewBinlogVerifyStoreFromSerialized(serialized BinlogVerifySerializedStore) return s } -func (s *BinlogVerifyStore) Add(table *TableSchema, paginationKey uint64) { +func (s *BinlogVerifyStore) Add(table *TableSchema, paginationKey string) { s.mutex.Lock() defer s.mutex.Unlock() _, exists := s.store[table.Schema] if !exists { - s.store[table.Schema] = make(map[string]map[uint64]int) + s.store[table.Schema] = make(map[string]map[string]int) } _, exists = s.store[table.Schema][table.Name] if !exists { - s.store[table.Schema][table.Name] = make(map[uint64]int) + s.store[table.Schema][table.Name] = make(map[string]int) } _, exists = s.store[table.Schema][table.Name][paginationKey] @@ -172,13 +173,15 @@ func (s *BinlogVerifyStore) RemoveVerifiedBatch(batch BinlogVerifyBatch) { } for _, paginationKey := range batch.PaginationKeys { - if _, exists = tableStore[paginationKey]; exists { - if tableStore[paginationKey] <= 1 { - // Even though this doesn't save as RAM, it will save space on the - // serialized output. - delete(tableStore, paginationKey) + paginationKeyStr, ok := paginationKey.(string) + if !ok { + continue + } + if _, exists = tableStore[paginationKeyStr]; exists { + if tableStore[paginationKeyStr] <= 1 { + delete(tableStore, paginationKeyStr) } else { - tableStore[paginationKey]-- + tableStore[paginationKeyStr]-- } s.currentRowCount-- } @@ -192,17 +195,17 @@ func (s *BinlogVerifyStore) Batches(batchsize int) []BinlogVerifyBatch { batches := make([]BinlogVerifyBatch, 0) for schemaName, _ := range s.store { for tableName, paginationKeySet := range s.store[schemaName] { - paginationKeyBatch := make([]uint64, 0, batchsize) + paginationKeyBatch := make([]interface{}, 0, batchsize) - for paginationKey, _ := range paginationKeySet { - paginationKeyBatch = append(paginationKeyBatch, paginationKey) + for paginationKeyStr, _ := range paginationKeySet { + paginationKeyBatch = append(paginationKeyBatch, paginationKeyStr) if len(paginationKeyBatch) >= batchsize { batches = append(batches, BinlogVerifyBatch{ SchemaName: schemaName, TableName: tableName, PaginationKeys: paginationKeyBatch, }) - paginationKeyBatch = make([]uint64, 0, batchsize) + paginationKeyBatch = make([]interface{}, 0, batchsize) } } @@ -247,7 +250,7 @@ const ( ) type InlineVerifierMismatches struct { - Pk uint64 + Pk string SourceChecksum string TargetChecksum string MismatchColumn string @@ -328,15 +331,48 @@ func (v *InlineVerifier) Result() (VerificationResultAndStatus, error) { func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, targetTable string, sourceBatch *RowBatch, enforceInlineVerification bool) ([]InlineVerifierMismatches, error) { table := sourceBatch.TableSchema() + paginationColumns := table.GetPaginationColumns() + paginationKeyIndexes := sourceBatch.PaginationKeyIndexes() - paginationKeys := make([]uint64, len(sourceBatch.Values())) + paginationKeys := make([]PaginationKey, len(sourceBatch.Values())) for i, row := range sourceBatch.Values() { - paginationKey, err := row.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } + keys := make([]PaginationKey, len(paginationColumns)) + for k, col := range paginationColumns { + idx := paginationKeyIndexes[k] + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := row.GetUint64(idx) + if err != nil { + return nil, err + } + keys[k] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyInterface := row[idx] + var paginationKeyBytes []byte + switch v := paginationKeyInterface.(type) { + case []byte: + paginationKeyBytes = v + case string: + paginationKeyBytes = []byte(v) + default: + return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) + } + keys[k] = NewBinaryKey(paginationKeyBytes) - paginationKeys[i] = paginationKey + default: + paginationKeyUint, err := row.GetUint64(idx) + if err != nil { + return nil, err + } + keys[k] = NewUint64Key(paginationKeyUint) + } + } + if len(keys) == 1 { + paginationKeys[i] = keys[0] + } else { + paginationKeys[i] = CompositeKey(keys) + } } // Fetch target data @@ -347,15 +383,12 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target // Fetch source data sourceFingerprints := sourceBatch.Fingerprints() - sourceDecompressedData := make(map[uint64]map[string][]byte) + sourceDecompressedData := make(map[string]map[string][]byte) - for _, rowData := range sourceBatch.Values() { - paginationKey, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } + for i, rowData := range sourceBatch.Values() { + paginationKeyStr := paginationKeys[i].String() - sourceDecompressedData[paginationKey] = make(map[string][]byte) + sourceDecompressedData[paginationKeyStr] = make(map[string][]byte) for idx, col := range table.Columns { var compressedData []byte var ok bool @@ -368,7 +401,7 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target return nil, fmt.Errorf("cannot convert column %v to []byte", col.Name) } - sourceDecompressedData[paginationKey][col.Name], err = v.decompressData(table, col.Name, compressedData) + sourceDecompressedData[paginationKeyStr][col.Name], err = v.decompressData(table, col.Name, compressedData) } } @@ -468,7 +501,7 @@ func formatMismatches(mismatches map[string]map[string][]InlineVerifierMismatche messageBuf.WriteString(tableNameWithSchema) messageBuf.WriteString(" [PKs: ") for _, mismatch := range mismatches[schemaName][tableName] { - messageBuf.WriteString(strconv.FormatUint(mismatch.Pk, 10)) + messageBuf.WriteString(mismatch.Pk) messageBuf.WriteString(" (type: ") messageBuf.WriteString(string(mismatch.MismatchType)) if mismatch.SourceChecksum != "" { @@ -521,15 +554,15 @@ func (v *InlineVerifier) VerifyDuringCutover() (VerificationResult, error) { }, nil } -func (v *InlineVerifier) getFingerprintDataFromSourceDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromSourceDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []PaginationKey) (map[string][]byte, map[string]map[string][]byte, error) { return v.getFingerprintDataFromDb(v.SourceDB, v.sourceStmtCache, schemaName, tableName, tx, table, paginationKeys) } -func (v *InlineVerifier) getFingerprintDataFromTargetDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromTargetDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []PaginationKey) (map[string][]byte, map[string]map[string][]byte, error) { return v.getFingerprintDataFromDb(v.TargetDB, v.targetStmtCache, schemaName, tableName, tx, table, paginationKeys) } -func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCache, schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCache, schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []PaginationKey) (map[string][]byte, map[string]map[string][]byte, error) { fingerprintQuery := table.FingerprintQuery(schemaName, tableName, len(paginationKeys)) fingerprintStmt, err := stmtCache.StmtFor(db, fingerprintQuery) if err != nil { @@ -540,9 +573,17 @@ func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCac fingerprintStmt = tx.Stmt(fingerprintStmt) } - args := make([]interface{}, len(paginationKeys)) - for i, paginationKey := range paginationKeys { - args[i] = paginationKey + args := make([]interface{}, 0, len(paginationKeys)) + for _, key := range paginationKeys { + // Flatten keys for arguments + switch k := key.(type) { + case CompositeKey: + for _, subKey := range k { + args = append(args, subKey.SQLValue()) + } + default: + args = append(args, key.SQLValue()) + } } rows, err := fingerprintStmt.Query(args...) @@ -555,8 +596,10 @@ func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCac return nil, nil, err } - fingerprints := make(map[uint64][]byte) // paginationKey -> fingerprint - decompressedData := make(map[uint64]map[string][]byte) // paginationKey -> columnName -> decompressedData + fingerprints := make(map[string][]byte) + decompressedData := make(map[string]map[string][]byte) + paginationCols := table.GetPaginationColumns() + numPKs := len(paginationCols) for rows.Next() { rowData, err := ScanByteRow(rows, len(columns)) @@ -564,20 +607,41 @@ func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCac return nil, nil, err } - paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64) - if err != nil { - return nil, nil, err + // Reconstruct the key + keys := make([]PaginationKey, numPKs) + for i, col := range paginationCols { + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64) + if err != nil { + return nil, nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + keys[i] = NewBinaryKey(rowData[i]) + + default: + paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64) + if err != nil { + return nil, nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + } + } + + var paginationKeyStr string + if len(keys) == 1 { + paginationKeyStr = keys[0].String() + } else { + paginationKeyStr = CompositeKey(keys).String() } - fingerprints[paginationKey] = rowData[1] - decompressedData[paginationKey] = make(map[string][]byte) + fingerprints[paginationKeyStr] = rowData[numPKs] + decompressedData[paginationKeyStr] = make(map[string][]byte) - // Note that the FingerprintQuery returns the columns: paginationKey, fingerprint, - // compressedData1, compressedData2, ... - // If there are no compressed data, only 2 columns are returned and this - // loop will be skipped. - for i := 2; i < len(columns); i++ { - decompressedData[paginationKey][columns[i]], err = v.decompressData(table, columns[i], rowData[i]) + for i := numPKs + 1; i < len(columns); i++ { + decompressedData[paginationKeyStr][columns[i]], err = v.decompressData(table, columns[i], rowData[i]) if err != nil { return nil, nil, err } @@ -606,8 +670,8 @@ func (v *InlineVerifier) decompressData(table *TableSchema, column string, compr } } -func (v *InlineVerifier) compareHashes(source, target map[uint64][]byte) map[uint64]InlineVerifierMismatches { - mismatchSet := map[uint64]InlineVerifierMismatches{} +func (v *InlineVerifier) compareHashes(source, target map[string][]byte) map[string]InlineVerifierMismatches { + mismatchSet := map[string]InlineVerifierMismatches{} for paginationKey, targetHash := range target { sourceHash, exists := source[paginationKey] @@ -639,8 +703,8 @@ func (v *InlineVerifier) compareHashes(source, target map[uint64][]byte) map[uin return mismatchSet } -func compareDecompressedData(source, target map[uint64]map[string][]byte) map[uint64]InlineVerifierMismatches { - mismatchSet := map[uint64]InlineVerifierMismatches{} +func compareDecompressedData(source, target map[string]map[string][]byte) map[string]InlineVerifierMismatches { + mismatchSet := map[string]InlineVerifierMismatches{} for paginationKey, targetDecompressedColumns := range target { sourceDecompressedColumns, exists := source[paginationKey] @@ -704,7 +768,7 @@ func compareDecompressedData(source, target map[uint64]map[string][]byte) map[ui return mismatchSet } -func (v *InlineVerifier) compareHashesAndData(sourceHashes, targetHashes map[uint64][]byte, sourceData, targetData map[uint64]map[string][]byte) []InlineVerifierMismatches { +func (v *InlineVerifier) compareHashesAndData(sourceHashes, targetHashes map[string][]byte, sourceData, targetData map[string]map[string][]byte) []InlineVerifierMismatches { mismatches := v.compareHashes(sourceHashes, targetHashes) compressedMismatch := compareDecompressedData(sourceData, targetData) for paginationKey, mismatch := range compressedMismatch { @@ -789,6 +853,43 @@ func (v *InlineVerifier) verifyAllEventsInStore() (bool, map[string]map[string][ return mismatchFound, mismatches, nil } +func parsePaginationKeyFromString(table *TableSchema, keyStr string) (PaginationKey, error) { + parts := strings.Split(keyStr, ",") + paginationCols := table.GetPaginationColumns() + if len(parts) != len(paginationCols) { + return nil, fmt.Errorf("key string %s does not match column count %d", keyStr, len(paginationCols)) + } + + keys := make([]PaginationKey, len(paginationCols)) + for i, col := range paginationCols { + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + val, err := strconv.ParseUint(parts[i], 10, 64) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(val) + case schema.TYPE_BINARY, schema.TYPE_STRING: + valBytes, err := hex.DecodeString(parts[i]) + if err != nil { + return nil, err + } + keys[i] = NewBinaryKey(valBytes) + default: + val, err := strconv.ParseUint(parts[i], 10, 64) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(val) + } + } + + if len(keys) == 1 { + return keys[0], nil + } + return CompositeKey(keys), nil +} + // Returns a list of mismatched PaginationKeys. // Since the mismatches gets re-added to the reverify store, this must return // a union of mismatches of fingerprints and mismatches due to decompressed @@ -809,11 +910,25 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer return []InlineVerifierMismatches{}, fmt.Errorf("programming error? %s.%s is not found in TableSchemaCache but is being reverified", batch.SchemaName, batch.TableName) } + keys := make([]PaginationKey, 0, len(batch.PaginationKeys)) + for _, kRaw := range batch.PaginationKeys { + kStr, ok := kRaw.(string) + if !ok { + continue + } + pk, err := parsePaginationKeyFromString(sourceTableSchema, kStr) + if err != nil { + v.logger.WithError(err).Warnf("failed to parse pagination key %s", kStr) + continue + } + keys = append(keys, pk) + } + wg := &sync.WaitGroup{} wg.Add(2) - var sourceFingerprints map[uint64][]byte - var sourceDecompressedData map[uint64]map[string][]byte + var sourceFingerprints map[string][]byte + var sourceDecompressedData map[string]map[string][]byte var sourceErr error go func() { defer wg.Done() @@ -822,14 +937,14 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer batch.SchemaName, batch.TableName, nil, // No transaction sourceTableSchema, - batch.PaginationKeys, + keys, ) return }) }() - var targetFingerprints map[uint64][]byte - var targetDecompressedData map[uint64]map[string][]byte + var targetFingerprints map[string][]byte + var targetDecompressedData map[string]map[string][]byte var targetErr error go func() { defer wg.Done() @@ -838,7 +953,7 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer targetSchema, targetTable, nil, // No transaction sourceTableSchema, - batch.PaginationKeys, + keys, ) return }) diff --git a/inline_verifier_test.go b/inline_verifier_test.go index 1db37362b..58b5073d4 100644 --- a/inline_verifier_test.go +++ b/inline_verifier_test.go @@ -7,31 +7,31 @@ import ( ) func TestCompareDecompressedDataNoDifference(t *testing.T) { - source := map[uint64]map[string][]byte{ - 31: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "31": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{ - 31: {"name": []byte("Leszek")}, + target := map[string]map[string][]byte{ + "31": {"name": []byte("Leszek")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{}, result) } func TestCompareDecompressedDataContentDifference(t *testing.T) { - source := map[uint64]map[string][]byte{ - 1: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "1": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{ - 1: {"name": []byte("Steve")}, + target := map[string]map[string][]byte{ + "1": {"name": []byte("Steve")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{ - 1: { - Pk: 1, + assert.Equal(t, map[string]InlineVerifierMismatches{ + "1": { + Pk: "1", MismatchType: MismatchColumnValueDifference, MismatchColumn: "name", SourceChecksum: "e356a972989f87a1531252cfa2152797", @@ -41,25 +41,25 @@ func TestCompareDecompressedDataContentDifference(t *testing.T) { } func TestCompareDecompressedDataMissingTarget(t *testing.T) { - source := map[uint64]map[string][]byte{ - 1: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "1": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{} + target := map[string]map[string][]byte{} result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{1: {Pk: 1, MismatchType: MismatchRowMissingOnTarget}}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{"1": {Pk: "1", MismatchType: MismatchRowMissingOnTarget}}, result) } func TestCompareDecompressedDataMissingSource(t *testing.T) { - source := map[uint64]map[string][]byte{} - target := map[uint64]map[string][]byte{ - 3: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{} + target := map[string]map[string][]byte{ + "3": {"name": []byte("Leszek")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{3: {Pk: 3, MismatchType: MismatchRowMissingOnSource}}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{"3": {Pk: "3", MismatchType: MismatchRowMissingOnSource}}, result) } func TestFormatMismatch(t *testing.T) { @@ -67,7 +67,7 @@ func TestFormatMismatch(t *testing.T) { "default": { "users": { InlineVerifierMismatches{ - Pk: 1, + Pk: "1", MismatchType: MismatchRowMissingOnSource, }, }, @@ -84,17 +84,17 @@ func TestFormatMismatches(t *testing.T) { "default": { "users": { InlineVerifierMismatches{ - Pk: 1, + Pk: "1", MismatchType: MismatchRowMissingOnSource, }, InlineVerifierMismatches{ - Pk: 5, + Pk: "5", MismatchType: MismatchRowMissingOnTarget, }, }, "posts": { InlineVerifierMismatches{ - Pk: 9, + Pk: "9", MismatchType: MismatchColumnValueDifference, MismatchColumn: string("title"), SourceChecksum: "boo", @@ -103,7 +103,7 @@ func TestFormatMismatches(t *testing.T) { }, "attachments": { InlineVerifierMismatches{ - Pk: 7, + Pk: "7", MismatchType: MismatchColumnValueDifference, MismatchColumn: string("name"), SourceChecksum: "boo", diff --git a/iterative_verifier.go b/iterative_verifier.go index 114bf35e5..c167b35a3 100644 --- a/iterative_verifier.go +++ b/iterative_verifier.go @@ -2,9 +2,9 @@ package ghostferry import ( "bytes" + "encoding/hex" "errors" "fmt" - "math" "strconv" "strings" "sync" @@ -18,17 +18,17 @@ import ( ) type ReverifyBatch struct { - PaginationKeys []uint64 + PaginationKeys []interface{} Table TableIdentifier } type ReverifyEntry struct { - PaginationKey uint64 + PaginationKey string Table *TableSchema } type ReverifyStore struct { - MapStore map[TableIdentifier]map[uint64]struct{} + MapStore map[TableIdentifier]map[string]struct{} mapStoreMutex *sync.Mutex BatchStore []ReverifyBatch RowCount uint64 @@ -50,13 +50,14 @@ func (r *ReverifyStore) Add(entry ReverifyEntry) { r.mapStoreMutex.Lock() defer r.mapStoreMutex.Unlock() + paginationKeyStr := entry.PaginationKey tableId := NewTableIdentifierFromSchemaTable(entry.Table) if _, exists := r.MapStore[tableId]; !exists { - r.MapStore[tableId] = make(map[uint64]struct{}) + r.MapStore[tableId] = make(map[string]struct{}) } - if _, exists := r.MapStore[tableId][entry.PaginationKey]; !exists { - r.MapStore[tableId][entry.PaginationKey] = struct{}{} + if _, exists := r.MapStore[tableId][paginationKeyStr]; !exists { + r.MapStore[tableId][paginationKeyStr] = struct{}{} r.RowCount++ if r.RowCount%r.EmitLogPerRowCount == 0 { metrics.Gauge("iterative_verifier_store_rows", float64(r.RowCount), []MetricTag{}, 1.0) @@ -74,16 +75,16 @@ func (r *ReverifyStore) FlushAndBatchByTable(batchsize int) []ReverifyBatch { r.BatchStore = make([]ReverifyBatch, 0) for tableId, paginationKeySet := range r.MapStore { - paginationKeyBatch := make([]uint64, 0, batchsize) - for paginationKey, _ := range paginationKeySet { - paginationKeyBatch = append(paginationKeyBatch, paginationKey) - delete(paginationKeySet, paginationKey) + paginationKeyBatch := make([]interface{}, 0, batchsize) + for paginationKeyStr, _ := range paginationKeySet { + paginationKeyBatch = append(paginationKeyBatch, paginationKeyStr) + delete(paginationKeySet, paginationKeyStr) if len(paginationKeyBatch) >= batchsize { r.BatchStore = append(r.BatchStore, ReverifyBatch{ PaginationKeys: paginationKeyBatch, Table: tableId, }) - paginationKeyBatch = make([]uint64, 0, batchsize) + paginationKeyBatch = make([]interface{}, 0, batchsize) } } @@ -102,7 +103,7 @@ func (r *ReverifyStore) FlushAndBatchByTable(batchsize int) []ReverifyBatch { } func (r *ReverifyStore) flushStore() { - r.MapStore = make(map[TableIdentifier]map[uint64]struct{}) + r.MapStore = make(map[TableIdentifier]map[string]struct{}) r.RowCount = 0 } @@ -184,10 +185,10 @@ func (v *IterativeVerifier) Initialize() error { func (v *IterativeVerifier) VerifyOnce() (VerificationResult, error) { v.logger.Info("starting one-off verification of all tables") - err := v.iterateAllTables(func(paginationKey uint64, tableSchema *TableSchema) error { + err := v.iterateAllTables(func(paginationKey string, tableSchema *TableSchema) error { return VerificationResult{ DataCorrect: false, - Message: fmt.Sprintf("verification failed on table: %s for paginationKey: %d", tableSchema.String(), paginationKey), + Message: fmt.Sprintf("verification failed on table: %s for paginationKey: %s", tableSchema.String(), paginationKey), IncorrectTables: []string{tableSchema.String()}, } }) @@ -213,7 +214,7 @@ func (v *IterativeVerifier) VerifyBeforeCutover() error { v.BinlogStreamer.AddEventListener(v.binlogEventListener) v.logger.Debug("verifying all tables") - err := v.iterateAllTables(func(paginationKey uint64, tableSchema *TableSchema) error { + err := v.iterateAllTables(func(paginationKey string, tableSchema *TableSchema) error { v.reverifyStore.Add(ReverifyEntry{PaginationKey: paginationKey, Table: tableSchema}) return nil }) @@ -290,15 +291,15 @@ func (v *IterativeVerifier) Result() (VerificationResultAndStatus, error) { return v.verificationResultAndStatus, v.verificationErr } -func (v *IterativeVerifier) GetHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) { - sql, args, err := GetMd5HashesSql(schema, table, paginationKeyColumn, columns, paginationKeys) +func (v *IterativeVerifier) GetHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) { + table := v.TableSchemaCache.Get(schemaName, tableName) + paginationColumns := table.GetPaginationColumns() + + sql, args, err := GetMd5HashesSql(schemaName, tableName, paginationColumns, columns, paginationKeys) if err != nil { return nil, err } - // This query must be a prepared query. If it is not, querying will use - // MySQL's plain text interface, which will scan all values into []uint8 - // if we give it []interface{}. stmt, err := db.Prepare(sql) if err != nil { return nil, err @@ -313,19 +314,52 @@ func (v *IterativeVerifier) GetHashes(db *sql.DB, schema, table, paginationKeyCo defer rows.Close() - resultSet := make(map[uint64][]byte) + resultSet := make(map[string][]byte) + numPaginationCols := len(paginationColumns) + for rows.Next() { - rowData, err := ScanGenericRow(rows, 2) + // Scan: pagination_col1, pagination_col2, ..., row_fingerprint + rowData, err := ScanGenericRow(rows, numPaginationCols+1) if err != nil { return nil, err } - paginationKey, err := rowData.GetUint64(0) - if err != nil { - return nil, err + // Build pagination key from columns (works for both single and composite keys) + keys := make([]PaginationKey, len(paginationColumns)) + for i, paginationColumn := range paginationColumns { + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := rowData.GetUint64(i) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyBytes, ok := rowData[i].([]byte) + if !ok { + return nil, fmt.Errorf("expected []byte for binary pagination key, got %T", rowData[i]) + } + keys[i] = NewBinaryKey(paginationKeyBytes) + + default: + paginationKeyUint, err := rowData.GetUint64(i) + if err != nil { + return nil, err + } + keys[i] = NewUint64Key(paginationKeyUint) + } } - resultSet[paginationKey] = rowData[1].([]byte) + // For single column, use the key directly; for composite, wrap in CompositeKey + var paginationKeyStr string + if len(keys) == 1 { + paginationKeyStr = keys[0].String() + } else { + paginationKeyStr = CompositeKey(keys).String() + } + + resultSet[paginationKeyStr] = rowData[numPaginationCols].([]byte) } return resultSet, nil } @@ -363,7 +397,7 @@ func (v *IterativeVerifier) reverifyUntilStoreIsSmallEnough(maxIterations int) e return nil } -func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(uint64, *TableSchema) error) error { +func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(string, *TableSchema) error) error { pool := &WorkerPool{ Concurrency: v.Concurrency, Process: func(tableIndex int) (interface{}, error) { @@ -386,28 +420,100 @@ func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(ui return err } -func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatchedPaginationKeyFunc func(uint64, *TableSchema) error) error { +func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatchedPaginationKeyFunc func(string, *TableSchema) error) error { // The cursor will stop iterating when it cannot find anymore rows, - // so it will not iterate until MaxUint64. - cursor := v.CursorConfig.NewCursorWithoutRowLock(table, 0, math.MaxUint64) + // so it will not iterate until MaxPaginationKey. + paginationColumns := table.GetPaginationColumns() + minKey := MinPaginationKey(paginationColumns) + maxKey := MaxPaginationKey(paginationColumns) + + cursor := v.CursorConfig.NewCursorWithoutRowLock(table, minKey, maxKey) // It only needs the PaginationKeys, not the entire row. - cursor.ColumnsToSelect = []string{fmt.Sprintf("`%s`", table.GetPaginationColumn().Name)} + columnsToSelect := make([]string, len(paginationColumns)) + for i, col := range paginationColumns { + columnsToSelect[i] = fmt.Sprintf("`%s`", col.Name) + } + cursor.ColumnsToSelect = columnsToSelect + return cursor.Each(func(batch *RowBatch) error { metrics.Count("RowEvent", int64(batch.Size()), []MetricTag{ MetricTag{"table", table.Name}, MetricTag{"source", "iterative_verifier_before_cutover"}, }, 1.0) - paginationKeys := make([]uint64, 0, batch.Size()) + paginationKeys := make([]interface{}, 0, batch.Size()) + paginationKeyIndexes := batch.PaginationKeyIndexes() for _, rowData := range batch.Values() { - paginationKey, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err + if len(paginationColumns) == 1 { + // Single column - use existing logic + paginationColumn := paginationColumns[0] + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := rowData.GetUint64(paginationKeyIndexes[0]) + if err != nil { + return err + } + paginationKeys = append(paginationKeys, paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyInterface := rowData[paginationKeyIndexes[0]] + var paginationKeyBytes []byte + switch v := paginationKeyInterface.(type) { + case []byte: + paginationKeyBytes = v + case string: + paginationKeyBytes = []byte(v) + default: + return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) + } + paginationKeys = append(paginationKeys, paginationKeyBytes) + + default: + paginationKeyUint, err := rowData.GetUint64(paginationKeyIndexes[0]) + if err != nil { + return err + } + paginationKeys = append(paginationKeys, paginationKeyUint) + } + } else { + // Composite key - append as a string representation + keys := make([]PaginationKey, len(paginationColumns)) + for i, paginationColumn := range paginationColumns { + idx := paginationKeyIndexes[i] + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + return err + } + keys[i] = NewUint64Key(paginationKeyUint) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyInterface := rowData[idx] + var paginationKeyBytes []byte + switch v := paginationKeyInterface.(type) { + case []byte: + paginationKeyBytes = v + case string: + paginationKeyBytes = []byte(v) + default: + return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) + } + keys[i] = NewBinaryKey(paginationKeyBytes) + + default: + paginationKeyUint, err := rowData.GetUint64(idx) + if err != nil { + return err + } + keys[i] = NewUint64Key(paginationKeyUint) + } + } + // Store as string for comparison with GetHashes results + paginationKeys = append(paginationKeys, CompositeKey(keys).String()) } - - paginationKeys = append(paginationKeys, paginationKey) } mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, batch.TableSchema()) @@ -513,7 +619,7 @@ func (v *IterativeVerifier) verifyStore(sourceTag string, additionalTags []Metri return result, err } -func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginationKeys []uint64) (VerificationResult, []uint64, error) { +func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginationKeys []interface{}) (VerificationResult, []string, error) { mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, table) if err != nil { return VerificationResult{}, mismatchedPaginationKeys, err @@ -523,14 +629,9 @@ func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginatio return NewCorrectVerificationResult(), mismatchedPaginationKeys, nil } - paginationKeyStrings := make([]string, len(mismatchedPaginationKeys)) - for idx, paginationKey := range mismatchedPaginationKeys { - paginationKeyStrings[idx] = strconv.FormatUint(paginationKey, 10) - } - return VerificationResult{ DataCorrect: false, - Message: fmt.Sprintf("verification failed on table: %s for paginationKeys: %s", table.String(), strings.Join(paginationKeyStrings, ",")), + Message: fmt.Sprintf("verification failed on table: %s for paginationKeys: %s", table.String(), strings.Join(mismatchedPaginationKeys, ",")), IncorrectTables: []string{table.String()}, }, mismatchedPaginationKeys, nil } @@ -582,7 +683,7 @@ func (v *IterativeVerifier) columnsToVerify(table *TableSchema) []schema.TableCo return columns } -func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table *TableSchema) ([]uint64, error) { +func (v *IterativeVerifier) compareFingerprints(paginationKeys []interface{}, table *TableSchema) ([]string, error) { targetDb := table.Schema if targetDbName, exists := v.DatabaseRewrites[targetDb]; exists { targetDb = targetDbName @@ -596,22 +697,24 @@ func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table * wg := &sync.WaitGroup{} wg.Add(2) - var sourceHashes map[uint64][]byte + var sourceHashes map[string][]byte var sourceErr error go func() { defer wg.Done() sourceErr = WithRetries(5, 0, v.logger, "get fingerprints from source db", func() (err error) { - sourceHashes, err = v.GetHashes(v.SourceDB, table.Schema, table.Name, table.GetPaginationColumn().Name, v.columnsToVerify(table), paginationKeys) + // Pass deprecated single column name for backward compatibility (unused in GetHashes now) + sourceHashes, err = v.GetHashes(v.SourceDB, table.Schema, table.Name, "", v.columnsToVerify(table), paginationKeys) return }) }() - var targetHashes map[uint64][]byte + var targetHashes map[string][]byte var targetErr error go func() { defer wg.Done() targetErr = WithRetries(5, 0, v.logger, "get fingerprints from target db", func() (err error) { - targetHashes, err = v.GetHashes(v.TargetDB, targetDb, targetTable, table.GetPaginationColumn().Name, v.columnsToVerify(table), paginationKeys) + // Pass deprecated single column name for backward compatibility (unused in GetHashes now) + targetHashes, err = v.GetHashes(v.TargetDB, targetDb, targetTable, "", v.columnsToVerify(table), paginationKeys) return }) }() @@ -632,13 +735,14 @@ func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table * return mismatches, nil } -func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string, table *TableSchema, paginationKeys []uint64) ([]uint64, error) { - sourceHashes, err := v.CompressionVerifier.GetCompressedHashes(v.SourceDB, table.Schema, table.Name, table.GetPaginationColumn().Name, v.columnsToVerify(table), paginationKeys) +func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string, table *TableSchema, paginationKeys []interface{}) ([]string, error) { + // Pass empty string for deprecated paginationKeyColumn parameter (CompressionVerifier will get it from TableSchemaCache) + sourceHashes, err := v.CompressionVerifier.GetCompressedHashes(v.SourceDB, table.Schema, table.Name, "", v.columnsToVerify(table), paginationKeys) if err != nil { return nil, err } - targetHashes, err := v.CompressionVerifier.GetCompressedHashes(v.TargetDB, targetDb, targetTable, table.GetPaginationColumn().Name, v.columnsToVerify(table), paginationKeys) + targetHashes, err := v.CompressionVerifier.GetCompressedHashes(v.TargetDB, targetDb, targetTable, "", v.columnsToVerify(table), paginationKeys) if err != nil { return nil, err } @@ -646,8 +750,8 @@ func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string return compareHashes(sourceHashes, targetHashes), nil } -func compareHashes(source, target map[uint64][]byte) []uint64 { - mismatchSet := map[uint64]struct{}{} +func compareHashes(source, target map[string][]byte) []string { + mismatchSet := map[string]struct{}{} for paginationKey, targetHash := range target { sourceHash, exists := source[paginationKey] @@ -663,7 +767,7 @@ func compareHashes(source, target map[uint64][]byte) []uint64 { } } - mismatches := make([]uint64, 0, len(mismatchSet)) + mismatches := make([]string, 0, len(mismatchSet)) for mismatch, _ := range mismatchSet { mismatches = append(mismatches, mismatch) } @@ -671,17 +775,85 @@ func compareHashes(source, target map[uint64][]byte) []uint64 { return mismatches } -func GetMd5HashesSql(schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (string, []interface{}, error) { - quotedPaginationKey := QuoteField(paginationKeyColumn) - return rowMd5Selector(columns, paginationKeyColumn). - From(QuotedTableNameFromString(schema, table)). - Where(sq.Eq{quotedPaginationKey: paginationKeys}). - OrderBy(quotedPaginationKey). - ToSql() +func GetMd5HashesSql(schemaName, table string, paginationKeyColumns []*schema.TableColumn, columns []schema.TableColumn, paginationKeys []interface{}) (string, []interface{}, error) { + builder := rowMd5Selector(columns, paginationKeyColumns). + From(QuotedTableNameFromString(schemaName, table)) + + if len(paginationKeyColumns) == 1 { + // Single column WHERE clause + quotedPaginationKey := QuoteField(paginationKeyColumns[0].Name) + builder = builder.Where(sq.Eq{quotedPaginationKey: paginationKeys}) + builder = builder.OrderBy(quotedPaginationKey) + + return builder.ToSql() + } + + // Composite key WHERE clause: (col1, col2) IN ((?, ?), (?, ?), ...) + quotedPKCols := make([]string, len(paginationKeyColumns)) + for i, col := range paginationKeyColumns { + quotedPKCols[i] = QuoteField(col.Name) + } + tuple := fmt.Sprintf("(%s)", strings.Join(quotedPKCols, ", ")) + + // Build placeholder tuples for each pagination key string + placeholderTuples := make([]string, len(paginationKeys)) + args := make([]interface{}, 0, len(paginationKeys)*len(paginationKeyColumns)) + + for i, pkInterface := range paginationKeys { + pkStr, ok := pkInterface.(string) + if !ok { + return "", nil, fmt.Errorf("expected string pagination key for composite key, got %T", pkInterface) + } + + // Parse the composite key string (comma-separated) + parts := strings.Split(pkStr, ",") + if len(parts) != len(paginationKeyColumns) { + return "", nil, fmt.Errorf("pagination key has %d parts but expected %d", len(parts), len(paginationKeyColumns)) + } + + placeholders := make([]string, len(parts)) + for j, part := range parts { + placeholders[j] = "?" + // Convert string representation back to appropriate type + col := paginationKeyColumns[j] + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + val, err := strconv.ParseUint(part, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("failed to parse pagination key part %q as uint64: %w", part, err) + } + args = append(args, val) + case schema.TYPE_BINARY, schema.TYPE_STRING: + // For binary keys, the string is hex-encoded + decoded, err := hex.DecodeString(part) + if err != nil { + return "", nil, fmt.Errorf("failed to decode pagination key part %q: %w", part, err) + } + args = append(args, decoded) + default: + val, err := strconv.ParseUint(part, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("failed to parse pagination key part %q: %w", part, err) + } + args = append(args, val) + } + } + placeholderTuples[i] = fmt.Sprintf("(%s)", strings.Join(placeholders, ", ")) + } + + whereClause := fmt.Sprintf("%s IN (%s)", tuple, strings.Join(placeholderTuples, ", ")) + builder = builder.Where(whereClause, args...) + builder = builder.OrderBy(strings.Join(quotedPKCols, ", ")) + + return builder.ToSql() } -func rowMd5Selector(columns []schema.TableColumn, paginationKeyColumn string) sq.SelectBuilder { - quotedPaginationKey := QuoteField(paginationKeyColumn) +func rowMd5Selector(columns []schema.TableColumn, paginationKeyColumns []*schema.TableColumn) sq.SelectBuilder { + // Select all pagination key columns + selectParts := make([]string, len(paginationKeyColumns)) + for i, col := range paginationKeyColumns { + selectParts[i] = QuoteField(col.Name) + } hashStrs := make([]string, len(columns)) for idx, column := range columns { @@ -691,7 +863,7 @@ func rowMd5Selector(columns []schema.TableColumn, paginationKeyColumn string) sq return sq.Select(fmt.Sprintf( "%s, MD5(CONCAT(%s)) AS row_fingerprint", - quotedPaginationKey, + strings.Join(selectParts, ", "), strings.Join(hashStrs, ","), )) } diff --git a/pagination_key.go b/pagination_key.go new file mode 100644 index 000000000..2d4527812 --- /dev/null +++ b/pagination_key.go @@ -0,0 +1,333 @@ +package ghostferry + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "strings" + + "github.com/go-mysql-org/go-mysql/schema" +) + +type PaginationKey interface { + SQLValue() interface{} + Compare(other PaginationKey) int + NumericPosition() float64 + String() string + MarshalJSON() ([]byte, error) + IsMax() bool +} + +type Uint64Key uint64 + +func NewUint64Key(value uint64) Uint64Key { + return Uint64Key(value) +} + +func (k Uint64Key) SQLValue() interface{} { + return uint64(k) +} + +func (k Uint64Key) Compare(other PaginationKey) int { + otherKey, ok := other.(Uint64Key) + if !ok { + panic(fmt.Sprintf("cannot compare Uint64Key with %T", other)) + } + + if k < otherKey { + return -1 + } else if k > otherKey { + return 1 + } + return 0 +} + +func (k Uint64Key) NumericPosition() float64 { + return float64(k) +} + +func (k Uint64Key) String() string { + return fmt.Sprintf("%d", uint64(k)) +} + +func (k Uint64Key) IsMax() bool { + return k == Uint64Key(math.MaxUint64) +} + +func (k Uint64Key) MarshalJSON() ([]byte, error) { + return json.Marshal(uint64(k)) +} + +type BinaryKey []byte + +func NewBinaryKey(value []byte) BinaryKey { + clone := make([]byte, len(value)) + copy(clone, value) + return BinaryKey(clone) +} + +func (k BinaryKey) SQLValue() interface{} { + return []byte(k) +} + +func (k BinaryKey) Compare(other PaginationKey) int { + otherKey, ok := other.(BinaryKey) + if !ok { + panic(fmt.Sprintf("type mismatch: cannot compare BinaryKey with %T", other)) + } + return bytes.Compare(k, otherKey) +} + +// NumericPosition calculates a rough float position. +func (k BinaryKey) NumericPosition() float64 { + if len(k) == 0 { + return 0.0 + } + + // Take up to the first 8 bytes to form a uint64 for estimation + var buf [8]byte + copy(buf[:], k) + + val := binary.BigEndian.Uint64(buf[:]) + return float64(val) +} + +func (k BinaryKey) String() string { + return hex.EncodeToString(k) +} + +func (k BinaryKey) IsMax() bool { + // We cannot know the true "Max" of a VARBINARY without knowing the length. + // However, for UUID(16), we can check for FF... + if len(k) == 0 { + return false + } + for _, b := range k { + if b != 0xFF { + return false + } + } + return true +} + +func (k BinaryKey) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(k)) +} + +type CompositeKey []PaginationKey + +func (k CompositeKey) SQLValue() interface{} { + values := make([]interface{}, len(k)) + for i, subKey := range k { + values[i] = subKey.SQLValue() + } + return values +} + +func (k CompositeKey) Compare(other PaginationKey) int { + otherKey, ok := other.(CompositeKey) + if !ok { + panic(fmt.Sprintf("type mismatch: cannot compare CompositeKey with %T", other)) + } + + if len(k) != len(otherKey) { + panic(fmt.Sprintf("length mismatch: %d vs %d", len(k), len(otherKey))) + } + + for i := range k { + cmp := k[i].Compare(otherKey[i]) + if cmp != 0 { + return cmp + } + } + + return 0 +} + +func (k CompositeKey) NumericPosition() float64 { + if len(k) == 0 { + return 0.0 + } + // Use the first key's position as a heuristic + return k[0].NumericPosition() +} + +func (k CompositeKey) String() string { + parts := make([]string, len(k)) + for i, subKey := range k { + parts[i] = subKey.String() + } + return strings.Join(parts, ",") +} + +func (k CompositeKey) IsMax() bool { + if len(k) == 0 { + return false + } + for _, subKey := range k { + if !subKey.IsMax() { + return false + } + } + return true +} + +func (k CompositeKey) MarshalJSON() ([]byte, error) { + encoded := make([]json.RawMessage, len(k)) + for i, subKey := range k { + b, err := MarshalPaginationKey(subKey) + if err != nil { + return nil, err + } + encoded[i] = b + } + return json.Marshal(encoded) +} + +type encodedKey struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} + +func MarshalPaginationKey(k PaginationKey) ([]byte, error) { + var typeName string + var valBytes []byte + var err error + + switch t := k.(type) { + case Uint64Key: + typeName = "uint64" + valBytes, err = t.MarshalJSON() + case BinaryKey: + typeName = "binary" + valBytes, err = t.MarshalJSON() + case CompositeKey: + typeName = "composite" + valBytes, err = t.MarshalJSON() + default: + return nil, fmt.Errorf("unknown pagination key type: %T", k) + } + + if err != nil { + return nil, err + } + + return json.Marshal(encodedKey{ + Type: typeName, + Value: valBytes, + }) +} + +func UnmarshalPaginationKey(data []byte) (PaginationKey, error) { + var wrapper encodedKey + if err := json.Unmarshal(data, &wrapper); err != nil { + return nil, err + } + + switch wrapper.Type { + case "uint64": + var i uint64 + if err := json.Unmarshal(wrapper.Value, &i); err != nil { + return nil, err + } + return NewUint64Key(i), nil + case "binary": + var s string + if err := json.Unmarshal(wrapper.Value, &s); err != nil { + return nil, err + } + b, err := hex.DecodeString(s) + if err != nil { + return nil, err + } + return NewBinaryKey(b), nil + case "composite": + var parts []json.RawMessage + if err := json.Unmarshal(wrapper.Value, &parts); err != nil { + return nil, err + } + keys := make([]PaginationKey, len(parts)) + for i, part := range parts { + pk, err := UnmarshalPaginationKey(part) + if err != nil { + return nil, err + } + keys[i] = pk + } + return CompositeKey(keys), nil + default: + return nil, fmt.Errorf("unknown key type: %s", wrapper.Type) + } +} + +// minPaginationKeyForColumn returns the minimum pagination key for a single column. +func minPaginationKeyForColumn(column *schema.TableColumn) PaginationKey { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + return NewUint64Key(0) + // Handle all potential binary/string types + case schema.TYPE_BINARY, schema.TYPE_STRING: + // The smallest value for any binary/string type is an empty slice. + // Even for fixed BINARY(N), starting at empty ensures we catch [0x00, ...] + return NewBinaryKey([]byte{}) + default: + return NewUint64Key(0) + } +} + +// maxPaginationKeyForColumn returns the maximum pagination key for a single column. +func maxPaginationKeyForColumn(column *schema.TableColumn) PaginationKey { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + return NewUint64Key(math.MaxUint64) + case schema.TYPE_BINARY, schema.TYPE_STRING: + // SAFETY: Cap the size to prevent OOM on LONGBLOB (4GB). + // InnoDB index limit is 3072 bytes. 4KB is a safe upper bound for a PK. + size := column.MaxSize + if size > 4096 { + size = 4096 + } + maxBytes := make([]byte, size) + for i := range maxBytes { + maxBytes[i] = 0xFF + } + return NewBinaryKey(maxBytes) + default: + return NewUint64Key(math.MaxUint64) + } +} + +// MinPaginationKey creates a minimum pagination key for the given columns. +// For single-column keys, returns a simple key type. For multi-column keys, returns a CompositeKey. +func MinPaginationKey(columns []*schema.TableColumn) PaginationKey { + if len(columns) == 0 { + return NewUint64Key(0) + } + if len(columns) == 1 { + return minPaginationKeyForColumn(columns[0]) + } + keys := make([]PaginationKey, len(columns)) + for i, col := range columns { + keys[i] = minPaginationKeyForColumn(col) + } + return CompositeKey(keys) +} + +// MaxPaginationKey creates a maximum pagination key for the given columns. +// For single-column keys, returns a simple key type. For multi-column keys, returns a CompositeKey. +func MaxPaginationKey(columns []*schema.TableColumn) PaginationKey { + if len(columns) == 0 { + return NewUint64Key(math.MaxUint64) + } + if len(columns) == 1 { + return maxPaginationKeyForColumn(columns[0]) + } + keys := make([]PaginationKey, len(columns)) + for i, col := range columns { + keys[i] = maxPaginationKeyForColumn(col) + } + return CompositeKey(keys) +} diff --git a/row_batch.go b/row_batch.go index 4426fc127..cf5cee053 100644 --- a/row_batch.go +++ b/row_batch.go @@ -6,19 +6,36 @@ import ( ) type RowBatch struct { - values []RowData - paginationKeyIndex int - table *TableSchema - fingerprints map[uint64][]byte - columns []string + values []RowData + paginationKeyIndex int // Deprecated: use paginationKeyIndexes + paginationKeyIndexes []int + table *TableSchema + fingerprints map[string][]byte + columns []string } +// Deprecated: Use NewRowBatchWithIndexes func NewRowBatch(table *TableSchema, values []RowData, paginationKeyIndex int) *RowBatch { return &RowBatch{ - values: values, - paginationKeyIndex: paginationKeyIndex, - table: table, - columns: ConvertTableColumnsToStrings(table.Columns), + values: values, + paginationKeyIndex: paginationKeyIndex, + paginationKeyIndexes: []int{paginationKeyIndex}, + table: table, + columns: ConvertTableColumnsToStrings(table.Columns), + } +} + +func NewRowBatchWithIndexes(table *TableSchema, values []RowData, paginationKeyIndexes []int) *RowBatch { + var legacyIndex int = -1 + if len(paginationKeyIndexes) > 0 { + legacyIndex = paginationKeyIndexes[0] + } + return &RowBatch{ + values: values, + paginationKeyIndex: legacyIndex, + paginationKeyIndexes: paginationKeyIndexes, + table: table, + columns: ConvertTableColumnsToStrings(table.Columns), } } @@ -39,12 +56,17 @@ func (e *RowBatch) EstimateByteSize() uint64 { return uint64(total) } +// Deprecated: Use PaginationKeyIndexes func (e *RowBatch) PaginationKeyIndex() int { return e.paginationKeyIndex } +func (e *RowBatch) PaginationKeyIndexes() []int { + return e.paginationKeyIndexes +} + func (e *RowBatch) ValuesContainPaginationKey() bool { - return e.paginationKeyIndex >= 0 + return len(e.paginationKeyIndexes) > 0 && e.paginationKeyIndexes[0] >= 0 } func (e *RowBatch) Size() int { @@ -55,7 +77,7 @@ func (e *RowBatch) TableSchema() *TableSchema { return e.table } -func (e *RowBatch) Fingerprints() map[uint64][]byte { +func (e *RowBatch) Fingerprints() map[string][]byte { return e.fingerprints } diff --git a/sharding/filter.go b/sharding/filter.go index 0f095b8ae..c5d17c9e5 100644 --- a/sharding/filter.go +++ b/sharding/filter.go @@ -33,7 +33,7 @@ type ShardedCopyFilter struct { missingShardingKeyIndexLogged sync.Map } -func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.TableSchema, lastPaginationKey, batchSize uint64) (sq.SelectBuilder, error) { +func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.TableSchema, lastPaginationKey ghostferry.PaginationKey, batchSize uint64) (sq.SelectBuilder, error) { quotedPaginationKey := "`" + table.GetPaginationColumn().Name + "`" quotedShardingKey := "`" + f.ShardingKey + "`" quotedTable := ghostferry.QuotedTableName(table) @@ -49,7 +49,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl return sq.Select(columns...). From(quotedTable + " USE INDEX (PRIMARY)"). Where(sq.Eq{quotedPaginationKey: f.ShardingValue}). - Where(sq.Gt{quotedPaginationKey: lastPaginationKey}), nil + Where(sq.Gt{quotedPaginationKey: lastPaginationKey.SQLValue()}), nil } joinTables, exists := f.JoinedTables[table.Name] @@ -90,7 +90,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl return sq.Select(columns...). From(quotedTable). - Join("("+selectPaginationKeys+") AS `batch` USING("+quotedPaginationKey+")", f.ShardingValue, lastPaginationKey), nil + Join("("+selectPaginationKeys+") AS `batch` USING("+quotedPaginationKey+")", f.ShardingValue, lastPaginationKey.SQLValue()), nil } // This is a "joined table". It is the only supported type of table that @@ -126,7 +126,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl pattern := "SELECT `%s` AS sharding_join_alias FROM `%s`.`%s` WHERE `%s` = ? AND `%s` > ?" sql := fmt.Sprintf(pattern, joinTable.JoinColumn, table.Schema, joinTable.TableName, f.ShardingKey, joinTable.JoinColumn) clauses = append(clauses, sql) - args = append(args, f.ShardingValue, lastPaginationKey) + args = append(args, f.ShardingValue, lastPaginationKey.SQLValue()) } subquery := strings.Join(clauses, " UNION DISTINCT ") diff --git a/sharding/test/copy_filter_test.go b/sharding/test/copy_filter_test.go index d0e4ebd42..dab0dfed4 100644 --- a/sharding/test/copy_filter_test.go +++ b/sharding/test/copy_filter_test.go @@ -18,7 +18,7 @@ type CopyFilterTestSuite struct { suite.Suite shardingValue int64 - paginationKeyCursor uint64 + paginationKeyCursor ghostferry.PaginationKey normalTable, normalTable2, joinedTable, primaryKeyTable *ghostferry.TableSchema @@ -27,7 +27,7 @@ type CopyFilterTestSuite struct { func (t *CopyFilterTestSuite) SetupTest() { t.shardingValue = int64(1) - t.paginationKeyCursor = uint64(12345) + t.paginationKeyCursor = ghostferry.NewUint64Key(12345) columns := []schema.TableColumn{{Name: "id"}, {Name: "tenant_id"}, {Name: "data"}} t.normalTable = &ghostferry.TableSchema{ @@ -105,7 +105,7 @@ func (t *CopyFilterTestSuite) TestSelectsRegularTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestFallsBackToLessGoodIndex() { @@ -116,7 +116,7 @@ func (t *CopyFilterTestSuite) TestFallsBackToLessGoodIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`less_good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestFallsBackToIgnoredPrimaryIndex() { @@ -128,7 +128,7 @@ func (t *CopyFilterTestSuite) TestFallsBackToIgnoredPrimaryIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` IGNORE INDEX (PRIMARY) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestRemovesIndexHint() { @@ -139,7 +139,7 @@ func (t *CopyFilterTestSuite) TestRemovesIndexHint() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestUsesForceIndex() { @@ -150,7 +150,7 @@ func (t *CopyFilterTestSuite) TestUsesForceIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestUsesIndexHintThatIsNotLowercased() { @@ -161,7 +161,7 @@ func (t *CopyFilterTestSuite) TestUsesIndexHintThatIsNotLowercased() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { @@ -179,7 +179,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -187,7 +187,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { @@ -205,7 +205,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -213,7 +213,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { @@ -236,7 +236,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -244,7 +244,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithIndexOnTable() { @@ -262,7 +262,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithIndexOnTable() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`less_good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestSelectsJoinedTables() { @@ -272,7 +272,7 @@ func (t *CopyFilterTestSuite) TestSelectsJoinedTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`joinedtable` WHERE `joined_paginationKey` IN (SELECT * FROM (SELECT `joined_paginationKey1` AS sharding_join_alias FROM `shard_1`.`join1` WHERE `tenant_id` = ? AND `joined_paginationKey1` > ? UNION DISTINCT SELECT `joined_paginationKey2` AS sharding_join_alias FROM `shard_1`.`join2` WHERE `tenant_id` = ? AND `joined_paginationKey2` > ? ORDER BY sharding_join_alias LIMIT 1024) AS sharding_join_table) ORDER BY `joined_paginationKey`", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor, t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue(), t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestSelectsPrimaryKeyTables() { @@ -282,7 +282,7 @@ func (t *CopyFilterTestSuite) TestSelectsPrimaryKeyTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`pkTable` USE INDEX (PRIMARY) WHERE `tenant_id` = ? AND `tenant_id` > ?", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestShardingValueTypes() { diff --git a/sharding/test/trivial_integration_test.go b/sharding/test/trivial_integration_test.go index 907753b9c..bd4c28982 100644 --- a/sharding/test/trivial_integration_test.go +++ b/sharding/test/trivial_integration_test.go @@ -2,6 +2,7 @@ package test import ( "math/rand" + "sync/atomic" "testing" sql "github.com/Shopify/ghostferry/sqlwrapper" @@ -83,6 +84,8 @@ func TestSelectiveCopyDataWithInsertLoadOnOtherTenants(t *testing.T) { } func TestSelectiveCopyDataWithInsertLoadOnAllTenants(t *testing.T) { + var firstInsert atomic.Bool + testcase := &testhelpers.IntegrationTestCase{ T: t, Ferry: selectiveFerry(int64(2)), @@ -93,7 +96,11 @@ func TestSelectiveCopyDataWithInsertLoadOnAllTenants(t *testing.T) { Tables: []string{"gftest.table1"}, ExtraInsertData: func(tableName string, vals map[string]interface{}) { - vals["tenant_id"] = rand.Intn(3) + if firstInsert.CompareAndSwap(false, true) { + vals["tenant_id"] = 2 + } else { + vals["tenant_id"] = rand.Intn(3) + } }, }, } diff --git a/state_tracker.go b/state_tracker.go index 760481a80..ba5226f21 100644 --- a/state_tracker.go +++ b/state_tracker.go @@ -2,7 +2,7 @@ package ghostferry import ( "container/ring" - "math" + "encoding/json" "sync" "time" @@ -34,7 +34,7 @@ type SerializableState struct { GhostferryVersion string LastKnownTableSchemaCache TableSchemaCache - LastSuccessfulPaginationKeys map[string]uint64 + LastSuccessfulPaginationKeys map[string]PaginationKey CompletedTables map[string]bool LastWrittenBinlogPosition mysql.Position BinlogVerifyStore BinlogVerifySerializedStore @@ -42,6 +42,53 @@ type SerializableState struct { LastStoredBinlogPositionForTargetVerifier mysql.Position } +func (s *SerializableState) MarshalJSON() ([]byte, error) { + // Create an alias to avoid infinite recursion, but change the map type + type Alias SerializableState + aux := &struct { + LastSuccessfulPaginationKeys map[string]json.RawMessage + *Alias + }{ + Alias: (*Alias)(s), + LastSuccessfulPaginationKeys: make(map[string]json.RawMessage), + } + + for k, v := range s.LastSuccessfulPaginationKeys { + b, err := MarshalPaginationKey(v) + if err != nil { + return nil, err + } + aux.LastSuccessfulPaginationKeys[k] = b + } + + return json.Marshal(aux) +} + +func (s *SerializableState) UnmarshalJSON(data []byte) error { + type Alias SerializableState + aux := &struct { + LastSuccessfulPaginationKeys map[string]json.RawMessage + *Alias + }{ + Alias: (*Alias)(s), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + s.LastSuccessfulPaginationKeys = make(map[string]PaginationKey) + for k, v := range aux.LastSuccessfulPaginationKeys { + pk, err := UnmarshalPaginationKey(v) + if err != nil { + return err + } + s.LastSuccessfulPaginationKeys[k] = pk + } + + return nil +} + func (s *SerializableState) MinSourceBinlogPosition() mysql.Position { nilPosition := mysql.Position{} if s.LastWrittenBinlogPosition == nilPosition { @@ -61,7 +108,7 @@ func (s *SerializableState) MinSourceBinlogPosition() mysql.Position { // For tracking the speed of the copy type PaginationKeyPositionLog struct { - Position uint64 + Position float64 At time.Time } @@ -92,7 +139,7 @@ type StateTracker struct { lastStoredBinlogPositionForInlineVerifier mysql.Position lastStoredBinlogPositionForTargetVerifier mysql.Position - lastSuccessfulPaginationKeys map[string]uint64 + lastSuccessfulPaginationKeys map[string]PaginationKey completedTables map[string]bool // TODO: Performance tracking should be refactored out of the state tracker, @@ -106,7 +153,7 @@ func NewStateTracker(speedLogCount int) *StateTracker { BinlogRWMutex: &sync.RWMutex{}, CopyRWMutex: &sync.RWMutex{}, - lastSuccessfulPaginationKeys: make(map[string]uint64), + lastSuccessfulPaginationKeys: make(map[string]PaginationKey), completedTables: make(map[string]bool), iterationSpeedLog: newSpeedLogRing(speedLogCount), rowStatsWrittenPerTable: make(map[string]RowStats), @@ -146,11 +193,16 @@ func (s *StateTracker) UpdateLastResumableBinlogPositionForTargetVerifier(pos my s.lastStoredBinlogPositionForTargetVerifier = pos } -func (s *StateTracker) UpdateLastSuccessfulPaginationKey(table string, paginationKey uint64, rowStats RowStats) { +func (s *StateTracker) UpdateLastSuccessfulPaginationKey(table string, paginationKey PaginationKey, rowStats RowStats) { s.CopyRWMutex.Lock() defer s.CopyRWMutex.Unlock() - deltaPaginationKey := paginationKey - s.lastSuccessfulPaginationKeys[table] + var deltaPaginationKey float64 + if lastKey, exists := s.lastSuccessfulPaginationKeys[table]; exists { + deltaPaginationKey = paginationKey.NumericPosition() - lastKey.NumericPosition() + } else { + deltaPaginationKey = paginationKey.NumericPosition() + } s.lastSuccessfulPaginationKeys[table] = paginationKey // TODO: this code is intentionally left here so it is kind of crappy and @@ -174,18 +226,18 @@ func (s *StateTracker) RowStatsWrittenPerTable() map[string]RowStats { return d } -func (s *StateTracker) LastSuccessfulPaginationKey(table string) uint64 { +func (s *StateTracker) LastSuccessfulPaginationKey(table string, tableSchema *TableSchema) PaginationKey { s.CopyRWMutex.RLock() defer s.CopyRWMutex.RUnlock() _, found := s.completedTables[table] if found { - return math.MaxUint64 + return MaxPaginationKey(tableSchema.GetPaginationColumns()) } paginationKey, found := s.lastSuccessfulPaginationKeys[table] if !found { - return 0 + return MinPaginationKey(tableSchema.GetPaginationColumns()) } return paginationKey @@ -240,7 +292,7 @@ func (s *StateTracker) updateRowStatsForTable(table string, rowStats RowStats) { } } -func (s *StateTracker) updateSpeedLog(deltaPaginationKey uint64) { +func (s *StateTracker) updateSpeedLog(deltaPaginationKey float64) { if s.iterationSpeedLog == nil { return } @@ -263,7 +315,7 @@ func (s *StateTracker) Serialize(lastKnownTableSchemaCache TableSchemaCache, bin state := &SerializableState{ GhostferryVersion: VersionString, LastKnownTableSchemaCache: lastKnownTableSchemaCache, - LastSuccessfulPaginationKeys: make(map[string]uint64), + LastSuccessfulPaginationKeys: make(map[string]PaginationKey), CompletedTables: make(map[string]bool), LastWrittenBinlogPosition: s.lastWrittenBinlogPosition, LastStoredBinlogPositionForInlineVerifier: s.lastStoredBinlogPositionForInlineVerifier, diff --git a/table_schema_cache.go b/table_schema_cache.go index ca5b1df81..01a18b2ff 100644 --- a/table_schema_cache.go +++ b/table_schema_cache.go @@ -40,8 +40,10 @@ type TableSchema struct { CompressedColumnsForVerification map[string]string // Map of column name => compression type IgnoredColumnsForVerification map[string]struct{} // Set of column name ForcedIndexForVerification string // Forced index name - PaginationKeyColumn *schema.TableColumn - PaginationKeyIndex int + PaginationKeyColumn *schema.TableColumn // Deprecated: Use PaginationKeyColumns + PaginationKeyIndex int // Deprecated: Use PaginationKeyIndexes + PaginationKeyColumns []*schema.TableColumn + PaginationKeyIndexes []int rowMd5Query string } @@ -61,26 +63,66 @@ type TableSchema struct { func (t *TableSchema) FingerprintQuery(schemaName, tableName string, numRows int) string { var forceIndex string - columnsToSelect := make([]string, 2+len(t.CompressedColumnsForVerification)) - columnsToSelect[0] = QuoteField(t.GetPaginationColumn().Name) - columnsToSelect[1] = t.RowMd5Query() - i := 2 - for columnName, _ := range t.CompressedColumnsForVerification { - columnsToSelect[i] = QuoteField(columnName) - i += 1 + // Construct the column list. + // Start with pagination key columns. + paginationCols := t.GetPaginationColumns() + columnsToSelect := make([]string, 0, len(paginationCols)+1+len(t.CompressedColumnsForVerification)) + for _, col := range paginationCols { + columnsToSelect = append(columnsToSelect, QuoteField(col.Name)) + } + + // Add the MD5 hash column + columnsToSelect = append(columnsToSelect, t.RowMd5Query()) + + // Add compressed columns + for columnName := range t.CompressedColumnsForVerification { + columnsToSelect = append(columnsToSelect, QuoteField(columnName)) } if t.ForcedIndexForVerification != "" { forceIndex = fmt.Sprintf(" FORCE INDEX (%s)", t.ForcedIndexForVerification) } + // Build the WHERE clause + var whereClause string + if len(paginationCols) == 1 { + // Single column: WHERE `id` IN (?,?,?) + colName := QuoteField(paginationCols[0].Name) + placeholders := make([]string, numRows) + for i := range placeholders { + placeholders[i] = "?" + } + whereClause = fmt.Sprintf("WHERE %s IN (%s)", colName, strings.Join(placeholders, ",")) + } else { + // Composite key: WHERE (col1, col2) IN ((?,?), (?,?)) + pkColsQuoted := make([]string, len(paginationCols)) + for i, col := range paginationCols { + pkColsQuoted[i] = QuoteField(col.Name) + } + pkTuple := fmt.Sprintf("(%s)", strings.Join(pkColsQuoted, ",")) + + // Build placeholders: (?,?) + placeholders := make([]string, len(paginationCols)) + for i := range placeholders { + placeholders[i] = "?" + } + tuplePlaceholder := fmt.Sprintf("(%s)", strings.Join(placeholders, ",")) + + // Repeat tuple placeholders for numRows + allPlaceholders := make([]string, numRows) + for i := range allPlaceholders { + allPlaceholders[i] = tuplePlaceholder + } + + whereClause = fmt.Sprintf("WHERE %s IN (%s)", pkTuple, strings.Join(allPlaceholders, ",")) + } + return fmt.Sprintf( - "SELECT %s FROM %s%s WHERE %s IN (%s)", + "SELECT %s FROM %s%s %s", strings.Join(columnsToSelect, ","), QuotedTableNameFromString(schemaName, tableName), forceIndex, - columnsToSelect[0], - strings.Repeat("?,", numRows-1)+"?", + whereClause, ) } @@ -126,8 +168,8 @@ func QuotedTableNameFromString(database, table string) string { return fmt.Sprintf("`%s`.`%s`", database, table) } -func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (map[*TableSchema]uint64, []*TableSchema, error) { - tablesWithData := make(map[*TableSchema]uint64) +func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (map[*TableSchema]PaginationKey, []*TableSchema, error) { + tablesWithData := make(map[*TableSchema]PaginationKey) emptyTables := make([]*TableSchema, 0, len(tables)) for _, table := range tables { @@ -135,7 +177,7 @@ func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) maxPaginationKey, maxPaginationKeyExists, err := maxPaginationKey(db, table) if err != nil { - logger.WithError(err).Errorf("failed to get max primary key %s", table.GetPaginationColumn().Name) + logger.WithError(err).Errorf("failed to get max primary key for %s", table.String()) return tablesWithData, emptyTables, err } @@ -216,13 +258,19 @@ func LoadTables(db *sql.DB, tableFilter TableFilter, columnCompressionConfig Col tableLog := dbLog.WithField("table", tableName) tableLog.Debug("caching table schema") - paginationKeyColumn, paginationKeyIndex, err := tableSchema.paginationKeyColumn(cascadingPaginationColumnConfig) + paginationKeyColumns, paginationKeyIndexes, err := tableSchema.getPaginationKeyColumns(cascadingPaginationColumnConfig) if err != nil { logger.WithError(err).Error("invalid table") return tableSchemaCache, err } - tableSchema.PaginationKeyColumn = paginationKeyColumn - tableSchema.PaginationKeyIndex = paginationKeyIndex + tableSchema.PaginationKeyColumns = paginationKeyColumns + tableSchema.PaginationKeyIndexes = paginationKeyIndexes + + // Backwards compatibility + if len(paginationKeyColumns) > 0 { + tableSchema.PaginationKeyColumn = paginationKeyColumns[0] + tableSchema.PaginationKeyIndex = paginationKeyIndexes[0] + } tableSchemaCache[tableSchema.String()] = tableSchema } @@ -257,42 +305,105 @@ func NonNumericPaginationKeyError(schema, table, paginationKey string) error { return fmt.Errorf("Pagination Key `%s` for %s is non-numeric", paginationKey, QuotedTableNameFromString(schema, table)) } -func (t *TableSchema) paginationKeyColumn(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) (*schema.TableColumn, int, error) { +func (t *TableSchema) getPaginationKeyColumns(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) ([]*schema.TableColumn, []int, error) { var err error - var paginationKeyColumn *schema.TableColumn - var paginationKeyIndex int - - if paginationColumn, found := cascadingPaginationColumnConfig.PaginationColumnFor(t.Schema, t.Name); found { - // Use per-schema, per-table pagination key from config - paginationKeyColumn, paginationKeyIndex, err = t.findColumnByName(paginationColumn) - } else if len(t.PKColumns) == 1 { - // Use Primary Key - paginationKeyIndex = t.PKColumns[0] - paginationKeyColumn = &t.Columns[paginationKeyIndex] - } else if fallbackColumnName, found := cascadingPaginationColumnConfig.FallbackPaginationColumnName(); found { - // Try fallback from config - paginationKeyColumn, paginationKeyIndex, err = t.findColumnByName(fallbackColumnName) + var paginationKeyColumns []*schema.TableColumn + var paginationKeyIndexes []int + + var paginationColumnStr string + found := false + + if cascadingPaginationColumnConfig != nil { + paginationColumnStr, found = cascadingPaginationColumnConfig.PaginationColumnFor(t.Schema, t.Name) + } + + if found { + // Configured + cols := strings.Split(paginationColumnStr, ",") + for _, colName := range cols { + colName = strings.TrimSpace(colName) + col, idx, e := t.findColumnByName(colName) + if e != nil { + return nil, nil, e + } + paginationKeyColumns = append(paginationKeyColumns, col) + paginationKeyIndexes = append(paginationKeyIndexes, idx) + } + } else if len(t.PKColumns) > 0 { + // Default to PK + for _, idx := range t.PKColumns { + paginationKeyIndexes = append(paginationKeyIndexes, idx) + paginationKeyColumns = append(paginationKeyColumns, &t.Columns[idx]) + } + } else if cascadingPaginationColumnConfig != nil { + // Fallback + if fallbackColumnName, ok := cascadingPaginationColumnConfig.FallbackPaginationColumnName(); ok { + cols := strings.Split(fallbackColumnName, ",") + for _, colName := range cols { + colName = strings.TrimSpace(colName) + col, idx, e := t.findColumnByName(colName) + if e != nil { + return nil, nil, e + } + paginationKeyColumns = append(paginationKeyColumns, col) + paginationKeyIndexes = append(paginationKeyIndexes, idx) + } + } else { + err = NonExistingPaginationKeyError(t.Schema, t.Name) + } } else { - // No usable pagination key found err = NonExistingPaginationKeyError(t.Schema, t.Name) } + + if err != nil { + return nil, nil, err + } + + // Validate types + for _, col := range paginationKeyColumns { + isNumber := col.Type == schema.TYPE_NUMBER || col.Type == schema.TYPE_MEDIUM_INT + isBinary := col.Type == schema.TYPE_BINARY || col.Type == schema.TYPE_STRING - if paginationKeyColumn != nil && paginationKeyColumn.Type != schema.TYPE_NUMBER && paginationKeyColumn.Type != schema.TYPE_MEDIUM_INT { - return nil, -1, NonNumericPaginationKeyError(t.Schema, t.Name, paginationKeyColumn.Name) + if !isNumber && !isBinary { + return nil, nil, NonNumericPaginationKeyError(t.Schema, t.Name, col.Name) + } } - return paginationKeyColumn, paginationKeyIndex, err + return paginationKeyColumns, paginationKeyIndexes, nil } +// Deprecated: Use getPaginationKeyColumns +func (t *TableSchema) paginationKeyColumn(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) (*schema.TableColumn, int, error) { + cols, idxs, err := t.getPaginationKeyColumns(cascadingPaginationColumnConfig) + if err != nil { + return nil, -1, err + } + if len(cols) == 0 { + return nil, -1, NonExistingPaginationKeyError(t.Schema, t.Name) + } + return cols[0], idxs[0], nil +} + + // GetPaginationColumn retrieves PaginationKeyColumn +// Deprecated: Use GetPaginationColumns func (t *TableSchema) GetPaginationColumn() *schema.TableColumn { return t.PaginationKeyColumn } +func (t *TableSchema) GetPaginationColumns() []*schema.TableColumn { + return t.PaginationKeyColumns +} + +// Deprecated: Use GetPaginationKeyIndexes func (t *TableSchema) GetPaginationKeyIndex() int { return t.PaginationKeyIndex } +func (t *TableSchema) GetPaginationKeyIndexes() []int { + return t.PaginationKeyIndexes +} + func (c TableSchemaCache) AsSlice() (tables []*TableSchema) { for _, tableSchema := range c { tables = append(tables, tableSchema) @@ -398,29 +509,89 @@ func showTablesFrom(c *sql.DB, dbname string) ([]string, error) { return tables, nil } -func maxPaginationKey(db *sql.DB, table *TableSchema) (uint64, bool, error) { - primaryKeyColumn := table.GetPaginationColumn() - paginationKeyName := QuoteField(primaryKeyColumn.Name) +func maxPaginationKey(db *sql.DB, table *TableSchema) (PaginationKey, bool, error) { + primaryKeyColumns := table.GetPaginationColumns() + if len(primaryKeyColumns) == 0 { + return nil, false, fmt.Errorf("no pagination key columns for table %s", table.String()) + } + + pkNames := make([]string, len(primaryKeyColumns)) + orderByClauses := make([]string, len(primaryKeyColumns)) + for i, col := range primaryKeyColumns { + quotedName := QuoteField(col.Name) + pkNames[i] = quotedName + orderByClauses[i] = fmt.Sprintf("%s DESC", quotedName) + } + query, args, err := sq. - Select(paginationKeyName). + Select(pkNames...). From(QuotedTableName(table)). - OrderBy(fmt.Sprintf("%s DESC", paginationKeyName)). + OrderBy(strings.Join(orderByClauses, ", ")). Limit(1). ToSql() if err != nil { - return 0, false, err + return nil, false, err } - var maxPaginationKey uint64 - err = db.QueryRow(query, args...).Scan(&maxPaginationKey) + scanArgs := make([]interface{}, len(primaryKeyColumns)) + // We need temp variables to hold scanned values + values := make([]interface{}, len(primaryKeyColumns)) + + for i, col := range primaryKeyColumns { + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + var v uint64 + values[i] = &v + scanArgs[i] = &v + case schema.TYPE_BINARY, schema.TYPE_STRING: + // Use interface{} for flexbility (bytes or string) + var v interface{} + values[i] = &v + scanArgs[i] = &v + default: + var v uint64 + values[i] = &v + scanArgs[i] = &v + } + } - switch { - case err == sqlorig.ErrNoRows: - return 0, false, nil - case err != nil: - return 0, false, err - default: - return maxPaginationKey, true, nil + err = db.QueryRow(query, args...).Scan(scanArgs...) + if err != nil { + if err == sqlorig.ErrNoRows { + return nil, false, nil + } + return nil, false, err } + + // Now convert scanned values to PaginationKey + keys := make([]PaginationKey, len(primaryKeyColumns)) + for i, col := range primaryKeyColumns { + switch col.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + val := *(values[i].(*uint64)) + keys[i] = NewUint64Key(val) + case schema.TYPE_BINARY, schema.TYPE_STRING: + val := *(values[i].(*interface{})) + var binValue []byte + switch v := val.(type) { + case []byte: + binValue = v + case string: + binValue = []byte(v) + default: + return nil, false, fmt.Errorf("expected binary/string for max key column %s, got %T", col.Name, val) + } + keys[i] = NewBinaryKey(binValue) + default: + val := *(values[i].(*uint64)) + keys[i] = NewUint64Key(val) + } + } + + if len(keys) == 1 { + return keys[0], true, nil + } + return CompositeKey(keys), true, nil } + diff --git a/target_verifier.go b/target_verifier.go index 1ffe99eb0..2ce7fa04b 100644 --- a/target_verifier.go +++ b/target_verifier.go @@ -42,7 +42,7 @@ func (t *TargetVerifier) BinlogEventListener(evs []DMLEvent) error { if err != nil { return err } - return fmt.Errorf("row data with paginationKey %d on `%s`.`%s` has been corrupted by a change directly performed in the target at binlog file: %s and position: %d", paginationKey, ev.Database(), ev.Table(), ev.BinlogPosition().Name, ev.BinlogPosition().Pos) + return fmt.Errorf("row data with paginationKey %s on `%s`.`%s` has been corrupted by a change directly performed in the target at binlog file: %s and position: %d", paginationKey, ev.Database(), ev.Table(), ev.BinlogPosition().Name, ev.BinlogPosition().Pos) } } diff --git a/test/go/data_iterator_sorter_test.go b/test/go/data_iterator_sorter_test.go index 8a32f9a50..cd45abd49 100644 --- a/test/go/data_iterator_sorter_test.go +++ b/test/go/data_iterator_sorter_test.go @@ -6,6 +6,7 @@ import ( "sync" "testing" + "github.com/go-mysql-org/go-mysql/schema" "github.com/stretchr/testify/suite" "github.com/Shopify/ghostferry" @@ -13,12 +14,12 @@ import ( ) const ( - TestDB1 = "gftest2" - TestDB2 = "gftest3" - TestDB3 = "gftest4" - TestTableDB1 = "test_db_2" - TestTableDB2 = "test_db_3" - TestTableDB3 = "test_db_4" + TestDB1 = "gftest2" + TestDB2 = "gftest3" + TestDB3 = "gftest4" + TestTableDB1 = "test_db_2" + TestTableDB2 = "test_db_3" + TestTableDB3 = "test_db_4" ) var TestDBs = []string{TestDB1, TestDB2, TestDB3} @@ -32,8 +33,8 @@ var DBTableMap = map[string]string{ type DataIteratorSorterTestSuite struct { *testhelpers.GhostferryUnitTestSuite - unsortedTables map[*ghostferry.TableSchema]uint64 - dataIterator *ghostferry.DataIterator + unsortedTables map[*ghostferry.TableSchema]ghostferry.PaginationKey + dataIterator *ghostferry.DataIterator } func (t *DataIteratorSorterTestSuite) SetupTest() { @@ -48,17 +49,17 @@ func (t *DataIteratorSorterTestSuite) SetupTest() { } tables, _ := ghostferry.LoadTables(t.Ferry.SourceDB, tableFilter, nil, nil, nil, nil) - t.unsortedTables = make(map[*ghostferry.TableSchema]uint64, len(tables)) + t.unsortedTables = make(map[*ghostferry.TableSchema]ghostferry.PaginationKey, len(tables)) i := 0 - for _,f := range tables.AsSlice() { + for _, f := range tables.AsSlice() { maxPaginationKey := uint64(100_000 - i) - t.unsortedTables[f] = maxPaginationKey + t.unsortedTables[f] = ghostferry.NewUint64Key(maxPaginationKey) i++ } t.dataIterator = &ghostferry.DataIterator{ - DB: t.Ferry.SourceDB, - ErrorHandler: t.Ferry.ErrorHandler, + DB: t.Ferry.SourceDB, + ErrorHandler: t.Ferry.ErrorHandler, TargetPaginationKeys: &sync.Map{}, } } @@ -83,7 +84,7 @@ func (t *DataIteratorSorterTestSuite) TestOrderMaxPaginationKeys() { copy(expectedTables, sortedTables) sort.Slice(expectedTables, func(i, j int) bool { - return sortedTables[i].MaxPaginationKey > sortedTables[j].MaxPaginationKey + return sortedTables[i].MaxPaginationKey.Compare(sortedTables[j].MaxPaginationKey) > 0 }) t.Require().Equal(len(t.unsortedTables), len(sortedTables)) @@ -91,6 +92,31 @@ func (t *DataIteratorSorterTestSuite) TestOrderMaxPaginationKeys() { } +func (t *DataIteratorSorterTestSuite) TestOrderMaxPaginationKeysMixedTypes() { + sorter := ghostferry.MaxPaginationKeySorter{} + + // Create a mix of keys: Uint64Key, BinaryKey, CompositeKey + mixedTables := map[*ghostferry.TableSchema]ghostferry.PaginationKey{ + {Table: &schema.Table{Schema: "test", Name: "uint_table"}}: ghostferry.NewUint64Key(100), + {Table: &schema.Table{Schema: "test", Name: "binary_table"}}: ghostferry.NewBinaryKey([]byte("abc")), + {Table: &schema.Table{Schema: "test", Name: "composite_table"}}: ghostferry.CompositeKey{ghostferry.NewUint64Key(1)}, + } + + // This should not panic + sortedTables, err := sorter.Sort(mixedTables) + t.Require().Nil(err) + t.Require().Equal(3, len(sortedTables)) + + // Verify we have one of each + typesSeen := make(map[string]bool) + for _, item := range sortedTables { + typesSeen[fmt.Sprintf("%T", item.MaxPaginationKey)] = true + } + t.Require().True(typesSeen["ghostferry.Uint64Key"]) + t.Require().True(typesSeen["ghostferry.BinaryKey"]) + t.Require().True(typesSeen["ghostferry.CompositeKey"]) +} + func (t *DataIteratorSorterTestSuite) TestOrderByInformationSchemaTableSize() { // information_schemas.table does not update automatically on every write diff --git a/test/go/data_iterator_test.go b/test/go/data_iterator_test.go index f30d0cfb2..f1b19b06f 100644 --- a/test/go/data_iterator_test.go +++ b/test/go/data_iterator_test.go @@ -56,7 +56,7 @@ func (this *DataIteratorTestSuite) SetupTest() { BatchSizePerTableOverride: config.DataIterationBatchSizePerTableOverride, ReadRetries: config.DBReadRetries, }, - StateTracker: ghostferry.NewStateTracker(config.DataIterationConcurrency * 10), + StateTracker: ghostferry.NewStateTracker(config.DataIterationConcurrency * 10), TargetPaginationKeys: &sync.Map{}, } @@ -273,6 +273,246 @@ func (this *DataIteratorTestSuite) TestDataIterationBatchSizePerTableOverrideCal } } +func (this *DataIteratorTestSuite) TestCompositeKeyIterationOrder() { + // Create a table with a composite primary key (tenant_id, id) + // We use this structure to test that iteration respects the lexicographical order + // i.e., (1, 100) < (2, 1) + dbName := testhelpers.TestSchemaName + tableName := "composite_key_test" + + // Drop if exists + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", dbName, tableName)) + this.Require().Nil(err) + + // Create table + query := fmt.Sprintf("CREATE TABLE %s.%s (tenant_id int, id int, data varchar(255), PRIMARY KEY (tenant_id, id))", dbName, tableName) + _, err = this.Ferry.SourceDB.Exec(query) + this.Require().Nil(err) + + // Insert data in a scrambled order to ensure the DB returns them sorted by PK during iteration + values := []struct { + tenant_id int + id int + data string + }{ + {2, 1, "t2-id1"}, + {1, 10, "t1-id10"}, + {1, 1, "t1-id1"}, + {3, 1, "t3-id1"}, + {1, 2, "t1-id2"}, + {2, 5, "t2-id5"}, + } + + for _, v := range values { + _, err = this.Ferry.SourceDB.Exec(fmt.Sprintf("INSERT INTO %s.%s (tenant_id, id, data) VALUES (?, ?, ?)", dbName, tableName), v.tenant_id, v.id, v.data) + this.Require().Nil(err) + } + + // Configure Ghostferry to use this table + filter := &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{dbName}), + TablesFunc: nil, + } + + tables, err := ghostferry.LoadTables(this.Ferry.SourceDB, filter, nil, nil, nil, nil) + this.Require().Nil(err) + + targetTable := tables.Get(dbName, tableName) + this.Require().NotNil(targetTable) + this.Require().Equal(2, len(targetTable.GetPaginationColumns())) + + // Setup DataIterator with a batch size of 2 to force multiple batches and pagination logic + batchSize := uint64(2) + dataIterator := &ghostferry.DataIterator{ + DB: this.Ferry.SourceDB, + Concurrency: 1, + ErrorHandler: this.Ferry.ErrorHandler, + CursorConfig: &ghostferry.CursorConfig{ + DB: this.Ferry.SourceDB, + BatchSize: &batchSize, + }, + StateTracker: ghostferry.NewStateTracker(10), + TargetPaginationKeys: &sync.Map{}, + } + + // Collect rows + var collectedRows []struct{ t, i int } + dataIterator.AddBatchListener(func(batch *ghostferry.RowBatch) error { + vals := batch.Values() + for _, row := range vals { + tID, _ := row.GetUint64(0) + id, _ := row.GetUint64(1) + collectedRows = append(collectedRows, struct{ t, i int }{int(tID), int(id)}) + } + return nil + }) + + // Run iterator + dataIterator.Run([]*ghostferry.TableSchema{targetTable}) + + // Verify order: (1,1), (1,2), (1,10), (2,1), (2,5), (3,1) + expected := []struct{ t, i int }{ + {1, 1}, + {1, 2}, + {1, 10}, + {2, 1}, + {2, 5}, + {3, 1}, + } + + this.Require().Equal(len(expected), len(collectedRows)) + for i := range expected { + this.Require().Equal(expected[i], collectedRows[i], "Mismatch at index %d", i) + } +} + +func (this *DataIteratorTestSuite) TestCompositeKeyWithBinaryType() { + // Test mixing types: (tenant_id int, uuid varbinary(16)) + dbName := testhelpers.TestSchemaName + tableName := "composite_binary_test" + + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", dbName, tableName)) + this.Require().Nil(err) + + query := fmt.Sprintf("CREATE TABLE %s.%s (tenant_id int, uuid varbinary(16), data varchar(255), PRIMARY KEY (tenant_id, uuid))", dbName, tableName) + _, err = this.Ferry.SourceDB.Exec(query) + this.Require().Nil(err) + + // Insert data: (1, 'A'), (1, 'B'), (2, 'A') + _, err = this.Ferry.SourceDB.Exec(fmt.Sprintf("INSERT INTO %s.%s VALUES (1, 'A', 'd1'), (2, 'A', 'd2'), (1, 'B', 'd3')", dbName, tableName)) + this.Require().Nil(err) + + tables, err := ghostferry.LoadTables(this.Ferry.SourceDB, + &testhelpers.TestTableFilter{DbsFunc: testhelpers.DbApplicabilityFilter([]string{dbName})}, + nil, nil, nil, nil) + this.Require().Nil(err) + targetTable := tables.Get(dbName, tableName) + + batchSize := uint64(1) // Force batching per row + dataIterator := &ghostferry.DataIterator{ + DB: this.Ferry.SourceDB, + Concurrency: 1, + ErrorHandler: this.Ferry.ErrorHandler, + CursorConfig: &ghostferry.CursorConfig{ + DB: this.Ferry.SourceDB, + BatchSize: &batchSize, + }, + StateTracker: ghostferry.NewStateTracker(10), + TargetPaginationKeys: &sync.Map{}, + } + + var collectedRows []string + dataIterator.AddBatchListener(func(batch *ghostferry.RowBatch) error { + for _, row := range batch.Values() { + tID, _ := row.GetUint64(0) + uuid := row[1].([]byte) + collectedRows = append(collectedRows, fmt.Sprintf("%d-%s", tID, string(uuid))) + } + return nil + }) + + dataIterator.Run([]*ghostferry.TableSchema{targetTable}) + + expected := []string{"1-A", "1-B", "2-A"} + this.Require().Equal(expected, collectedRows) +} + +func (this *DataIteratorTestSuite) TestThreeColumnCompositeKeyIterationOrder() { + // Test iteration with a 3-column composite primary key (region_id, tenant_id, id) + dbName := testhelpers.TestSchemaName + tableName := "three_col_composite_test" + + // Drop if exists + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", dbName, tableName)) + this.Require().Nil(err) + + // Create table + query := fmt.Sprintf("CREATE TABLE %s.%s (region_id int, tenant_id int, id int, data varchar(255), PRIMARY KEY (region_id, tenant_id, id))", dbName, tableName) + _, err = this.Ferry.SourceDB.Exec(query) + this.Require().Nil(err) + + // Insert data in scrambled order to test sorting + values := []struct { + region_id int + tenant_id int + id int + data string + }{ + {2, 1, 1, "r2-t1-id1"}, + {1, 2, 3, "r1-t2-id3"}, + {1, 1, 1, "r1-t1-id1"}, + {1, 1, 2, "r1-t1-id2"}, + {1, 2, 1, "r1-t2-id1"}, + {2, 1, 2, "r2-t1-id2"}, + {1, 1, 3, "r1-t1-id3"}, + {2, 2, 1, "r2-t2-id1"}, + } + + for _, v := range values { + _, err = this.Ferry.SourceDB.Exec(fmt.Sprintf("INSERT INTO %s.%s (region_id, tenant_id, id, data) VALUES (?, ?, ?, ?)", dbName, tableName), v.region_id, v.tenant_id, v.id, v.data) + this.Require().Nil(err) + } + + // Configure Ghostferry + filter := &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{dbName}), + TablesFunc: nil, + } + + tables, err := ghostferry.LoadTables(this.Ferry.SourceDB, filter, nil, nil, nil, nil) + this.Require().Nil(err) + + targetTable := tables.Get(dbName, tableName) + this.Require().NotNil(targetTable) + this.Require().Equal(3, len(targetTable.GetPaginationColumns())) + + // Setup DataIterator with a batch size of 2 + batchSize := uint64(2) + dataIterator := &ghostferry.DataIterator{ + DB: this.Ferry.SourceDB, + Concurrency: 1, + ErrorHandler: this.Ferry.ErrorHandler, + CursorConfig: &ghostferry.CursorConfig{ + DB: this.Ferry.SourceDB, + BatchSize: &batchSize, + }, + StateTracker: ghostferry.NewStateTracker(10), + TargetPaginationKeys: &sync.Map{}, + } + + // Collect rows + var collectedRows []struct{ r, t, i int } + dataIterator.AddBatchListener(func(batch *ghostferry.RowBatch) error { + vals := batch.Values() + for _, row := range vals { + rID, _ := row.GetUint64(0) + tID, _ := row.GetUint64(1) + id, _ := row.GetUint64(2) + collectedRows = append(collectedRows, struct{ r, t, i int }{int(rID), int(tID), int(id)}) + } + return nil + }) + + dataIterator.Run([]*ghostferry.TableSchema{targetTable}) + + // Verify order: (1,1,1), (1,1,2), (1,1,3), (1,2,1), (1,2,3), (2,1,1), (2,1,2), (2,2,1) + expected := []struct{ r, t, i int }{ + {1, 1, 1}, + {1, 1, 2}, + {1, 1, 3}, + {1, 2, 1}, + {1, 2, 3}, + {2, 1, 1}, + {2, 1, 2}, + {2, 2, 1}, + } + + this.Require().Equal(len(expected), len(collectedRows)) + for i := range expected { + this.Require().Equal(expected[i], collectedRows[i], "Mismatch at index %d", i) + } +} + func TestDataIterator(t *testing.T) { testhelpers.SetupTest() suite.Run(t, &DataIteratorTestSuite{GhostferryUnitTestSuite: &testhelpers.GhostferryUnitTestSuite{}}) diff --git a/test/go/inline_verifier_test.go b/test/go/inline_verifier_test.go index ee97e3ea5..d707d583f 100644 --- a/test/go/inline_verifier_test.go +++ b/test/go/inline_verifier_test.go @@ -9,18 +9,18 @@ import ( func newMockBinlogVerifySerializedStore() ghostferry.BinlogVerifySerializedStore { s := make(ghostferry.BinlogVerifySerializedStore) - s["db"] = map[string]map[uint64]int{ - "table1": map[uint64]int{ - 3: 1, - 10: 2, - 30: 3, + s["db"] = map[string]map[string]int{ + "table1": map[string]int{ + "3": 1, + "10": 2, + "30": 3, }, } - s["db2"] = map[string]map[uint64]int{ - "table2": map[uint64]int{ - 4: 1, - 20: 2, - 40: 1, + s["db2"] = map[string]map[string]int{ + "table2": map[string]int{ + "4": 1, + "20": 2, + "40": 1, }, } return s @@ -39,7 +39,7 @@ func TestBinlogVerifySerializedStoreCopy(t *testing.T) { s := newMockBinlogVerifySerializedStore() s2 := s.Copy() - s2["db"]["table1"][3] += 1 + s2["db"]["table1"]["3"] += 1 r.Equal(uint64(10), s.RowCount()) r.Equal(uint64(11), s2.RowCount()) diff --git a/test/go/iterative_verifier_integration_test.go b/test/go/iterative_verifier_integration_test.go index c0e087635..70b747657 100644 --- a/test/go/iterative_verifier_integration_test.go +++ b/test/go/iterative_verifier_integration_test.go @@ -14,8 +14,12 @@ import ( func TestHashesSql(t *testing.T) { columns := []schema.TableColumn{schema.TableColumn{Name: "id"}, schema.TableColumn{Name: "data"}, schema.TableColumn{Name: "float_col", Type: schema.TYPE_FLOAT}} paginationKeys := []uint64{1, 5, 42} + paginationKeysInterface := make([]interface{}, len(paginationKeys)) + for i, pk := range paginationKeys { + paginationKeysInterface[i] = pk + } - sql, args, err := ghostferry.GetMd5HashesSql("gftest", "test_table", "id", columns, paginationKeys) + sql, args, err := ghostferry.GetMd5HashesSql("gftest", "test_table", []*schema.TableColumn{&columns[0]}, columns, paginationKeysInterface) assert.Nil(t, err) assert.Equal(t, "SELECT `id`, MD5(CONCAT(MD5(COALESCE(`id`, 'NULL')),MD5(COALESCE(`data`, 'NULL')),MD5(COALESCE((if (`float_col` = '-0', 0, `float_col`)), 'NULL')))) "+ @@ -292,3 +296,105 @@ func deleteTestRowsToTriggerFailure(ferry *testhelpers.TestFerry) { _, err = ferry.Ferry.TargetDB.Exec("DELETE FROM gftest.table1 WHERE id = \"43\"") testhelpers.PanicIfError(err) } + +func TestHashesSqlWithCompositeKey(t *testing.T) { + columns := []schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "id", Type: schema.TYPE_NUMBER}, + {Name: "data"}, + } + // Composite keys as comma-separated strings: "1,10", "2,20" + paginationKeysInterface := []interface{}{"1,10", "2,20"} + + sql, args, err := ghostferry.GetMd5HashesSql( + "gftest", "test_table", + []*schema.TableColumn{&columns[0], &columns[1]}, + columns, + paginationKeysInterface, + ) + + assert.Nil(t, err) + assert.Equal(t, + "SELECT `tenant_id`, `id`, MD5(CONCAT(MD5(COALESCE(`tenant_id`, 'NULL')),MD5(COALESCE(`id`, 'NULL')),MD5(COALESCE(`data`, 'NULL')))) "+ + "AS row_fingerprint FROM `gftest`.`test_table` WHERE (`tenant_id`, `id`) IN ((?, ?), (?, ?)) ORDER BY `tenant_id`, `id`", + sql) + assert.Equal(t, []interface{}{uint64(1), uint64(10), uint64(2), uint64(20)}, args) +} + +func TestHashesSqlWithThreeColumnCompositeKey(t *testing.T) { + columns := []schema.TableColumn{ + {Name: "region_id", Type: schema.TYPE_NUMBER}, + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "id", Type: schema.TYPE_NUMBER}, + {Name: "data"}, + } + // Composite keys as comma-separated strings: "1,2,10", "1,3,20" + paginationKeysInterface := []interface{}{"1,2,10", "1,3,20"} + + sql, args, err := ghostferry.GetMd5HashesSql( + "gftest", "test_table", + []*schema.TableColumn{&columns[0], &columns[1], &columns[2]}, + columns, + paginationKeysInterface, + ) + + assert.Nil(t, err) + assert.Equal(t, + "SELECT `region_id`, `tenant_id`, `id`, MD5(CONCAT(MD5(COALESCE(`region_id`, 'NULL')),MD5(COALESCE(`tenant_id`, 'NULL')),MD5(COALESCE(`id`, 'NULL')),MD5(COALESCE(`data`, 'NULL')))) "+ + "AS row_fingerprint FROM `gftest`.`test_table` WHERE (`region_id`, `tenant_id`, `id`) IN ((?, ?, ?), (?, ?, ?)) ORDER BY `region_id`, `tenant_id`, `id`", + sql) + assert.Equal(t, []interface{}{uint64(1), uint64(2), uint64(10), uint64(1), uint64(3), uint64(20)}, args) +} + +func TestVerificationWithCompositeKeyDetectsMismatch(t *testing.T) { + ferry := testhelpers.NewTestFerry() + // Filter to only composite_table_2 to avoid mixing key types with table1 + ferry.Config.TableFilter = &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{"gftest"}), + TablesFunc: func(tables []*ghostferry.TableSchema) []*ghostferry.TableSchema { + for _, t := range tables { + if t.Name == "composite_table_2" { + return []*ghostferry.TableSchema{t} + } + } + return nil + }, + } + + iterativeVerifier := &ghostferry.IterativeVerifier{} + ran := false + + testcase := &testhelpers.IntegrationTestCase{ + T: t, + SetupAction: setupCompositeKeyTableDatabase(2), + AfterRowCopyIsComplete: func(ferry *testhelpers.TestFerry, sourceDB, targetDB *sql.DB) { + setupIterativeVerifierFromFerry(iterativeVerifier, ferry.Ferry) + + err := iterativeVerifier.Initialize() + testhelpers.PanicIfError(err) + + err = iterativeVerifier.VerifyBeforeCutover() + testhelpers.PanicIfError(err) + }, + BeforeStoppingBinlogStreaming: func(ferry *testhelpers.TestFerry, sourceDB, targetDB *sql.DB) { + _, err := sourceDB.Exec("INSERT INTO gftest.composite_table_2 VALUES (1, 1, 'test') ON DUPLICATE KEY UPDATE data = 'reverify'") + testhelpers.PanicIfError(err) + }, + AfterStoppedBinlogStreaming: func(ferry *testhelpers.TestFerry, sourceDB, targetDB *sql.DB) { + // Modify target data to create mismatch + _, err := targetDB.Exec("UPDATE gftest.composite_table_2 SET data = 'MISMATCH' WHERE k1 = 1 AND k2 = 1") + testhelpers.PanicIfError(err) + + result, err := iterativeVerifier.VerifyDuringCutover() + assert.Nil(t, err) + assert.False(t, result.DataCorrect) + assert.Contains(t, result.Message, "composite_table_2") + ran = true + }, + Ferry: ferry, + DisableChecksumVerifier: true, + } + + testcase.Run() + assert.True(t, ran) +} diff --git a/test/go/iterative_verifier_test.go b/test/go/iterative_verifier_test.go index f48008d5e..487b24fb9 100644 --- a/test/go/iterative_verifier_test.go +++ b/test/go/iterative_verifier_test.go @@ -3,6 +3,7 @@ package test import ( "fmt" "sort" + "strconv" "testing" "time" @@ -31,7 +32,7 @@ func (t *IterativeVerifierTestSuite) SetupTest() { tableCompressions[testhelpers.TestCompressedTable1Name] = make(map[string]string) tableCompressions[testhelpers.TestCompressedTable1Name][testhelpers.TestCompressedColumn1Name] = ghostferry.CompressionSnappy - compressionVerifier, err := ghostferry.NewCompressionVerifier(tableCompressions) + compressionVerifier, err := ghostferry.NewCompressionVerifier(tableCompressions, t.Ferry.Tables) if err != nil { t.FailNow(err.Error()) } @@ -223,13 +224,13 @@ func (t *IterativeVerifierTestSuite) TestChangingDataChangesHash() { func (t *IterativeVerifierTestSuite) TestDeduplicatesHashes() { t.InsertRow(42, "foo") - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []uint64{42, 42}) + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []interface{}{uint64(42), uint64(42)}) t.Require().Nil(err) t.Require().Equal(1, len(hashes)) } func (t *IterativeVerifierTestSuite) TestDoesntReturnHashIfRecordDoesntExist() { - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []uint64{42, 42}) + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []interface{}{uint64(42), uint64(42)}) t.Require().Nil(err) t.Require().Equal(0, len(hashes)) } @@ -347,14 +348,20 @@ func (t *IterativeVerifierTestSuite) DeleteRow(id int) { } func (t *IterativeVerifierTestSuite) GetHashes(ids []uint64) []string { - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, ids) + paginationKeys := make([]interface{}, len(ids)) + for i, id := range ids { + paginationKeys[i] = id + } + + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, paginationKeys) t.Require().Nil(err) t.Require().Equal(len(hashes), len(ids)) res := make([]string, len(ids)) for idx, id := range ids { - hash, ok := hashes[id] + paginationKeyStr := ghostferry.NewUint64Key(id).String() + hash, ok := hashes[paginationKeyStr] t.Require().True(ok) t.Require().True(len(hash) > 0) @@ -376,6 +383,9 @@ func (t *IterativeVerifierTestSuite) reloadTables() { t.Ferry.Tables = tables t.verifier.Tables = tables.AsSlice() t.verifier.TableSchemaCache = tables + if t.verifier.CompressionVerifier != nil { + t.verifier.CompressionVerifier.TableSchemaCache = tables + } t.table = tables.Get(testhelpers.TestSchemaName, testhelpers.TestTable1Name) t.Require().NotNil(t.table) @@ -394,19 +404,21 @@ func (t *ReverifyStoreTestSuite) SetupTest() { func (t *ReverifyStoreTestSuite) TestAddEntryIntoReverifyStoreWillDeduplicate() { paginationKey1 := uint64(100) paginationKey2 := uint64(101) + paginationKey1Str := ghostferry.NewUint64Key(paginationKey1).String() + paginationKey2Str := ghostferry.NewUint64Key(paginationKey2).String() table1 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table1"}} - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2Str, Table: table1}) t.Require().Equal(uint64(2), t.store.RowCount) t.Require().Equal(1, len(t.store.MapStore)) t.Require().Equal( - map[uint64]struct{}{ - paginationKey1: struct{}{}, - paginationKey2: struct{}{}, + map[string]struct{}{ + paginationKey1Str: struct{}{}, + paginationKey2Str: struct{}{}, }, t.store.MapStore[ghostferry.TableIdentifier{"gftest", "table1"}], ) @@ -417,13 +429,15 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch table1 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table1"}} table2 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table2"}} for i := uint64(100); i < 155; i++ { - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: i, Table: table1}) + paginationKeyStr := ghostferry.NewUint64Key(i).String() + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKeyStr, Table: table1}) expectedTable1PaginationKeys = append(expectedTable1PaginationKeys, i) } expectedTable2PaginationKeys := make([]uint64, 0, 45) for i := uint64(200); i < 245; i++ { - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: i, Table: table2}) + paginationKeyStr := ghostferry.NewUint64Key(i).String() + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKeyStr, Table: table2}) expectedTable2PaginationKeys = append(expectedTable2PaginationKeys, i) } @@ -446,8 +460,11 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch actualTable1PaginationKeys := make([]uint64, 0) for _, batch := range table1Batches { - for _, paginationKey := range batch.PaginationKeys { - actualTable1PaginationKeys = append(actualTable1PaginationKeys, paginationKey) + for _, paginationKeyInterface := range batch.PaginationKeys { + paginationKeyStr := paginationKeyInterface.(string) + paginationKeyUint, err := strconv.ParseUint(paginationKeyStr, 10, 64) + t.Require().Nil(err) + actualTable1PaginationKeys = append(actualTable1PaginationKeys, paginationKeyUint) } } @@ -456,8 +473,11 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch actualTable2PaginationKeys := make([]uint64, 0) for _, batch := range table2Batches { - for _, paginationKey := range batch.PaginationKeys { - actualTable2PaginationKeys = append(actualTable2PaginationKeys, paginationKey) + for _, paginationKeyInterface := range batch.PaginationKeys { + paginationKeyStr := paginationKeyInterface.(string) + paginationKeyUint, err := strconv.ParseUint(paginationKeyStr, 10, 64) + t.Require().Nil(err) + actualTable2PaginationKeys = append(actualTable2PaginationKeys, paginationKeyUint) } } diff --git a/test/go/pagination_key_test.go b/test/go/pagination_key_test.go new file mode 100644 index 000000000..4a2bf7955 --- /dev/null +++ b/test/go/pagination_key_test.go @@ -0,0 +1,1284 @@ +package test + +import ( + "encoding/hex" + "encoding/json" + "math" + "testing" + + "github.com/Shopify/ghostferry" + "github.com/go-mysql-org/go-mysql/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUint64Key_SQLValue(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + assert.Equal(t, uint64(12345), key.SQLValue()) +} + +func TestUint64Key_Compare(t *testing.T) { + tests := []struct { + name string + key1 ghostferry.Uint64Key + key2 ghostferry.Uint64Key + expected int + }{ + {"less than", ghostferry.NewUint64Key(100), ghostferry.NewUint64Key(200), -1}, + {"equal", ghostferry.NewUint64Key(100), ghostferry.NewUint64Key(100), 0}, + {"greater than", ghostferry.NewUint64Key(200), ghostferry.NewUint64Key(100), 1}, + {"zero vs non-zero", ghostferry.NewUint64Key(0), ghostferry.NewUint64Key(1), -1}, + {"max uint64", ghostferry.NewUint64Key(math.MaxUint64), ghostferry.NewUint64Key(math.MaxUint64 - 1), 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Compare(tt.key2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUint64Key_ComparePanicsOnTypeMismatch(t *testing.T) { + key1 := ghostferry.NewUint64Key(100) + key2 := ghostferry.NewBinaryKey([]byte{0x01, 0x02}) + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestUint64Key_NumericPosition(t *testing.T) { + tests := []struct { + value uint64 + expected float64 + }{ + {0, 0.0}, + {100, 100.0}, + {math.MaxUint64, float64(math.MaxUint64)}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + key := ghostferry.NewUint64Key(tt.value) + assert.Equal(t, tt.expected, key.NumericPosition()) + }) + } +} + +func TestUint64Key_String(t *testing.T) { + tests := []struct { + value uint64 + expected string + }{ + {0, "0"}, + {12345, "12345"}, + {math.MaxUint64, "18446744073709551615"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + key := ghostferry.NewUint64Key(tt.value) + assert.Equal(t, tt.expected, key.String()) + }) + } +} + +func TestUint64Key_IsMax(t *testing.T) { + assert.True(t, ghostferry.NewUint64Key(math.MaxUint64).IsMax()) + assert.False(t, ghostferry.NewUint64Key(math.MaxUint64-1).IsMax()) + assert.False(t, ghostferry.NewUint64Key(0).IsMax()) +} + +func TestUint64Key_MarshalJSON(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + data, err := key.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, "12345", string(data)) +} + +func TestBinaryKey_NewBinaryKeyClones(t *testing.T) { + original := []byte{0x01, 0x02, 0x03} + key := ghostferry.NewBinaryKey(original) + + original[0] = 0xFF + + assert.Equal(t, []byte{0x01, 0x02, 0x03}, []byte(key)) +} + +func TestBinaryKey_SQLValue(t *testing.T) { + original := []byte{0x01, 0x02, 0x03} + key := ghostferry.NewBinaryKey(original) + assert.Equal(t, original, key.SQLValue()) +} + +func TestBinaryKey_Compare(t *testing.T) { + tests := []struct { + name string + key1 ghostferry.BinaryKey + key2 ghostferry.BinaryKey + expected int + }{ + { + "less than", + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + ghostferry.NewBinaryKey([]byte{0x01, 0x03}), + -1, + }, + { + "equal", + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + 0, + }, + { + "greater than", + ghostferry.NewBinaryKey([]byte{0x02, 0x01}), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + 1, + }, + { + "empty vs non-empty", + ghostferry.NewBinaryKey([]byte{}), + ghostferry.NewBinaryKey([]byte{0x01}), + -1, + }, + { + "different lengths", + ghostferry.NewBinaryKey([]byte{0x01}), + ghostferry.NewBinaryKey([]byte{0x01, 0x00}), + -1, + }, + { + "UUID comparison", + ghostferry.NewBinaryKey([]byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x05, + }), + ghostferry.NewBinaryKey([]byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x06, + }), + -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Compare(tt.key2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBinaryKey_ComparePanicsOnTypeMismatch(t *testing.T) { + key1 := ghostferry.NewBinaryKey([]byte{0x01, 0x02}) + key2 := ghostferry.NewUint64Key(100) + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestBinaryKey_NumericPosition(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected float64 + }{ + { + "empty", + []byte{}, + 0.0, + }, + { + "single byte", + []byte{0x01}, + float64(0x0100000000000000), + }, + { + "8 bytes", + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + float64(0x0102030405060708), + }, + { + "more than 8 bytes uses first 8", + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a}, + float64(0x0102030405060708), + }, + { + "UUIDv7 timestamp ordering", + []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + float64(0x018f3e4c5a6b0000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.NumericPosition()) + }) + } +} + +func TestBinaryKey_NumericPosition_Monotonic(t *testing.T) { + key1 := ghostferry.NewBinaryKey([]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + key2 := ghostferry.NewBinaryKey([]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + + assert.True(t, key1.NumericPosition() < key2.NumericPosition()) +} + +func TestBinaryKey_String(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected string + }{ + {"empty", []byte{}, ""}, + {"single byte", []byte{0x01}, "01"}, + {"multiple bytes", []byte{0x01, 0x02, 0x03}, "010203"}, + {"UUID", []byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x05, + }, "018f3e4c5a6b7c8d9eafb0c1d2e3f405"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.String()) + }) + } +} + +func TestBinaryKey_IsMax(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected bool + }{ + {"empty is not max", []byte{}, false}, + {"all FF is max", []byte{0xFF, 0xFF, 0xFF, 0xFF}, true}, + {"UUID(16) all FF is max", []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, true}, + {"one non-FF byte is not max", []byte{0xFF, 0xFE, 0xFF, 0xFF}, false}, + {"zero is not max", []byte{0x00}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.IsMax()) + }) + } +} + +func TestBinaryKey_MarshalJSON(t *testing.T) { + key := ghostferry.NewBinaryKey([]byte{0x01, 0x02, 0x03}) + data, err := key.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `"010203"`, string(data)) +} + +func TestMarshalPaginationKey_Uint64(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "uint64", result["type"]) + assert.Equal(t, float64(12345), result["value"]) +} + +func TestMarshalPaginationKey_Binary(t *testing.T) { + key := ghostferry.NewBinaryKey([]byte{0x01, 0x02, 0x03}) + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "binary", result["type"]) + assert.Equal(t, "010203", result["value"]) +} + +func TestUnmarshalPaginationKey_Uint64(t *testing.T) { + data := []byte(`{"type":"uint64","value":12345}`) + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + uint64Key, ok := key.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(12345), uint64(uint64Key)) +} + +func TestUnmarshalPaginationKey_Binary(t *testing.T) { + data := []byte(`{"type":"binary","value":"010203"}`) + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + binaryKey, ok := key.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{0x01, 0x02, 0x03}, []byte(binaryKey)) +} + +func TestUnmarshalPaginationKey_InvalidType(t *testing.T) { + data := []byte(`{"type":"invalid","value":"something"}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown key type") +} + +func TestUnmarshalPaginationKey_InvalidJSON(t *testing.T) { + data := []byte(`{invalid json}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) +} + +func TestUnmarshalPaginationKey_InvalidBinaryHex(t *testing.T) { + data := []byte(`{"type":"binary","value":"ZZZZ"}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) +} + +func TestPaginationKey_RoundTrip_Uint64(t *testing.T) { + original := ghostferry.NewUint64Key(98765) + + marshaled, err := ghostferry.MarshalPaginationKey(original) + require.NoError(t, err) + + unmarshaled, err := ghostferry.UnmarshalPaginationKey(marshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestPaginationKey_RoundTrip_Binary(t *testing.T) { + original := ghostferry.NewBinaryKey([]byte{0xDE, 0xAD, 0xBE, 0xEF}) + + marshaled, err := ghostferry.MarshalPaginationKey(original) + require.NoError(t, err) + + unmarshaled, err := ghostferry.UnmarshalPaginationKey(marshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestMinPaginationKey_Numeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_NUMBER, + } + + minKey := ghostferry.MinPaginationKey([]*schema.TableColumn{column}) + uint64Key, ok := minKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(uint64Key)) +} + +func TestMinPaginationKey_MediumInt(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_MEDIUM_INT, + } + + minKey := ghostferry.MinPaginationKey([]*schema.TableColumn{column}) + uint64Key, ok := minKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(uint64Key)) +} + +func TestMinPaginationKey_Binary(t *testing.T) { + column := &schema.TableColumn{ + Name: "uuid", + Type: schema.TYPE_BINARY, + } + + minKey := ghostferry.MinPaginationKey([]*schema.TableColumn{column}) + binaryKey, ok := minKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(binaryKey)) +} + +func TestMinPaginationKey_String(t *testing.T) { + column := &schema.TableColumn{ + Name: "key", + Type: schema.TYPE_STRING, + } + + minKey := ghostferry.MinPaginationKey([]*schema.TableColumn{column}) + binaryKey, ok := minKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(binaryKey)) +} + +func TestMaxPaginationKey_Numeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_NUMBER, + } + + maxKey := ghostferry.MaxPaginationKey([]*schema.TableColumn{column}) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestMaxPaginationKey_MediumInt(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_MEDIUM_INT, + } + + maxKey := ghostferry.MaxPaginationKey([]*schema.TableColumn{column}) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestMaxPaginationKey_Binary_UUID16(t *testing.T) { + column := &schema.TableColumn{ + Name: "uuid", + Type: schema.TYPE_BINARY, + MaxSize: 16, + } + + maxKey := ghostferry.MaxPaginationKey([]*schema.TableColumn{column}) + binaryKey, ok := maxKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, 16, len(binaryKey)) + + for _, b := range binaryKey { + assert.Equal(t, byte(0xFF), b) + } + assert.True(t, binaryKey.IsMax()) +} + +func TestMaxPaginationKey_Binary_LargeSize(t *testing.T) { + column := &schema.TableColumn{ + Name: "large", + Type: schema.TYPE_STRING, + MaxSize: 100000, + } + + maxKey := ghostferry.MaxPaginationKey([]*schema.TableColumn{column}) + binaryKey, ok := maxKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, 4096, len(binaryKey)) +} + +func TestMaxPaginationKey_DefaultToNumeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: 999, + } + + maxKey := ghostferry.MaxPaginationKey([]*schema.TableColumn{column}) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestPaginationKey_CrossTypeComparison_UUIDv7Ordering(t *testing.T) { + uuidBytes1, _ := hex.DecodeString("018f3e4c5a6b7c8d9eafb0c1d2e3f405") + uuidBytes2, _ := hex.DecodeString("018f3e4c5a6c7c8d9eafb0c1d2e3f405") + uuidBytes3, _ := hex.DecodeString("018f3e4c5b6b7c8d9eafb0c1d2e3f405") + + key1 := ghostferry.NewBinaryKey(uuidBytes1) + key2 := ghostferry.NewBinaryKey(uuidBytes2) + key3 := ghostferry.NewBinaryKey(uuidBytes3) + + assert.Equal(t, -1, key1.Compare(key2)) + assert.Equal(t, -1, key1.Compare(key3)) + assert.Equal(t, -1, key2.Compare(key3)) + + assert.True(t, key1.NumericPosition() < key2.NumericPosition()) + assert.True(t, key2.NumericPosition() < key3.NumericPosition()) +} + +// CompositeKey Tests + +func TestCompositeKey_SQLValue(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(123), + ghostferry.NewUint64Key(456), + } + + values := key.SQLValue().([]interface{}) + assert.Equal(t, 2, len(values)) + assert.Equal(t, uint64(123), values[0]) + assert.Equal(t, uint64(456), values[1]) +} + +func TestCompositeKey_SQLValue_Mixed(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + } + + values := key.SQLValue().([]interface{}) + assert.Equal(t, 2, len(values)) + assert.Equal(t, uint64(100), values[0]) + assert.Equal(t, []byte{0x01, 0x02}, values[1]) +} + +func TestCompositeKey_Compare(t *testing.T) { + tests := []struct { + name string + key1 ghostferry.CompositeKey + key2 ghostferry.CompositeKey + expected int + }{ + { + "equal composite keys", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + }, + 0, + }, + { + "less than - first element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(101), + ghostferry.NewUint64Key(200), + }, + -1, + }, + { + "less than - second element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(201), + }, + -1, + }, + { + "greater than - first element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(200), + ghostferry.NewUint64Key(100), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(100), + }, + 1, + }, + { + "greater than - second element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(300), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + }, + 1, + }, + { + "three element composite - middle differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + ghostferry.NewUint64Key(300), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(199), + ghostferry.NewUint64Key(999), + }, + 1, + }, + { + "three element composite - last element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(3), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(4), + }, + -1, + }, + { + "three element composite - equal", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(10), + ghostferry.NewUint64Key(20), + ghostferry.NewUint64Key(30), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(10), + ghostferry.NewUint64Key(20), + ghostferry.NewUint64Key(30), + }, + 0, + }, + { + "four element composite - third element differs", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(3), + ghostferry.NewUint64Key(4), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(999), + }, + 1, + }, + { + "three element mixed types", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + ghostferry.NewUint64Key(300), + }, + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewBinaryKey([]byte{0x01, 0x01}), + ghostferry.NewUint64Key(999), + }, + 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Compare(tt.key2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCompositeKey_ComparePanicsOnTypeMismatch(t *testing.T) { + key1 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + } + key2 := ghostferry.NewUint64Key(100) + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestCompositeKey_ComparePanicsOnLengthMismatch(t *testing.T) { + key1 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + } + key2 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + } + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestCompositeKey_NumericPosition(t *testing.T) { + tests := []struct { + name string + key ghostferry.CompositeKey + expected float64 + }{ + { + "uses first element for position", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(12345), + ghostferry.NewUint64Key(67890), + }, + 12345.0, + }, + { + "empty composite key", + ghostferry.CompositeKey{}, + 0.0, + }, + { + "binary first element", + ghostferry.CompositeKey{ + ghostferry.NewBinaryKey([]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + ghostferry.NewUint64Key(999), + }, + float64(0x0100000000000000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.key.NumericPosition()) + }) + } +} + +func TestCompositeKey_String(t *testing.T) { + tests := []struct { + name string + key ghostferry.CompositeKey + expected string + }{ + { + "two uint64 keys", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(123), + ghostferry.NewUint64Key(456), + }, + "123,456", + }, + { + "mixed types", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewBinaryKey([]byte{0xAB, 0xCD}), + }, + "100,abcd", + }, + { + "three elements", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(3), + }, + "1,2,3", + }, + { + "four elements", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(10), + ghostferry.NewUint64Key(20), + ghostferry.NewUint64Key(30), + ghostferry.NewUint64Key(40), + }, + "10,20,30,40", + }, + { + "four elements mixed types", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewBinaryKey([]byte{0xAB}), + ghostferry.NewUint64Key(2), + ghostferry.NewBinaryKey([]byte{0xCD, 0xEF}), + }, + "1,ab,2,cdef", + }, + { + "empty composite key", + ghostferry.CompositeKey{}, + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.key.String()) + }) + } +} + +func TestCompositeKey_IsMax(t *testing.T) { + tests := []struct { + name string + key ghostferry.CompositeKey + expected bool + }{ + { + "all elements are max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(math.MaxUint64), + }, + true, + }, + { + "first element not max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(math.MaxUint64), + }, + false, + }, + { + "second element not max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(100), + }, + false, + }, + { + "empty composite key", + ghostferry.CompositeKey{}, + false, + }, + { + "mixed types all max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewBinaryKey([]byte{0xFF, 0xFF, 0xFF, 0xFF}), + }, + true, + }, + { + "three elements all max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(math.MaxUint64), + }, + true, + }, + { + "three elements third not max", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewUint64Key(100), + }, + false, + }, + { + "four elements all max mixed types", + ghostferry.CompositeKey{ + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewBinaryKey([]byte{0xFF, 0xFF}), + ghostferry.NewUint64Key(math.MaxUint64), + ghostferry.NewBinaryKey([]byte{0xFF}), + }, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.key.IsMax()) + }) + } +} + +func TestCompositeKey_MarshalJSON(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(123), + ghostferry.NewUint64Key(456), + } + + data, err := key.MarshalJSON() + require.NoError(t, err) + + // Should be an array of encoded keys + var decoded []json.RawMessage + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + assert.Equal(t, 2, len(decoded)) +} + +func TestCompositeKey_MarshalJSON_ThreeElements(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewUint64Key(2), + ghostferry.NewUint64Key(3), + } + + data, err := key.MarshalJSON() + require.NoError(t, err) + + var decoded []json.RawMessage + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + assert.Equal(t, 3, len(decoded)) +} + +func TestCompositeKey_MarshalJSON_FourElementsMixed(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewBinaryKey([]byte{0xAB, 0xCD}), + ghostferry.NewUint64Key(2), + ghostferry.NewBinaryKey([]byte{0xEF}), + } + + data, err := key.MarshalJSON() + require.NoError(t, err) + + var decoded []json.RawMessage + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + assert.Equal(t, 4, len(decoded)) +} + +func TestCompositeKey_SQLValue_ThreeElements(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + ghostferry.NewUint64Key(300), + } + + sqlVal := key.SQLValue() + values, ok := sqlVal.([]interface{}) + require.True(t, ok) + require.Equal(t, 3, len(values)) + assert.Equal(t, uint64(100), values[0]) + assert.Equal(t, uint64(200), values[1]) + assert.Equal(t, uint64(300), values[2]) +} + +func TestCompositeKey_SQLValue_FourElementsMixed(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1), + ghostferry.NewBinaryKey([]byte{0xAB, 0xCD}), + ghostferry.NewUint64Key(2), + ghostferry.NewBinaryKey([]byte{0xEF}), + } + + sqlVal := key.SQLValue() + values, ok := sqlVal.([]interface{}) + require.True(t, ok) + require.Equal(t, 4, len(values)) + assert.Equal(t, uint64(1), values[0]) + assert.Equal(t, []byte{0xAB, 0xCD}, values[1]) + assert.Equal(t, uint64(2), values[2]) + assert.Equal(t, []byte{0xEF}, values[3]) +} + +func TestMarshalPaginationKey_Composite(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + } + + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "composite", result["type"]) + + // Value should be an array + value, ok := result["value"].([]interface{}) + require.True(t, ok) + assert.Equal(t, 2, len(value)) +} + +func TestMarshalPaginationKey_CompositeThreeElements(t *testing.T) { + key := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(10), + ghostferry.NewUint64Key(20), + ghostferry.NewUint64Key(30), + } + + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "composite", result["type"]) + + value, ok := result["value"].([]interface{}) + require.True(t, ok) + assert.Equal(t, 3, len(value)) +} + +func TestUnmarshalPaginationKey_Composite(t *testing.T) { + // Manually construct a composite key JSON + data := []byte(`{"type":"composite","value":[{"type":"uint64","value":100},{"type":"uint64","value":200}]}`) + + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + compositeKey, ok := key.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 2, len(compositeKey)) + + // Check first element + uint64Key1, ok := compositeKey[0].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(100), uint64(uint64Key1)) + + // Check second element + uint64Key2, ok := compositeKey[1].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(200), uint64(uint64Key2)) +} + +func TestUnmarshalPaginationKey_CompositeMixed(t *testing.T) { + // Composite key with mixed types + data := []byte(`{"type":"composite","value":[{"type":"uint64","value":100},{"type":"binary","value":"abcd"}]}`) + + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + compositeKey, ok := key.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 2, len(compositeKey)) + + // Check first element (uint64) + uint64Key, ok := compositeKey[0].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(100), uint64(uint64Key)) + + // Check second element (binary) + binaryKey, ok := compositeKey[1].(ghostferry.BinaryKey) + require.True(t, ok) + expectedBytes, _ := hex.DecodeString("abcd") + assert.Equal(t, expectedBytes, []byte(binaryKey)) +} + +func TestUnmarshalPaginationKey_CompositeThreeElements(t *testing.T) { + data := []byte(`{"type":"composite","value":[{"type":"uint64","value":1},{"type":"uint64","value":2},{"type":"uint64","value":3}]}`) + + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + compositeKey, ok := key.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 3, len(compositeKey)) + + assert.Equal(t, uint64(1), uint64(compositeKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(2), uint64(compositeKey[1].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(3), uint64(compositeKey[2].(ghostferry.Uint64Key))) +} + +func TestUnmarshalPaginationKey_CompositeFourElementsMixed(t *testing.T) { + data := []byte(`{"type":"composite","value":[{"type":"uint64","value":10},{"type":"binary","value":"ab"},{"type":"uint64","value":20},{"type":"binary","value":"cd"}]}`) + + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + compositeKey, ok := key.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 4, len(compositeKey)) + + assert.Equal(t, uint64(10), uint64(compositeKey[0].(ghostferry.Uint64Key))) + expectedBytes1, _ := hex.DecodeString("ab") + assert.Equal(t, expectedBytes1, []byte(compositeKey[1].(ghostferry.BinaryKey))) + assert.Equal(t, uint64(20), uint64(compositeKey[2].(ghostferry.Uint64Key))) + expectedBytes2, _ := hex.DecodeString("cd") + assert.Equal(t, expectedBytes2, []byte(compositeKey[3].(ghostferry.BinaryKey))) +} + +func TestPaginationKey_RoundTrip_Composite(t *testing.T) { + original := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(12345), + ghostferry.NewBinaryKey([]byte{0xDE, 0xAD, 0xBE, 0xEF}), + ghostferry.NewUint64Key(67890), + } + + marshaled, err := ghostferry.MarshalPaginationKey(original) + require.NoError(t, err) + + unmarshaled, err := ghostferry.UnmarshalPaginationKey(marshaled) + require.NoError(t, err) + + compositeKey, ok := unmarshaled.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 3, len(compositeKey)) + + // Verify each element matches + assert.Equal(t, original[0], compositeKey[0]) + assert.Equal(t, original[1], compositeKey[1]) + assert.Equal(t, original[2], compositeKey[2]) +} + +func TestMinCompositePaginationKey(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "user_id", Type: schema.TYPE_NUMBER}, + } + + minKey := ghostferry.MinPaginationKey(columns) + compositeKey, ok := minKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 2, len(compositeKey)) + + // Both should be zero + assert.Equal(t, uint64(0), uint64(compositeKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(0), uint64(compositeKey[1].(ghostferry.Uint64Key))) +} + +func TestMinCompositePaginationKey_SingleColumn(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "id", Type: schema.TYPE_NUMBER}, + } + + minKey := ghostferry.MinPaginationKey(columns) + // Single column should return Uint64Key, not CompositeKey + uint64Key, ok := minKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(uint64Key)) +} + +func TestMaxCompositePaginationKey(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "user_id", Type: schema.TYPE_NUMBER}, + } + + maxKey := ghostferry.MaxPaginationKey(columns) + compositeKey, ok := maxKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 2, len(compositeKey)) + + // Both should be max uint64 + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[1].(ghostferry.Uint64Key))) + assert.True(t, compositeKey.IsMax()) +} + +func TestMaxCompositePaginationKey_SingleColumn(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "id", Type: schema.TYPE_NUMBER}, + } + + maxKey := ghostferry.MaxPaginationKey(columns) + // Single column should return Uint64Key, not CompositeKey + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestMinCompositePaginationKey_ThreeColumns(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "user_id", Type: schema.TYPE_NUMBER}, + {Name: "order_id", Type: schema.TYPE_NUMBER}, + } + + minKey := ghostferry.MinPaginationKey(columns) + compositeKey, ok := minKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 3, len(compositeKey)) + + assert.Equal(t, uint64(0), uint64(compositeKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(0), uint64(compositeKey[1].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(0), uint64(compositeKey[2].(ghostferry.Uint64Key))) +} + +func TestMinCompositePaginationKey_FourColumnsMixed(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "region", Type: schema.TYPE_STRING, MaxSize: 10}, + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "uuid", Type: schema.TYPE_BINARY, MaxSize: 16}, + {Name: "seq", Type: schema.TYPE_NUMBER}, + } + + minKey := ghostferry.MinPaginationKey(columns) + compositeKey, ok := minKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 4, len(compositeKey)) + + // Binary key for string + assert.Equal(t, []byte{}, []byte(compositeKey[0].(ghostferry.BinaryKey))) + // Uint64 key + assert.Equal(t, uint64(0), uint64(compositeKey[1].(ghostferry.Uint64Key))) + // Binary key + assert.Equal(t, []byte{}, []byte(compositeKey[2].(ghostferry.BinaryKey))) + // Uint64 key + assert.Equal(t, uint64(0), uint64(compositeKey[3].(ghostferry.Uint64Key))) +} + +func TestMaxCompositePaginationKey_ThreeColumns(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "user_id", Type: schema.TYPE_NUMBER}, + {Name: "order_id", Type: schema.TYPE_NUMBER}, + } + + maxKey := ghostferry.MaxPaginationKey(columns) + compositeKey, ok := maxKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 3, len(compositeKey)) + + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[1].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[2].(ghostferry.Uint64Key))) + assert.True(t, compositeKey.IsMax()) +} + +func TestMaxCompositePaginationKey_FourColumnsMixed(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "region", Type: schema.TYPE_STRING, MaxSize: 10}, + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "uuid", Type: schema.TYPE_BINARY, MaxSize: 16}, + {Name: "seq", Type: schema.TYPE_NUMBER}, + } + + maxKey := ghostferry.MaxPaginationKey(columns) + compositeKey, ok := maxKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 4, len(compositeKey)) + + // First (string) should be 0xFF bytes + binaryKey1 := compositeKey[0].(ghostferry.BinaryKey) + assert.Equal(t, 10, len(binaryKey1)) + for _, b := range binaryKey1 { + assert.Equal(t, byte(0xFF), b) + } + + // Second (number) should be max uint64 + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[1].(ghostferry.Uint64Key))) + + // Third (binary) should be 0xFF bytes + binaryKey2 := compositeKey[2].(ghostferry.BinaryKey) + assert.Equal(t, 16, len(binaryKey2)) + for _, b := range binaryKey2 { + assert.Equal(t, byte(0xFF), b) + } + + // Fourth (number) should be max uint64 + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[3].(ghostferry.Uint64Key))) + + assert.True(t, compositeKey.IsMax()) +} + +func TestMaxCompositePaginationKey_MixedTypes(t *testing.T) { + columns := []*schema.TableColumn{ + {Name: "tenant_id", Type: schema.TYPE_NUMBER}, + {Name: "uuid", Type: schema.TYPE_BINARY, MaxSize: 16}, + } + + maxKey := ghostferry.MaxPaginationKey(columns) + compositeKey, ok := maxKey.(ghostferry.CompositeKey) + require.True(t, ok) + assert.Equal(t, 2, len(compositeKey)) + + // First should be max uint64 + assert.Equal(t, uint64(math.MaxUint64), uint64(compositeKey[0].(ghostferry.Uint64Key))) + + // Second should be max binary (all 0xFF bytes) + binaryKey, ok := compositeKey[1].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, 16, len(binaryKey)) + assert.True(t, binaryKey.IsMax()) +} diff --git a/test/go/state_serialization_test.go b/test/go/state_serialization_test.go new file mode 100644 index 000000000..fdf8b7a7a --- /dev/null +++ b/test/go/state_serialization_test.go @@ -0,0 +1,543 @@ +package test + +import ( + "encoding/json" + "testing" + + "github.com/Shopify/ghostferry" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSerializableState_MarshalJSON_EmptyState(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: make(map[string]ghostferry.PaginationKey), + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + assert.NotEmpty(t, data) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Empty(t, decoded.LastSuccessfulPaginationKeys) + assert.Empty(t, decoded.CompletedTables) +} + +func TestSerializableState_MarshalJSON_WithUint64Keys(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(100), + "db.table2": ghostferry.NewUint64Key(200), + "db.table3": ghostferry.NewUint64Key(300), + }, + CompletedTables: map[string]bool{ + "db.table4": true, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 3) + + key1, ok := decoded.LastSuccessfulPaginationKeys["db.table1"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(100), uint64(key1)) + + key2, ok := decoded.LastSuccessfulPaginationKeys["db.table2"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(200), uint64(key2)) + + key3, ok := decoded.LastSuccessfulPaginationKeys["db.table3"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(300), uint64(key3)) + + assert.True(t, decoded.CompletedTables["db.table4"]) +} + +func TestSerializableState_MarshalJSON_WithBinaryKeys(t *testing.T) { + uuid1 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + uuid2 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x02} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.uuid_table1": ghostferry.NewBinaryKey(uuid1), + "db.uuid_table2": ghostferry.NewBinaryKey(uuid2), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 2) + + key1, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table1"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid1, []byte(key1)) + + key2, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table2"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid2, []byte(key2)) +} + +func TestSerializableState_MarshalJSON_WithMixedKeys(t *testing.T) { + uuid := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.numeric_table": ghostferry.NewUint64Key(12345), + "db.uuid_table": ghostferry.NewBinaryKey(uuid), + "db.varchar_table": ghostferry.NewBinaryKey([]byte("some_key")), + "db.bigint_table": ghostferry.NewUint64Key(999999999), + }, + CompletedTables: map[string]bool{ + "db.completed_table": true, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 4) + + numericKey, ok := decoded.LastSuccessfulPaginationKeys["db.numeric_table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(12345), uint64(numericKey)) + + uuidKey, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid, []byte(uuidKey)) + + varcharKey, ok := decoded.LastSuccessfulPaginationKeys["db.varchar_table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte("some_key"), []byte(varcharKey)) + + bigintKey, ok := decoded.LastSuccessfulPaginationKeys["db.bigint_table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(999999999), uint64(bigintKey)) + + assert.True(t, decoded.CompletedTables["db.completed_table"]) +} + +func TestSerializableState_MarshalJSON_WithBinlogPosition(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(100), + }, + CompletedTables: make(map[string]bool), + LastWrittenBinlogPosition: mysql.Position{ + Name: "mysql-bin.000123", + Pos: 456789, + }, + LastStoredBinlogPositionForInlineVerifier: mysql.Position{ + Name: "mysql-bin.000122", + Pos: 123456, + }, + LastStoredBinlogPositionForTargetVerifier: mysql.Position{ + Name: "mysql-bin.000121", + Pos: 987654, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "mysql-bin.000123", decoded.LastWrittenBinlogPosition.Name) + assert.Equal(t, uint32(456789), decoded.LastWrittenBinlogPosition.Pos) + + assert.Equal(t, "mysql-bin.000122", decoded.LastStoredBinlogPositionForInlineVerifier.Name) + assert.Equal(t, uint32(123456), decoded.LastStoredBinlogPositionForInlineVerifier.Pos) + + assert.Equal(t, "mysql-bin.000121", decoded.LastStoredBinlogPositionForTargetVerifier.Name) + assert.Equal(t, uint32(987654), decoded.LastStoredBinlogPositionForTargetVerifier.Pos) +} + +func TestSerializableState_UnmarshalJSON_CorruptedData(t *testing.T) { + corruptedJSON := `{ + "GhostferryVersion": "test-version", + "LastSuccessfulPaginationKeys": { + "db.table1": {"type": "invalid_type", "value": 123} + } + }` + + var decoded ghostferry.SerializableState + err := json.Unmarshal([]byte(corruptedJSON), &decoded) + assert.Error(t, err) +} + +func TestSerializableState_RoundTrip_LargeState(t *testing.T) { + uuid1 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + uuid2 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x02} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version-1.2.3", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "prod.users": ghostferry.NewUint64Key(1000000), + "prod.orders": ghostferry.NewUint64Key(5000000), + "prod.products": ghostferry.NewUint64Key(250000), + "prod.sessions": ghostferry.NewBinaryKey(uuid1), + "prod.api_keys": ghostferry.NewBinaryKey(uuid2), + "staging.users": ghostferry.NewUint64Key(500), + "staging.orders": ghostferry.NewUint64Key(1000), + }, + CompletedTables: map[string]bool{ + "prod.old_table1": true, + "prod.old_table2": true, + "staging.old_table": true, + }, + LastWrittenBinlogPosition: mysql.Position{ + Name: "mysql-bin.001234", + Pos: 987654321, + }, + LastStoredBinlogPositionForInlineVerifier: mysql.Position{ + Name: "mysql-bin.001233", + Pos: 123456789, + }, + LastStoredBinlogPositionForTargetVerifier: mysql.Position{ + Name: "mysql-bin.001232", + Pos: 111222333, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, state.GhostferryVersion, decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 7) + assert.Len(t, decoded.CompletedTables, 3) + + usersKey, ok := decoded.LastSuccessfulPaginationKeys["prod.users"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(1000000), uint64(usersKey)) + + sessionsKey, ok := decoded.LastSuccessfulPaginationKeys["prod.sessions"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid1, []byte(sessionsKey)) + + assert.Equal(t, state.LastWrittenBinlogPosition, decoded.LastWrittenBinlogPosition) + assert.Equal(t, state.LastStoredBinlogPositionForInlineVerifier, decoded.LastStoredBinlogPositionForInlineVerifier) + assert.Equal(t, state.LastStoredBinlogPositionForTargetVerifier, decoded.LastStoredBinlogPositionForTargetVerifier) + + for tableName := range state.CompletedTables { + assert.True(t, decoded.CompletedTables[tableName]) + } +} + +func TestSerializableState_JSONStructure(t *testing.T) { + uuid := []byte{0xDE, 0xAD, 0xBE, 0xEF} + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(123), + "db.table2": ghostferry.NewBinaryKey(uuid), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var raw map[string]interface{} + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + keys, ok := raw["LastSuccessfulPaginationKeys"].(map[string]interface{}) + require.True(t, ok) + + table1Data := keys["db.table1"].(map[string]interface{}) + assert.Equal(t, "uint64", table1Data["type"]) + assert.Equal(t, float64(123), table1Data["value"]) + + table2Data := keys["db.table2"].(map[string]interface{}) + assert.Equal(t, "binary", table2Data["type"]) + assert.Equal(t, "deadbeef", table2Data["value"]) +} + +func TestSerializableState_EmptyBinaryKey(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table": ghostferry.NewBinaryKey([]byte{}), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + key, ok := decoded.LastSuccessfulPaginationKeys["db.table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(key)) +} + +func TestSerializableState_ZeroUint64Key(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table": ghostferry.NewUint64Key(0), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + key, ok := decoded.LastSuccessfulPaginationKeys["db.table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(key)) +} + +func TestSerializableState_WithCompositeKeys(t *testing.T) { + compositeKey1 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(100), + ghostferry.NewUint64Key(200), + } + compositeKey2 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(300), + ghostferry.NewBinaryKey([]byte("abc")), + } + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.composite_table1": compositeKey1, + "db.composite_table2": compositeKey2, + "db.simple_table": ghostferry.NewUint64Key(999), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 3) + + // Check composite key 1 (two uint64s) + key1, ok := decoded.LastSuccessfulPaginationKeys["db.composite_table1"].(ghostferry.CompositeKey) + require.True(t, ok) + require.Len(t, key1, 2) + + subKey1_0, ok := key1[0].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(100), uint64(subKey1_0)) + + subKey1_1, ok := key1[1].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(200), uint64(subKey1_1)) + + // Check composite key 2 (uint64 + binary) + key2, ok := decoded.LastSuccessfulPaginationKeys["db.composite_table2"].(ghostferry.CompositeKey) + require.True(t, ok) + require.Len(t, key2, 2) + + subKey2_0, ok := key2[0].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(300), uint64(subKey2_0)) + + subKey2_1, ok := key2[1].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte("abc"), []byte(subKey2_1)) + + // Check simple key still works + key3, ok := decoded.LastSuccessfulPaginationKeys["db.simple_table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(999), uint64(key3)) +} + +func TestSerializableState_CompositeKey_ThreeElements(t *testing.T) { + uuid := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x05} + + compositeKey := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1000), + ghostferry.NewBinaryKey(uuid), + ghostferry.NewUint64Key(2000), + } + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.three_col_table": compositeKey, + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + key, ok := decoded.LastSuccessfulPaginationKeys["db.three_col_table"].(ghostferry.CompositeKey) + require.True(t, ok) + require.Len(t, key, 3) + + subKey0, ok := key[0].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(1000), uint64(subKey0)) + + subKey1, ok := key[1].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid, []byte(subKey1)) + + subKey2, ok := key[2].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(2000), uint64(subKey2)) +} + +func TestSerializableState_CompositeKey_JSONStructure(t *testing.T) { + compositeKey := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(123), + ghostferry.NewBinaryKey([]byte{0xDE, 0xAD}), + } + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.composite": compositeKey, + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var raw map[string]interface{} + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + keys, ok := raw["LastSuccessfulPaginationKeys"].(map[string]interface{}) + require.True(t, ok) + + compositeData := keys["db.composite"].(map[string]interface{}) + assert.Equal(t, "composite", compositeData["type"]) + + valueArray := compositeData["value"].([]interface{}) + require.Len(t, valueArray, 2) + + // First element should be uint64 + elem0 := valueArray[0].(map[string]interface{}) + assert.Equal(t, "uint64", elem0["type"]) + assert.Equal(t, float64(123), elem0["value"]) + + // Second element should be binary + elem1 := valueArray[1].(map[string]interface{}) + assert.Equal(t, "binary", elem1["type"]) + assert.Equal(t, "dead", elem1["value"]) +} + +func TestSerializableState_MixedWithCompositeKeys_LargeState(t *testing.T) { + uuid1 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + + compositeKey1 := ghostferry.CompositeKey{ + ghostferry.NewUint64Key(1000), + ghostferry.NewUint64Key(2000), + } + + compositeKey2 := ghostferry.CompositeKey{ + ghostferry.NewBinaryKey(uuid1), + ghostferry.NewUint64Key(5000), + } + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version-composite", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "prod.users": ghostferry.NewUint64Key(1000000), + "prod.tenant_data": compositeKey1, + "prod.sharded_events": compositeKey2, + "prod.sessions": ghostferry.NewBinaryKey(uuid1), + }, + CompletedTables: map[string]bool{ + "prod.old_table": true, + }, + LastWrittenBinlogPosition: mysql.Position{ + Name: "mysql-bin.001234", + Pos: 987654321, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, state.GhostferryVersion, decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 4) + + // Check simple uint64 key + usersKey, ok := decoded.LastSuccessfulPaginationKeys["prod.users"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(1000000), uint64(usersKey)) + + // Check composite key 1 + tenantKey, ok := decoded.LastSuccessfulPaginationKeys["prod.tenant_data"].(ghostferry.CompositeKey) + require.True(t, ok) + require.Len(t, tenantKey, 2) + assert.Equal(t, uint64(1000), uint64(tenantKey[0].(ghostferry.Uint64Key))) + assert.Equal(t, uint64(2000), uint64(tenantKey[1].(ghostferry.Uint64Key))) + + // Check composite key 2 + eventsKey, ok := decoded.LastSuccessfulPaginationKeys["prod.sharded_events"].(ghostferry.CompositeKey) + require.True(t, ok) + require.Len(t, eventsKey, 2) + assert.Equal(t, uuid1, []byte(eventsKey[0].(ghostferry.BinaryKey))) + assert.Equal(t, uint64(5000), uint64(eventsKey[1].(ghostferry.Uint64Key))) + + // Check simple binary key + sessionsKey, ok := decoded.LastSuccessfulPaginationKeys["prod.sessions"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid1, []byte(sessionsKey)) + + assert.Equal(t, state.LastWrittenBinlogPosition, decoded.LastWrittenBinlogPosition) + assert.True(t, decoded.CompletedTables["prod.old_table"]) +} diff --git a/test/go/table_schema_cache_test.go b/test/go/table_schema_cache_test.go index fc00a5406..7abe8cda3 100644 --- a/test/go/table_schema_cache_test.go +++ b/test/go/table_schema_cache_test.go @@ -86,18 +86,18 @@ func (this *TableSchemaCacheTestSuite) TestLoadTablesWithoutFiltering() { } } -func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithoutNumericPK() { +func (this *TableSchemaCacheTestSuite) TestLoadTablesAcceptTablesWithVarcharPK() { table := "test_table_4" paginationColumn := "id" query := fmt.Sprintf("CREATE TABLE %s.%s (%s varchar(20) not null, data TEXT, primary key(%s))", testhelpers.TestSchemaName, table, paginationColumn, paginationColumn) _, err := this.Ferry.SourceDB.Exec(query) this.Require().Nil(err) - _, err = ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) + tableSchemaCache, err := ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) - this.Require().NotNil(err) - this.Require().EqualError(err, ghostferry.NonNumericPaginationKeyError(testhelpers.TestSchemaName, table, paginationColumn).Error()) - this.Require().Contains(err.Error(), table) + this.Require().Nil(err) + this.Require().Contains(tableSchemaCache, testhelpers.TestSchemaName+"."+table) + this.Require().Equal(paginationColumn, tableSchemaCache[testhelpers.TestSchemaName+"."+table].GetPaginationColumn().Name) } func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithoutNumericPKWithMediumInt() { table := "pagination_by_column_medium_int_pk" @@ -222,21 +222,35 @@ func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithoutPKColumn this.Require().EqualError(err, ghostferry.NonExistingPaginationKeyError(testhelpers.TestSchemaName, table).Error()) } -func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithCompositePKButNoAlternateColumnToFallBackTo() { +func (this *TableSchemaCacheTestSuite) TestLoadTablesAcceptsTablesWithCompositePK() { table := "composite_pk_without_fallback" query := fmt.Sprintf("CREATE TABLE %s.%s (identity bigint(20) not null, other_id bigint(20) not null, data TEXT, primary key(identity, other_id))", testhelpers.TestSchemaName, table) _, err := this.Ferry.SourceDB.Exec(query) this.Require().Nil(err) - _, err = ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) + tables, err := ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) - this.Require().NotNil(err) - this.Require().EqualError(err, ghostferry.NonExistingPaginationKeyError(testhelpers.TestSchemaName, table).Error()) + // Composite PKs are now supported, so this should succeed + this.Require().Nil(err) + tableSchema := tables.Get(testhelpers.TestSchemaName, table) + this.Require().NotNil(tableSchema) + // Should use the composite PK + this.Require().Equal(2, len(tableSchema.PaginationKeyColumns)) + this.Require().Equal("identity", tableSchema.PaginationKeyColumns[0].Name) + this.Require().Equal("other_id", tableSchema.PaginationKeyColumns[1].Name) + // Backward compat: PaginationKeyColumn should be the first column + this.Require().Equal("identity", tableSchema.PaginationKeyColumn.Name) } -func (this *TableSchemaCacheTestSuite) TestLoadTablesWithCompositePKButIDColumnToFallBackTo() { +func (this *TableSchemaCacheTestSuite) TestLoadTablesWithCompositePKButIDColumnToOverride() { table := "composite_pk_with_id_fallback" paginationColumn := "id" + + query := fmt.Sprintf("CREATE TABLE %s.%s (identity bigint(20) not null, id bigint(20) not null, data TEXT, primary key(identity, id))", testhelpers.TestSchemaName, table) + _, err := this.Ferry.SourceDB.Exec(query) + this.Require().Nil(err) + + // Test 1: PerTable config should override the composite PK cascadingPaginationColumnConfig := &ghostferry.CascadingPaginationColumnConfig{ PerTable: map[string]map[string]string{ testhelpers.TestSchemaName: map[string]string{ @@ -244,17 +258,20 @@ func (this *TableSchemaCacheTestSuite) TestLoadTablesWithCompositePKButIDColumnT }, }, } - - query := fmt.Sprintf("CREATE TABLE %s.%s (identity bigint(20) not null, id bigint(20) not null, data TEXT, primary key(identity, id))", testhelpers.TestSchemaName, table) - _, err := this.Ferry.SourceDB.Exec(query) - this.Require().Nil(err) - this.assertLoadTablesWithCascadingPaginationColumnConfig(table, paginationColumn, cascadingPaginationColumnConfig) + // Test 2: FallbackColumn is NOT used when table has a PK (including composite) + // The composite PK takes precedence cascadingPaginationColumnConfig = &ghostferry.CascadingPaginationColumnConfig{ FallbackColumn: paginationColumn, } - this.assertLoadTablesWithCascadingPaginationColumnConfig(table, paginationColumn, cascadingPaginationColumnConfig) + // With FallbackColumn, it should use the composite PK instead (first column: "identity") + tableSchemaCache, err := ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, cascadingPaginationColumnConfig) + this.Require().Nil(err) + tableSchema := tableSchemaCache.Get(testhelpers.TestSchemaName, table) + // Should use composite PK, not fallback + this.Require().Equal(2, len(tableSchema.PaginationKeyColumns)) + this.Require().Equal("identity", tableSchema.PaginationKeyColumn.Name) } func (this *TableSchemaCacheTestSuite) TestAllTableNames() { diff --git a/test/go/trivial_integration_test.go b/test/go/trivial_integration_test.go index 22a2631b4..449c7a507 100644 --- a/test/go/trivial_integration_test.go +++ b/test/go/trivial_integration_test.go @@ -4,6 +4,7 @@ import ( sqlorig "database/sql" "fmt" "math/rand" + "strings" "testing" sql "github.com/Shopify/ghostferry/sqlwrapper" @@ -154,6 +155,18 @@ func TestCopyDataWithNullInColumn(t *testing.T) { testcase.Run() } +func TestCopyDataWithThreeColumnCompositeKey(t *testing.T) { + ferry := testhelpers.NewTestFerry() + + testcase := &testhelpers.IntegrationTestCase{ + T: t, + SetupAction: setupCompositeKeyTableDatabase(3), + Ferry: ferry, + } + + testcase.Run() +} + // ==================== // Helper methods below // ==================== @@ -180,3 +193,36 @@ func setupSingleTableDatabaseWithExtraNullColumn(f *testhelpers.TestFerry, sourc _, err := sourceDB.Exec("UPDATE gftest.table1 SET tenant_id = NULL") testhelpers.PanicIfError(err) } + +// setupCompositeKeyTableDatabase returns a setup function for a composite key table with numColumns key columns. +func setupCompositeKeyTableDatabase(numColumns int) func(*testhelpers.TestFerry, *sql.DB, *sql.DB) { + return func(f *testhelpers.TestFerry, sourceDB, targetDB *sql.DB) { + tableName := fmt.Sprintf("composite_table_%d", numColumns) + // 10^(numColumns-1) rows gives us a good spread across all key columns + maxRows := 1 + for i := 0; i < numColumns; i++ { + maxRows *= 10 + } + if maxRows > 100 { + maxRows = 100 // Cap at 100 rows for reasonable test speed + } + + testhelpers.SeedInitialDataCompositeKey(sourceDB, "gftest", tableName, maxRows, numColumns) + + // Delete a few random rows + for i := 0; i < 4; i++ { + args := make([]interface{}, numColumns) + var whereClauses []string + remaining := rand.Intn(maxRows) + for col := numColumns - 1; col >= 0; col-- { + args[col] = (remaining % 10) + 1 + remaining /= 10 + whereClauses = append([]string{fmt.Sprintf("k%d = ?", col+1)}, whereClauses...) + } + query := fmt.Sprintf("DELETE FROM gftest.%s WHERE %s", tableName, strings.Join(whereClauses, " AND ")) + sourceDB.Exec(query, args...) // Ignore errors for non-existent rows + } + + testhelpers.SeedInitialDataCompositeKey(targetDB, "gftest", tableName, 0, numColumns) + } +} diff --git a/test/helpers/db_helper.rb b/test/helpers/db_helper.rb index 68e6468c9..bcb2e0f43 100644 --- a/test/helpers/db_helper.rb +++ b/test/helpers/db_helper.rb @@ -1,5 +1,6 @@ require "logger" require "mysql2" +require "securerandom" module DbHelper ALPHANUMERICS = ("0".."9").to_a + ("a".."z").to_a + ("A".."Z").to_a @@ -9,6 +10,7 @@ module DbHelper DEFAULT_DB = "gftest" DEFAULT_TABLE = "test_table_1" + UUID_TABLE = "test_uuid_table" class Mysql2::Client alias_method :query_without_maginalia, :query @@ -42,6 +44,7 @@ def self.rand_data(length: 32) end DEFAULT_FULL_TABLE_NAME = full_table_name(DEFAULT_DB, DEFAULT_TABLE) + UUID_FULL_TABLE_NAME = full_table_name(DEFAULT_DB, UUID_TABLE) def full_table_name(db, table) DbHelper.full_table_name(db, table) @@ -160,6 +163,50 @@ def seed_simple_database_with_single_table seed_random_data(target_db, number_of_rows: 0) end + def generate_uuid_bytes + uuid_string = SecureRandom.uuid + uuid_string.gsub("-", "").scan(/../).map { |x| x.hex.chr }.join + end + + def seed_uuid_data(connection, database_name: DEFAULT_DB, table_name: UUID_TABLE, number_of_rows: 1111) + dbtable = full_table_name(database_name, table_name) + + connection.query("CREATE DATABASE IF NOT EXISTS #{database_name}") + connection.query("CREATE TABLE IF NOT EXISTS #{dbtable} (id VARBINARY(16) NOT NULL, data TEXT, PRIMARY KEY(id))") + + return if number_of_rows == 0 + + transaction(connection) do + number_of_rows.times do + uuid_bytes = generate_uuid_bytes + data = rand_data + insert_statement = connection.prepare("INSERT INTO #{dbtable} (id, data) VALUES (?, ?)") + insert_statement.execute(uuid_bytes, data) + end + end + end + + def seed_simple_database_with_uuid_table + max_rows = 1111 + seed_uuid_data(source_db, number_of_rows: max_rows) + + num_holes = 140 + result = source_db.query("SELECT id FROM #{UUID_FULL_TABLE_NAME} ORDER BY id LIMIT #{num_holes}") + + holes_ids = [] + result.each do |row| + holes_ids << row["id"] + end + + unless holes_ids.empty? + sqlargs = (["?"]*holes_ids.length).join(",") + delete_statement = source_db.prepare("DELETE FROM #{UUID_FULL_TABLE_NAME} WHERE id IN (#{sqlargs})") + delete_statement.execute(*holes_ids) + end + + seed_uuid_data(target_db, number_of_rows: 0) + end + # Get some overall metrics like CHECKSUM, row count, sample row from tables. # Generally used for test validation. def source_and_target_table_metrics(tables: [DEFAULT_FULL_TABLE_NAME]) @@ -186,7 +233,12 @@ def table_metric(conn, table, sample_id: nil) result = conn.query("SELECT * FROM #{table} ORDER BY RAND() LIMIT 1") metrics[:sample_row] = result.first else - result = conn.query("SELECT * FROM #{table} WHERE id = #{sample_id} LIMIT 1") + if sample_id.is_a?(String) && sample_id.encoding == Encoding::ASCII_8BIT + stmt = conn.prepare("SELECT * FROM #{table} WHERE id = ? LIMIT 1") + result = stmt.execute(sample_id) + else + result = conn.query("SELECT * FROM #{table} WHERE id = #{sample_id} LIMIT 1") + end metrics[:sample_row] = result.first end diff --git a/test/integration/interrupt_resume_test.rb b/test/integration/interrupt_resume_test.rb index fde868efe..2ed3f23c9 100644 --- a/test/integration/interrupt_resume_test.rb +++ b/test/integration/interrupt_resume_test.rb @@ -24,7 +24,7 @@ def test_interrupt_resume_without_writes_to_source_to_check_target_state_when_in result = target_db.query("SELECT MAX(id) AS max_id FROM #{DEFAULT_FULL_TABLE_NAME}") last_successful_id = result.first["max_id"] assert last_successful_id > 0 - assert_equal last_successful_id, dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + assert_equal last_successful_id, dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] end def test_interrupt_and_resume_without_last_known_schema_cache @@ -553,7 +553,7 @@ def test_issue_149_correct dumped_state = ghostferry.run_expecting_interrupt assert_basic_fields_exist_in_dumped_state(dumped_state) - last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] assert last_pk > 200 # We need to rewind the state backwards, and then change that row on the @@ -573,7 +573,7 @@ def test_issue_149_correct data_changed = source_db.query("SELECT data FROM #{DEFAULT_FULL_TABLE_NAME} WHERE id = #{id_to_change}").first["data"] assert_equal "changed", data_changed - dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] = id_to_change - 1 + dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] = id_to_change - 1 ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) changed_row_copied = false @@ -623,7 +623,7 @@ def test_issue_149_corrupted dumped_state = ghostferry.run_expecting_interrupt assert_basic_fields_exist_in_dumped_state(dumped_state) - last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] assert last_pk > 200 # This should be similar to test_issue_149_correct, except we force the @@ -641,7 +641,7 @@ def test_issue_149_corrupted data_corrupted = target_db.query("SELECT data FROM #{DEFAULT_FULL_TABLE_NAME} WHERE id = #{id_to_change}").first["data"] assert_equal "corrupted", data_corrupted - dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] = id_to_change - 1 + dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] = id_to_change - 1 ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) changed_row_copied = false @@ -680,4 +680,128 @@ def test_issue_149_corrupted assert expectation, "error message: #{error_message.inspect}, didn't start with #{predicate.inspect}" end + + def test_interrupt_resume_without_writes_to_source_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + ghostferry.send_signal("TERM") + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + + result = target_db.query("SELECT COUNT(*) AS cnt FROM #{UUID_FULL_TABLE_NAME}") + count = result.first["cnt"] + assert_equal 200, count + + result = target_db.query("SELECT MAX(id) AS max_id FROM #{UUID_FULL_TABLE_NAME}") + last_successful_id_bytes = result.first["max_id"] + assert last_successful_id_bytes.length > 0 + + last_key_in_state = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{UUID_TABLE}"]["value"] + assert_equal last_successful_id_bytes.unpack1("H*"), last_key_in_state + end + + def test_interrupt_and_resume_without_last_known_schema_cache_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + ghostferry.send_signal("TERM") + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + dumped_state["LastKnownTableSchemaCache"] = nil + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.run(dumped_state) + + assert_uuid_table_is_identical + end + + def test_interrupt_resume_with_writes_to_source_with_uuid_table + seed_simple_database_with_uuid_table + + datawriter = new_source_datawriter + ghostferry = new_ghostferry_with_interrupt_after_row_copy(MINIMAL_GHOSTFERRY, after_batches_written: 2) + + start_datawriter_with_ghostferry(datawriter, ghostferry) + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + stop_datawriter_during_cutover(datawriter, ghostferry) + + ghostferry.run(dumped_state) + + assert_uuid_table_is_identical + end + + def test_interrupt_resume_idempotence_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry_with_interrupt_after_row_copy(MINIMAL_GHOSTFERRY) + dumped_state = ghostferry.run_expecting_interrupt + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + ghostferry.run_with_logs(dumped_state) + + assert_uuid_table_is_identical + + ghostferry.run_with_logs(dumped_state) + + assert_uuid_table_is_identical + + assert_ghostferry_completed(ghostferry, times: 2) + end + + def test_interrupt_resume_inline_verifier_with_uuid_table + seed_simple_database_with_single_table + seed_simple_database_with_uuid_table + + datawriter = new_source_datawriter + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + start_datawriter_with_ghostferry(datawriter, ghostferry) + + batches_written = 0 + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + batches_written += 1 + if batches_written >= 5 + ghostferry.term_and_wait_for_exit + end + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + refute_nil dumped_state["BinlogVerifyStore"] + refute_nil dumped_state["BinlogVerifyStore"]["gftest"] + refute_nil dumped_state["BinlogVerifyStore"]["gftest"]["test_table_1"] + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + verification_ran = false + incorrect_tables = [] + ghostferry.on_status(Ghostferry::Status::VERIFIED) do |*tables| + verification_ran = true + incorrect_tables = tables + end + + stop_datawriter_during_cutover(datawriter, ghostferry) + + ghostferry.run(dumped_state) + + assert verification_ran + assert_equal 0, incorrect_tables.length + assert_test_table_is_identical + assert_uuid_table_is_identical + end end diff --git a/test/test_helper.rb b/test/test_helper.rb index a3a5ba5af..d7247860b 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -92,8 +92,8 @@ def new_ghostferry_with_interrupt_after_row_copy(filepath, config: {}, after_bat g end - def new_source_datawriter(*args) - dw = DataWriter.new(source_db_config, *args, logger: @log_capturer.logger) + def new_source_datawriter(*args, **kwargs) + dw = DataWriter.new(source_db_config, *args, **kwargs, logger: @log_capturer.logger) @datawriter_instances << dw dw end @@ -179,6 +179,25 @@ def assert_test_table_is_identical ) end + def assert_uuid_table_is_identical + source, target = source_and_target_table_metrics(tables: [UUID_FULL_TABLE_NAME]) + + assert source[UUID_FULL_TABLE_NAME][:row_count] > 0 + assert target[UUID_FULL_TABLE_NAME][:row_count] > 0 + + assert_equal( + source[UUID_FULL_TABLE_NAME][:row_count], + target[UUID_FULL_TABLE_NAME][:row_count], + "source and target row count don't match", + ) + + assert_equal( + source[UUID_FULL_TABLE_NAME][:checksum], + target[UUID_FULL_TABLE_NAME][:checksum], + "source and target checksum don't match", + ) + end + # Use this method to assert the validity of the structure of the dumped # state. # diff --git a/testhelpers/data_writer.go b/testhelpers/data_writer.go index c80107f16..a3e700311 100644 --- a/testhelpers/data_writer.go +++ b/testhelpers/data_writer.go @@ -3,6 +3,7 @@ package testhelpers import ( "fmt" "math/rand" + "strings" "sync" "time" @@ -86,6 +87,66 @@ func SeedInitialData(db *sql.DB, dbname, tablename string, numberOfRows int) { PanicIfError(tx.Commit()) } +// SeedInitialDataCompositeKey seeds a table with an n-column composite primary key. +// numColumns specifies how many key columns (k1, k2, ..., kN) to create. +func SeedInitialDataCompositeKey(db *sql.DB, dbname, tablename string, numberOfRows, numColumns int) { + if numColumns < 2 { + panic("SeedInitialDataCompositeKey fails: number of columns must be at least 2") + } + + var err error + + query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname) + _, err = db.Exec(query) + PanicIfError(err) + + // Build column definitions: k1 int, k2 int, ..., kN int + var colDefs, colNames, placeholders, pkCols []string + for i := 1; i <= numColumns; i++ { + colName := fmt.Sprintf("k%d", i) + colDefs = append(colDefs, colName+" int") + colNames = append(colNames, colName) + placeholders = append(placeholders, "?") + pkCols = append(pkCols, colName) + } + colDefs = append(colDefs, "data TEXT") + colNames = append(colNames, "data") + placeholders = append(placeholders, "?") + + createQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (%s, PRIMARY KEY (%s))", + dbname, tablename, + strings.Join(colDefs, ", "), + strings.Join(pkCols, ", ")) + _, err = db.Exec(createQuery) + PanicIfError(err) + + tx, err := db.Begin() + PanicIfError(err) + + insertQuery := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", + dbname, tablename, + strings.Join(colNames, ", "), + strings.Join(placeholders, ", ")) + + // Generate unique composite keys by treating i as a base-10 number + // where each digit (1-10) corresponds to a key column value + for i := 0; i < numberOfRows; i++ { + args := make([]interface{}, numColumns+1) + remaining := i + // Calculate key values from least significant to most significant + for col := numColumns - 1; col >= 0; col-- { + args[col] = (remaining % 10) + 1 + remaining /= 10 + } + args[numColumns] = RandData() + + _, err = tx.Exec(insertQuery, args...) + PanicIfError(err) + } + + PanicIfError(tx.Commit()) +} + func AddTenantID(db *sql.DB, dbName, tableName string, numberOfTenants int) { query := "ALTER TABLE %s.%s ADD tenant_id bigint(20)" query = fmt.Sprintf(query, dbName, tableName)