Skip to content

Commit

Permalink
Storage write api - support default stream (#226)
Browse files Browse the repository at this point in the history
* storage write api support default stream

* commit in case of commited stream

* Storage write api - support default stream: add test cases

* empty commit

---------

Co-authored-by: Matan Levy <[email protected]>
  • Loading branch information
MatanLevy and Matan Levy authored Jun 12, 2024
1 parent 316038b commit 76384ba
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 13 deletions.
62 changes: 59 additions & 3 deletions server/storage_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ func (s *storageWriteServer) CreateWriteStream(ctx context.Context, req *storage
TableSchema: schema,
WriteMode: storagepb.WriteStream_INSERT,
}

s.mu.Lock()
s.streamMap[streamName] = &writeStreamStatus{
streamType: streamType,
Expand Down Expand Up @@ -525,6 +524,7 @@ func (s *storageWriteServer) appendRows(req *storagepb.AppendRowsRequest, msgDes
status.rows = append(status.rows, data...)
}
return s.sendResult(stream, streamName, offset+int64(len(rows)))

}

func (s *storageWriteServer) sendResult(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, offset int64) error {
Expand Down Expand Up @@ -677,10 +677,14 @@ func (s *storageWriteServer) insertTableData(ctx context.Context, tx *connection

func (s *storageWriteServer) GetWriteStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) {
s.mu.RLock()
defer s.mu.RUnlock()
status, exists := s.streamMap[req.Name]
s.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("failed to find stream from %s", req.Name)
stream, err := s.createDefaultStream(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to find stream from %s", req.Name)
}
return stream, err
}
return status.stream, nil
}
Expand Down Expand Up @@ -775,6 +779,58 @@ func (s *storageWriteServer) FlushRows(ctx context.Context, req *storagepb.Flush
}, nil
}

/*
*
According to google documentation (https://pkg.go.dev/cloud.google.com/go/bigquery/storage/apiv1#BigQueryWriteClient.GetWriteStream)
every table has a special stream named ‘_default’ to which data can be written. This stream doesn’t need to be created using CreateWriteStream
Here we create the default stream and add it to map in case it not exists yet, the GetWriteStreamRequest given as second
argument should have Name in this format: projects/<projectId>/datasets/<datasetId>/tables/<tableId>/streams/_default
*/
func (s *storageWriteServer) createDefaultStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) {
streamId := req.Name
suffix := "_default"
streams := "/streams/"
if !strings.HasSuffix(streamId, suffix) {
return nil, fmt.Errorf("unexpected stream id: %s, expected '%s' suffix", streamId, suffix)
}
index := strings.LastIndex(streamId, streams)
if index == -1 {
return nil, fmt.Errorf("unexpected stream id: %s, expected containg '%s'", streamId, streams)
}
streamPart := streamId[:index]
writeStreamReq := &storagepb.CreateWriteStreamRequest{
Parent: streamPart,
WriteStream: &storagepb.WriteStream{
Type: storagepb.WriteStream_COMMITTED,
},
}
stream, err := s.CreateWriteStream(ctx, writeStreamReq)
if err != nil {
return nil, err
}
projectID, datasetID, tableID, err := getIDsFromPath(streamPart)
if err != nil {
return nil, err
}
tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID)
if err != nil {
return nil, err
}
streamStatus := &writeStreamStatus{
streamType: storagepb.WriteStream_COMMITTED,
stream: stream,
projectID: projectID,
datasetID: datasetID,
tableID: tableID,
tableMetadata: tableMetadata,
}
s.mu.Lock()
defer s.mu.Unlock()
s.streamMap[streamId] = streamStatus
return stream, nil
}

func getIDsFromPath(path string) (string, string, string, error) {
paths := strings.Split(path, "/")
if len(paths)%2 != 0 {
Expand Down
42 changes: 32 additions & 10 deletions server/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ func TestStorageWrite(t *testing.T) {
for _, test := range []struct {
name string
streamType storagepb.WriteStream_Type
isDefaultStream bool
expectedRowsAfterFirstWrite int
expectedRowsAfterSecondWrite int
expectedRowsAfterThirdWrite int
Expand All @@ -416,6 +417,15 @@ func TestStorageWrite(t *testing.T) {
expectedRowsAfterThirdWrite: 6,
expectedRowsAfterExplicitCommit: 6,
},
{
name: "default",
streamType: storagepb.WriteStream_COMMITTED,
isDefaultStream: true,
expectedRowsAfterFirstWrite: 1,
expectedRowsAfterSecondWrite: 4,
expectedRowsAfterThirdWrite: 6,
expectedRowsAfterExplicitCommit: 6,
},
} {
const (
projectID = "test"
Expand Down Expand Up @@ -490,24 +500,36 @@ func TestStorageWrite(t *testing.T) {
}
defer client.Close()
t.Run(test.name, func(t *testing.T) {
writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{
Parent: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID),
WriteStream: &storagepb.WriteStream{
Type: test.streamType,
},
})
if err != nil {
t.Fatalf("CreateWriteStream: %v", err)
var writeStreamName string
fullTableName := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID)
if !test.isDefaultStream {
writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{
Parent: fullTableName,
WriteStream: &storagepb.WriteStream{
Type: test.streamType,
},
})
if err != nil {
t.Fatalf("CreateWriteStream: %v", err)
}
writeStreamName = writeStream.GetName()
}
m := &exampleproto.SampleData{}
descriptorProto, err := adapt.NormalizeDescriptor(m.ProtoReflect().Descriptor())
if err != nil {
t.Fatalf("NormalizeDescriptor: %v", err)
}
var writerOptions []managedwriter.WriterOption
if test.isDefaultStream {
writerOptions = append(writerOptions, managedwriter.WithType(managedwriter.DefaultStream))
writerOptions = append(writerOptions, managedwriter.WithDestinationTable(fullTableName))
} else {
writerOptions = append(writerOptions, managedwriter.WithStreamName(writeStreamName))
}
writerOptions = append(writerOptions, managedwriter.WithSchemaDescriptor(descriptorProto))
managedStream, err := client.NewManagedStream(
ctx,
managedwriter.WithStreamName(writeStream.GetName()),
managedwriter.WithSchemaDescriptor(descriptorProto),
writerOptions...,
)
if err != nil {
t.Fatalf("NewManagedStream: %v", err)
Expand Down

0 comments on commit 76384ba

Please sign in to comment.