Skip to content

Commit

Permalink
mysql: GetTableSchemaConnector
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 26, 2024
1 parent 2bc2709 commit 9b1bd7f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 3 deletions.
1 change: 1 addition & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ var (
_ CDCNormalizeConnector = &connclickhouse.ClickHouseConnector{}

_ GetTableSchemaConnector = &connpostgres.PostgresConnector{}
_ GetTableSchemaConnector = &connmysql.MySqlConnector{}
_ GetTableSchemaConnector = &connsnowflake.SnowflakeConnector{}
_ GetTableSchemaConnector = &connclickhouse.ClickHouseConnector{}

Expand Down
129 changes: 128 additions & 1 deletion flow/connectors/mysql/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"time"

"github.com/go-mysql-org/go-mysql/mysql"
Expand All @@ -14,10 +15,12 @@ import (
"go.opentelemetry.io/otel/metric"

"github.com/PeerDB-io/peer-flow/alerting"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/otel_metrics"
"github.com/PeerDB-io/peer-flow/peerdbenv"
)

func (c *MySqlConnector) GetTableSchema(
Expand All @@ -26,7 +29,131 @@ func (c *MySqlConnector) GetTableSchema(
system protos.TypeSystem,
tableIdentifiers []string,
) (map[string]*protos.TableSchema, error) {
panic("TODO")
res := make(map[string]*protos.TableSchema, len(tableIdentifiers))
for _, tableName := range tableIdentifiers {
tableSchema, err := c.getTableSchemaForTable(ctx, env, tableName, system)
if err != nil {
c.logger.Info("error fetching schema for table "+tableName, slog.Any("error", err))
return nil, err
}
res[tableName] = tableSchema
c.logger.Info("fetched schema for table " + tableName)
}

return res, nil
}

func (c *MySqlConnector) getTableSchemaForTable(
ctx context.Context,
env map[string]string,
tableName string,
system protos.TypeSystem,
) (*protos.TableSchema, error) {
schemaTable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, err
}

nullableEnabled, err := peerdbenv.PeerDBNullable(ctx, env)
if err != nil {
return nil, err
}

rs, err := c.Execute(ctx, fmt.Sprintf("select * from %s limit 0", schemaTable.String()))
if err != nil {
return nil, err
}
columns := make([]*protos.FieldDescription, 0, len(rs.Values))
primary := make([]string, 0)
for _, field := range rs.Fields {
var qkind qvalue.QValueKind
switch field.Type {
case mysql.MYSQL_TYPE_DECIMAL:
qkind = qvalue.QValueKindNumeric
case mysql.MYSQL_TYPE_TINY:
qkind = qvalue.QValueKindInt16 // TODO qvalue.QValueKindInt8
case mysql.MYSQL_TYPE_SHORT:
qkind = qvalue.QValueKindInt16
case mysql.MYSQL_TYPE_LONG:
qkind = qvalue.QValueKindInt32
case mysql.MYSQL_TYPE_FLOAT:
qkind = qvalue.QValueKindFloat32
case mysql.MYSQL_TYPE_DOUBLE:
qkind = qvalue.QValueKindFloat64
case mysql.MYSQL_TYPE_NULL:
qkind = qvalue.QValueKindInvalid // TODO qvalue.QValueKindNothing
case mysql.MYSQL_TYPE_TIMESTAMP:
qkind = qvalue.QValueKindTimestamp
case mysql.MYSQL_TYPE_LONGLONG:
qkind = qvalue.QValueKindInt64
case mysql.MYSQL_TYPE_INT24:
qkind = qvalue.QValueKindInt32
case mysql.MYSQL_TYPE_DATE:
qkind = qvalue.QValueKindDate
case mysql.MYSQL_TYPE_TIME:
qkind = qvalue.QValueKindTime
case mysql.MYSQL_TYPE_DATETIME:
qkind = qvalue.QValueKindTimestamp
case mysql.MYSQL_TYPE_YEAR:
qkind = qvalue.QValueKindInt16
case mysql.MYSQL_TYPE_NEWDATE:
qkind = qvalue.QValueKindDate
case mysql.MYSQL_TYPE_VARCHAR:
qkind = qvalue.QValueKindString
case mysql.MYSQL_TYPE_BIT:
qkind = qvalue.QValueKindInt64
case mysql.MYSQL_TYPE_TIMESTAMP2:
qkind = qvalue.QValueKindTimestamp
case mysql.MYSQL_TYPE_DATETIME2:
qkind = qvalue.QValueKindTimestamp
case mysql.MYSQL_TYPE_TIME2:
qkind = qvalue.QValueKindTime
case mysql.MYSQL_TYPE_JSON:
qkind = qvalue.QValueKindJSON
case mysql.MYSQL_TYPE_NEWDECIMAL:
qkind = qvalue.QValueKindNumeric
case mysql.MYSQL_TYPE_ENUM:
qkind = qvalue.QValueKindInt64
case mysql.MYSQL_TYPE_SET:
qkind = qvalue.QValueKindInt64
case mysql.MYSQL_TYPE_TINY_BLOB:
qkind = qvalue.QValueKindBytes
case mysql.MYSQL_TYPE_MEDIUM_BLOB:
qkind = qvalue.QValueKindBytes
case mysql.MYSQL_TYPE_LONG_BLOB:
qkind = qvalue.QValueKindBytes
case mysql.MYSQL_TYPE_BLOB:
qkind = qvalue.QValueKindBytes
case mysql.MYSQL_TYPE_VAR_STRING:
qkind = qvalue.QValueKindString
case mysql.MYSQL_TYPE_STRING:
qkind = qvalue.QValueKindString
case mysql.MYSQL_TYPE_GEOMETRY:
qkind = qvalue.QValueKindGeometry
default:
return nil, fmt.Errorf("unknown mysql type %d", field.Type)
}
column := &protos.FieldDescription{
Name: string(field.Name),
Type: string(qkind),
TypeModifier: 0, // TODO numeric precision info
Nullable: (field.Flag & mysql.NOT_NULL_FLAG) == 0,
}
if (field.Flag & mysql.PRI_KEY_FLAG) != 0 {
primary = append(primary, column.Name)

}

Check failure on line 145 in flow/connectors/mysql/cdc.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)
columns = append(columns, column)
}

return &protos.TableSchema{
TableIdentifier: tableName,
PrimaryKeyColumns: primary,
IsReplicaIdentityFull: false,
System: system,
NullableEnabled: nullableEnabled,
Columns: columns,
}, nil
}

func (c *MySqlConnector) EnsurePullability(
Expand Down
3 changes: 1 addition & 2 deletions flow/workflows/cdc_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ func GetSideEffect[T any](ctx workflow.Context, f func(workflow.Context) T) T {
})

var result T
err := sideEffect.Get(&result)
if err != nil {
if err := sideEffect.Get(&result); err != nil {
panic(err)
}
return result
Expand Down

0 comments on commit 9b1bd7f

Please sign in to comment.