Skip to content

Commit

Permalink
Merge pull request #1095 from dolthub/zachmu/copy
Browse files Browse the repository at this point in the history
Better support for COPY FROM statements
  • Loading branch information
zachmu authored Jan 7, 2025
2 parents 55c7b8f + 9f92726 commit a8d996d
Show file tree
Hide file tree
Showing 17 changed files with 833 additions and 456 deletions.
218 changes: 122 additions & 96 deletions core/dataloader/csvdataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (

"github.com/dolthub/dolt/go/libraries/doltcore/table"
"github.com/dolthub/go-mysql-server/sql"
"github.com/sirupsen/logrus"

"github.com/dolthub/doltgresql/server/types"
)
Expand All @@ -31,155 +30,182 @@ import (
type CsvDataLoader struct {
results LoadDataResults
partialRecord string
rowInserter sql.RowInserter
nextDataChunk *bufio.Reader
colTypes []*types.DoltgresType
sch sql.Schema
removeHeader bool
delimiter string
}

func (cdl *CsvDataLoader) SetNextDataChunk(ctx *sql.Context, data *bufio.Reader) error {
cdl.nextDataChunk = data
return nil
}

var _ DataLoader = (*CsvDataLoader)(nil)

const defaultCsvDelimiter = ","

// NewCsvDataLoader creates a new DataLoader instance that will insert records from chunks of CSV data into |table|. If
// NewCsvDataLoader creates a new DataLoader instance that will produce rows for the schema provided.
// |header| is true, the first line of the data will be treated as a header and ignored. If |delimiter| is not the empty
// string, it will be used as the delimiter separating value.
func NewCsvDataLoader(ctx *sql.Context, table sql.InsertableTable, delimiter string, header bool) (*CsvDataLoader, error) {
colTypes, err := getColumnTypes(table.Schema())
func NewCsvDataLoader(colNames []string, sch sql.Schema, delimiter string, header bool) (*CsvDataLoader, error) {
colTypes, reducedSch, err := getColumnTypes(colNames, sch)
if err != nil {
return nil, err
}

rowInserter := table.Inserter(ctx)
rowInserter.StatementBegin(ctx)

if delimiter == "" {
delimiter = defaultCsvDelimiter
}

return &CsvDataLoader{
rowInserter: rowInserter,
colTypes: colTypes,
sch: table.Schema(),
sch: reducedSch,
removeHeader: header,
delimiter: delimiter,
}, nil
}

// LoadChunk implements the DataLoader interface
func (cdl *CsvDataLoader) LoadChunk(ctx *sql.Context, data *bufio.Reader) error {
combinedReader := NewStringPrefixReader(cdl.partialRecord, data)
cdl.partialRecord = ""

reader, err := newCsvReaderWithDelimiter(combinedReader, cdl.delimiter)
if err != nil {
return err
// nextRow attempts to read the next row from the data and return it, and returns true if a row was read
func (cdl *CsvDataLoader) nextRow(ctx *sql.Context, reader *csvReader) (sql.Row, bool, error) {
if cdl.removeHeader {
_, err := reader.readLine()
cdl.removeHeader = false
if err != nil {
return nil, false, err
}
}

for {
// Read the next record from the data
if cdl.removeHeader {
_, err := reader.readLine()
cdl.removeHeader = false
if err != nil {
return err
}
record, err := reader.ReadSqlRow()
if err != nil {
if ple, ok := err.(*partialLineError); ok {
cdl.partialRecord = ple.partialLine
return nil, false, nil
}

record, err := reader.ReadSqlRow()
if err != nil {
if ple, ok := err.(*partialLineError); ok {
cdl.partialRecord = ple.partialLine
break
}

// csvReader will return a BadRow error if it encounters an input line without the
// correct number of columns. If we see the end of data marker, then break out of the
// loop and return from this function without returning an error.
if _, ok := err.(*table.BadRow); ok {
if len(record) == 1 && record[0] == "\\." {
break
}
}

if err != io.EOF {
return err
// csvReader will return a BadRow error if it encounters an input line without the
// correct number of columns. If we see the end of data marker, then break out of the
// loop and return from this function without returning an error.
if _, ok := err.(*table.BadRow); ok {
if len(record) == 1 && record[0] == "\\." {
return nil, false, nil
}

recordValues := make([]string, 0, len(record))
for _, v := range record {
recordValues = append(recordValues, fmt.Sprintf("%v", v))
}
cdl.partialRecord = strings.Join(recordValues, ",")
break
}

// If we see the end of data marker, then break out of the loop. Normally this will happen in the code
// above when we receive a BadRow error, since there won't be enough values, but if a table only has
// one column, we won't get a BadRow error, and we'll handle the end of data marker here.
if len(record) == 1 && record[0] == "\\." {
break
}

if len(record) > len(cdl.colTypes) {
return fmt.Errorf("extra data after last expected column")
} else if len(record) < len(cdl.colTypes) {
return fmt.Errorf(`missing data for column "%s"`, cdl.sch[len(record)].Name)
if err != io.EOF {
return nil, false, err
}

// Cast the values using I/O input
row := make(sql.Row, len(cdl.colTypes))
for i := range cdl.colTypes {
if record[i] == nil {
row[i] = nil
} else {
row[i], err = cdl.colTypes[i].IoInput(ctx, fmt.Sprintf("%v", record[i]))
if err != nil {
return err
}
}
recordValues := make([]string, 0, len(record))
for _, v := range record {
recordValues = append(recordValues, fmt.Sprintf("%v", v))
}
cdl.partialRecord = strings.Join(recordValues, ",")
return nil, false, nil
}

// Insert the row
if err = cdl.rowInserter.Insert(ctx, row); err != nil {
return err
}
cdl.results.RowsLoaded += 1
// If we see the end of data marker, then break out of the loop. Normally this will happen in the code
// above when we receive a BadRow error, since there won't be enough values, but if a table only has
// one column, we won't get a BadRow error, and we'll handle the end of data marker here.
if len(record) == 1 && record[0] == "\\." {
return nil, false, nil
}

return nil
}
if len(record) > len(cdl.colTypes) {
return nil, false, fmt.Errorf("extra data after last expected column")
} else if len(record) < len(cdl.colTypes) {
return nil, false, fmt.Errorf(`missing data for column "%s"`, cdl.sch[len(record)].Name)
}

// Abort implements the DataLoader interface
func (cdl *CsvDataLoader) Abort(ctx *sql.Context) error {
defer func() {
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil {
logrus.Warnf("error closing rowInserter: %v", closeErr)
// Cast the values using I/O input
row := make(sql.Row, len(cdl.colTypes))
for i := range cdl.colTypes {
if record[i] == nil {
row[i] = nil
} else {
row[i], err = cdl.colTypes[i].IoInput(ctx, fmt.Sprintf("%v", record[i]))
if err != nil {
return nil, false, err
}
}
}()
}

return cdl.rowInserter.DiscardChanges(ctx, nil)
return row, true, nil
}

// Finish implements the DataLoader interface
func (cdl *CsvDataLoader) Finish(ctx *sql.Context) (*LoadDataResults, error) {
defer func() {
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil {
logrus.Warnf("error closing rowInserter: %v", closeErr)
}
}()

// If there is partial data from the last chunk that hasn't been inserted, return an error.
if cdl.partialRecord != "" {
return nil, fmt.Errorf("partial record (%s) found at end of data load", cdl.partialRecord)
}

err := cdl.rowInserter.StatementComplete(ctx)
return &cdl.results, nil
}

func (cdl *CsvDataLoader) Resolved() bool {
return true
}

func (cdl *CsvDataLoader) String() string {
return "CsvDataLoader"
}

func (cdl *CsvDataLoader) Schema() sql.Schema {
return cdl.sch
}

func (cdl *CsvDataLoader) Children() []sql.Node {
return nil
}

func (cdl *CsvDataLoader) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(cdl, len(children), 0)
}
return cdl, nil
}

func (cdl *CsvDataLoader) IsReadOnly() bool {
return true
}

type csvRowIter struct {
cdl *CsvDataLoader
reader *csvReader
}

func (c csvRowIter) Next(ctx *sql.Context) (sql.Row, error) {
row, hasNext, err := c.cdl.nextRow(ctx, c.reader)
if err != nil {
err = cdl.rowInserter.DiscardChanges(ctx, err)
return nil, err
}

return &cdl.results, nil
// TODO: this isn't the best way to handle the count of rows, something like a RowUpdateAccumulator would be better
if hasNext {
c.cdl.results.RowsLoaded++
} else {
return nil, io.EOF
}

return row, nil
}

func (c csvRowIter) Close(context *sql.Context) error {
return nil
}

var _ sql.RowIter = (*csvRowIter)(nil)

func (cdl *CsvDataLoader) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
combinedReader := NewStringPrefixReader(cdl.partialRecord, cdl.nextDataChunk)
cdl.partialRecord = ""

csvReader, err := newCsvReaderWithDelimiter(combinedReader, cdl.delimiter)
if err != nil {
return nil, err
}

return &csvRowIter{cdl: cdl, reader: csvReader}, nil
}
40 changes: 24 additions & 16 deletions core/dataloader/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,16 @@ import (
// incomplete records, and saving that partial record until the next call to LoadChunk, so that it may be prefixed
// with the incomplete record.
type DataLoader interface {
// LoadChunk reads the records from |data| and inserts them into the previously configured table. Data records
sql.ExecSourceRel

// SetNextDataChunk sets the next data chunk to be processed by the DataLoader. Data records
// are not guaranteed to start and end cleanly on chunk boundaries, so implementations must recognize incomplete
// records and save them to prepend on the next processed chunk.
LoadChunk(ctx *sql.Context, data *bufio.Reader) error

// Abort aborts the current load operation and releases all used resources.
Abort(ctx *sql.Context) error
SetNextDataChunk(ctx *sql.Context, data *bufio.Reader) error

// Finish finalizes the current load operation and commits the inserted rows so that the data becomes visibile
// to clients. Implementations should check that the last call to LoadChunk did not end with an incomplete
// record and return an error to the caller if so. The returned LoadDataResults describe the load operation,
// including how many rows were inserted.
// Finish finalizes the current load operation and cleans up any resources used. Implementations should check that
// the last call to LoadChunk did not end with an incomplete record and return an error to the caller if so. The
// returned LoadDataResults describe the load operation, including how many rows were inserted.
Finish(ctx *sql.Context) (*LoadDataResults, error)
}

Expand All @@ -49,17 +47,27 @@ type LoadDataResults struct {
RowsLoaded int32
}

// getColumnTypes examines |sch| and returns a slice of DoltgresTypes in the order of the schema's columns. If any
// columns in the schema are not DoltgresType instances, an error is returned.
func getColumnTypes(sch sql.Schema) ([]*types.DoltgresType, error) {
colTypes := make([]*types.DoltgresType, len(sch))
for i, col := range sch {
// getColumnTypes returns the types of the columns in the schema that match the provided column names, in the order
// they are provided. If a subset of column names are provided, the returned types will only contain those columns.
// If the column names are not found in the schema, an error is returned.
func getColumnTypes(colNames []string, sch sql.Schema) ([]*types.DoltgresType, sql.Schema, error) {
colTypes := make([]*types.DoltgresType, len(colNames))
reducedSch := make(sql.Schema, len(colNames))
for i, colName := range colNames {
colIdx := sch.IndexOfColName(colName)
if colIdx < 0 {
// should be impossible
return nil, nil, fmt.Errorf("column %s not found in schema", colName)
}
col := sch[colIdx]
var ok bool
colTypes[i], ok = col.Type.(*types.DoltgresType)
if !ok {
return nil, fmt.Errorf("unsupported column type: name: %s, type: %T", col.Name, col.Type)
return nil, nil, fmt.Errorf("unsupported column type: name: %s, type: %T", col.Name, col.Type)
}

reducedSch[i] = col
}

return colTypes, nil
return colTypes, reducedSch, nil
}
Loading

0 comments on commit a8d996d

Please sign in to comment.