Skip to content

Commit

Permalink
misc cleanup (#2474)
Browse files Browse the repository at this point in the history
split out from #2395
  • Loading branch information
serprex authored Jan 21, 2025
1 parent 67ba3d6 commit e418391
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 121 deletions.
31 changes: 25 additions & 6 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"sync/atomic"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -831,12 +832,30 @@ func (a *FlowableActivity) QRepHasNewRows(ctx context.Context,

logger.Info(fmt.Sprintf("current last partition value is %v", last))

result, err := srcConn.CheckForUpdatedMaxValue(ctx, config, last)
maxValue, err := srcConn.GetMaxValue(ctx, config, last)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return false, fmt.Errorf("failed to check for new rows: %w", err)
}
return result, nil

if maxValue == nil || last == nil || last.Range == nil {
return maxValue != nil, nil
}

switch x := last.Range.Range.(type) {
case *protos.PartitionRange_IntRange:
if maxValue.(int64) > x.IntRange.End {
return true, nil
}
case *protos.PartitionRange_TimestampRange:
if maxValue.(time.Time).After(x.TimestampRange.End.AsTime()) {
return true, nil
}
default:
return false, fmt.Errorf("unknown range type: %v", x)
}

return false, nil
}

func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) {
Expand Down Expand Up @@ -1025,16 +1044,16 @@ func (a *FlowableActivity) RemoveTablesFromRawTable(
for _, table := range tablesToRemove {
tableNames = append(tableNames, table.DestinationTableIdentifier)
}
err = dstConn.RemoveTableEntriesFromRawTable(ctx, &protos.RemoveTablesFromRawTableInput{
if err := dstConn.RemoveTableEntriesFromRawTable(ctx, &protos.RemoveTablesFromRawTableInput{
FlowJobName: cfg.FlowJobName,
DestinationTableNames: tableNames,
SyncBatchId: syncBatchID,
NormalizeBatchId: normBatchID,
})
if err != nil {
}); err != nil {
a.Alerter.LogFlowError(ctx, cfg.FlowJobName, err)
return err
}
return err
return nil
}

func (a *FlowableActivity) RemoveTablesFromCatalog(
Expand Down
10 changes: 7 additions & 3 deletions flow/activities/snapshot_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, pee
}

a.SnapshotStatesMutex.Lock()
a.TxSnapshotStates[sessionID] = TxSnapshotState{
SnapshotName: exportSnapshotOutput.SnapshotName,
SupportsTIDScans: exportSnapshotOutput.SupportsTidScans,
if exportSnapshotOutput != nil {
a.TxSnapshotStates[sessionID] = TxSnapshotState{
SnapshotName: exportSnapshotOutput.SnapshotName,
SupportsTIDScans: exportSnapshotOutput.SupportsTidScans,
}
} else {
a.TxSnapshotStates[sessionID] = TxSnapshotState{}
}
a.SnapshotStatesMutex.Unlock()

Expand Down
58 changes: 18 additions & 40 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ func (c *PostgresConnector) GetQRepPartitions(
) ([]*protos.QRepPartition, error) {
if config.WatermarkColumn == "" {
// if no watermark column is specified, return a single partition
partition := &protos.QRepPartition{
PartitionId: uuid.New().String(),
FullTablePartition: true,
Range: nil,
}
return []*protos.QRepPartition{partition}, nil
return []*protos.QRepPartition{
{
PartitionId: uuid.New().String(),
FullTablePartition: true,
Range: nil,
},
}, nil
}

// begin a transaction
Expand Down Expand Up @@ -135,40 +136,38 @@ func (c *PostgresConnector) getNumRowsPartitions(
partitionsQuery := fmt.Sprintf(
`SELECT bucket, MIN(%[2]s) AS start, MAX(%[2]s) AS end
FROM (
SELECT NTILE(%[1]d) OVER (ORDER BY %[2]s) AS bucket, %[2]s
FROM %[3]s WHERE %[2]s > $1
SELECT NTILE(%[1]d) OVER (ORDER BY %[2]s) AS bucket, %[2]s
FROM %[3]s WHERE %[2]s > $1
) subquery
GROUP BY bucket
ORDER BY start
`,
ORDER BY start`,
numPartitions,
quotedWatermarkColumn,
parsedWatermarkTable.String(),
)
c.logger.Info("[row_based_next] partitions query: " + partitionsQuery)
c.logger.Info("[row_based_next] partitions query", slog.String("query", partitionsQuery))
rows, err = tx.Query(ctx, partitionsQuery, minVal)
} else {
partitionsQuery := fmt.Sprintf(
`SELECT bucket, MIN(%[2]s) AS start, MAX(%[2]s) AS end
FROM (
SELECT NTILE(%[1]d) OVER (ORDER BY %[2]s) AS bucket, %[2]s FROM %[3]s
SELECT NTILE(%[1]d) OVER (ORDER BY %[2]s) AS bucket, %[2]s FROM %[3]s
) subquery
GROUP BY bucket
ORDER BY start
`,
ORDER BY start`,
numPartitions,
quotedWatermarkColumn,
parsedWatermarkTable.String(),
)
c.logger.Info("[row_based] partitions query: " + partitionsQuery)
c.logger.Info("[row_based] partitions query", slog.String("query", partitionsQuery))
rows, err = tx.Query(ctx, partitionsQuery)
}
if err != nil {
return nil, shared.LogError(c.logger, fmt.Errorf("failed to query for partitions: %w", err))
}
defer rows.Close()

partitionHelper := partition_utils.NewPartitionHelper()
partitionHelper := partition_utils.NewPartitionHelper(c.logger)
for rows.Next() {
var bucket pgtype.Int8
var start, end interface{}
Expand Down Expand Up @@ -264,40 +263,19 @@ func (c *PostgresConnector) getMinMaxValues(
return minValue, maxValue, nil
}

func (c *PostgresConnector) CheckForUpdatedMaxValue(
func (c *PostgresConnector) GetMaxValue(
ctx context.Context,
config *protos.QRepConfig,
last *protos.QRepPartition,
) (bool, error) {
) (any, error) {
checkTx, err := c.conn.Begin(ctx)
if err != nil {
return false, fmt.Errorf("unable to begin transaction for getting max value: %w", err)
}
defer shared.RollbackTx(checkTx, c.logger)

_, maxValue, err := c.getMinMaxValues(ctx, checkTx, config, last)
if err != nil {
return false, fmt.Errorf("error while getting min and max values: %w", err)
}

if maxValue == nil || last == nil || last.Range == nil {
return maxValue != nil, nil
}

switch x := last.Range.Range.(type) {
case *protos.PartitionRange_IntRange:
if maxValue.(int64) > x.IntRange.End {
return true, nil
}
case *protos.PartitionRange_TimestampRange:
if maxValue.(time.Time).After(x.TimestampRange.End.AsTime()) {
return true, nil
}
default:
return false, fmt.Errorf("unknown range type: %v", x)
}

return false, nil
return maxValue, err
}

func (c *PostgresConnector) PullQRepRecords(
Expand Down
3 changes: 1 addition & 2 deletions flow/connectors/postgres/qrep_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ func TestGetQRepPartitions(t *testing.T) {
defer conn.Close(context.Background())

//nolint:gosec // Generate a random schema name, number has no cryptographic significance
rndUint := rand.Uint64()
schemaName := fmt.Sprintf("test_%d", rndUint)
schemaName := fmt.Sprintf("test_%d", rand.Uint64())

// Create the schema
_, err = conn.Exec(context.Background(), fmt.Sprintf(`CREATE SCHEMA %s;`, schemaName))
Expand Down
5 changes: 2 additions & 3 deletions flow/connectors/postgres/ssh_wrapped_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,14 @@ func (tunnel *SSHTunnel) NewPostgresConnFromConfig(
}

host := connConfig.Host
err = retryWithBackoff(logger, func() error {
if err := retryWithBackoff(logger, func() error {
_, err := conn.Exec(ctx, "SELECT 1")
if err != nil {
logger.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host))
return err
}
return nil
}, 5, 5*time.Second)
if err != nil {
}, 5, 5*time.Second); err != nil {
logger.Error("Failed to create pool", slog.Any("error", err), slog.String("host", host))
conn.Close(ctx)
return nil, err
Expand Down
100 changes: 46 additions & 54 deletions flow/connectors/sqlserver/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (c *SQLServerConnector) GetQRepPartitions(
ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition,
) ([]*protos.QRepPartition, error) {
if config.WatermarkTable == "" {
c.logger.Info("watermark table is empty, doing full table refresh")
// if no watermark column is specified, return a single partition
return []*protos.QRepPartition{
{
PartitionId: uuid.New().String(),
Expand All @@ -46,7 +46,7 @@ func (c *SQLServerConnector) GetQRepPartitions(

// Query to get the total number of rows in the table
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", config.WatermarkTable, whereClause)
var minVal interface{} = nil
var minVal interface{}
var totalRows pgtype.Int8
if last != nil && last.Range != nil {
switch lastRange := last.Range.Range.(type) {
Expand Down Expand Up @@ -78,11 +78,8 @@ func (c *SQLServerConnector) GetQRepPartitions(
if err != nil {
return nil, fmt.Errorf("failed to query for total rows: %w", err)
}
} else {
row := c.db.QueryRowContext(ctx, countQuery)
if err = row.Scan(&totalRows); err != nil {
return nil, fmt.Errorf("failed to query for total rows: %w", err)
}
} else if err := c.db.QueryRowContext(ctx, countQuery).Scan(&totalRows); err != nil {
return nil, fmt.Errorf("failed to query for total rows: %w", err)
}

if totalRows.Int64 == 0 {
Expand All @@ -101,18 +98,16 @@ func (c *SQLServerConnector) GetQRepPartitions(
if minVal != nil {
// Query to get partitions using window functions
partitionsQuery := fmt.Sprintf(
`SELECT bucket_v, MIN(v_from) AS start_v, MAX(v_from) AS end_v
FROM (
SELECT NTILE(%d) OVER (ORDER BY %s) AS bucket_v, %s as v_from
FROM %s WHERE %s > :minVal
) AS subquery
GROUP BY bucket_v
ORDER BY start_v`,
`SELECT bucket, MIN(%[2]s) AS start_v, MAX(%[2]s) AS end_v
FROM (
SELECT NTILE(%[1]d) OVER (ORDER BY %s) AS bucket, %[2]s
FROM %[3]s WHERE %[2]s > :minVal
) AS subquery
GROUP BY bucket
ORDER BY start_v`,
numPartitions,
quotedWatermarkColumn,
quotedWatermarkColumn,
config.WatermarkTable,
quotedWatermarkColumn,
)
c.logger.Info(fmt.Sprintf("partitions query: %s - minVal: %v", partitionsQuery, minVal))
params := map[string]interface{}{
Expand All @@ -121,16 +116,14 @@ func (c *SQLServerConnector) GetQRepPartitions(
rows, err = c.db.NamedQuery(partitionsQuery, params)
} else {
partitionsQuery := fmt.Sprintf(
`SELECT bucket_v, MIN(v_from) AS start_v, MAX(v_from) AS end_v
FROM (
SELECT NTILE(%d) OVER (ORDER BY %s) AS bucket_v, %s as v_from
FROM %s
) AS subquery
GROUP BY bucket_v
ORDER BY start_v`,
`SELECT bucket, MIN(%[2]s) AS start_v, MAX(%[2]s) AS end_v
FROM (
SELECT NTILE(%[1]d) OVER (ORDER BY %[2]s) AS bucket, %[2]s FROM %[3]s
) AS subquery
GROUP BY bucket
ORDER BY start_v`,
numPartitions,
quotedWatermarkColumn,
quotedWatermarkColumn,
config.WatermarkTable,
)
c.logger.Info("partitions query: " + partitionsQuery)
Expand All @@ -142,16 +135,15 @@ func (c *SQLServerConnector) GetQRepPartitions(

defer rows.Close()

partitionHelper := utils.NewPartitionHelper()
partitionHelper := utils.NewPartitionHelper(c.logger)
for rows.Next() {
var bucket pgtype.Int8
var start, end interface{}
if err := rows.Scan(&bucket, &start, &end); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}

err = partitionHelper.AddPartition(start, end)
if err != nil {
if err := partitionHelper.AddPartition(start, end); err != nil {
return nil, fmt.Errorf("failed to add partition: %w", err)
}
}
Expand All @@ -162,7 +154,7 @@ func (c *SQLServerConnector) GetQRepPartitions(
func (c *SQLServerConnector) PullQRepRecords(
ctx context.Context,
config *protos.QRepConfig,
partition *protos.QRepPartition,
last *protos.QRepPartition,
stream *model.QRecordStream,
) (int, error) {
// Build the query to pull records within the range from the source table
Expand All @@ -172,40 +164,40 @@ func (c *SQLServerConnector) PullQRepRecords(
return 0, err
}

if partition.FullTablePartition {
var qbatch *model.QRecordBatch
if last.FullTablePartition {
// this is a full table partition, so just run the query
qbatch, err := c.ExecuteAndProcessQuery(ctx, query)
var err error
qbatch, err = c.ExecuteAndProcessQuery(ctx, query)
if err != nil {
return 0, err
}
qbatch.FeedToQRecordStream(stream)
return len(qbatch.Records), nil
}
} else {
var rangeStart interface{}
var rangeEnd interface{}

var rangeStart interface{}
var rangeEnd interface{}

// Depending on the type of the range, convert the range into the correct type
switch x := partition.Range.Range.(type) {
case *protos.PartitionRange_IntRange:
rangeStart = x.IntRange.Start
rangeEnd = x.IntRange.End
case *protos.PartitionRange_TimestampRange:
rangeStart = x.TimestampRange.Start.AsTime()
rangeEnd = x.TimestampRange.End.AsTime()
default:
return 0, fmt.Errorf("unknown range type: %v", x)
}
// Depending on the type of the range, convert the range into the correct type
switch x := last.Range.Range.(type) {
case *protos.PartitionRange_IntRange:
rangeStart = x.IntRange.Start
rangeEnd = x.IntRange.End
case *protos.PartitionRange_TimestampRange:
rangeStart = x.TimestampRange.Start.AsTime()
rangeEnd = x.TimestampRange.End.AsTime()
default:
return 0, fmt.Errorf("unknown range type: %v", x)
}

rangeParams := map[string]interface{}{
"startRange": rangeStart,
"endRange": rangeEnd,
var err error
qbatch, err = c.NamedExecuteAndProcessQuery(ctx, query, map[string]interface{}{
"startRange": rangeStart,
"endRange": rangeEnd,
})
if err != nil {
return 0, err
}
}

qbatch, err := c.NamedExecuteAndProcessQuery(ctx, query, rangeParams)
if err != nil {
return 0, err
}
qbatch.FeedToQRecordStream(stream)
return len(qbatch.Records), nil
}
Expand Down
Loading

0 comments on commit e418391

Please sign in to comment.