Skip to content

Commit

Permalink
feat(arrow/ipc): add functions to generate payloads
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Nov 22, 2024
1 parent fb174ba commit 9786155
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 14 deletions.
49 changes: 35 additions & 14 deletions arrow/ipc/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) error {
// non-zero offset: slice the buffer
offset := int64(data.Offset()) * typeWidth
// send padding if available
len := minI64(bitutil.CeilByte64(arrLen*typeWidth), int64(values.Len())-offset)
len := min(bitutil.CeilByte64(arrLen*typeWidth), int64(values.Len())-offset)
values = memory.NewBufferBytes(values.Bytes()[offset : offset+len])
default:
if values != nil {
Expand All @@ -628,7 +628,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) error {
// slice data buffer to include the range we need now.
var (
beg int64 = 0
len = minI64(paddedLength(totalDataBytes, kArrowAlignment), int64(totalDataBytes))
len = min(paddedLength(totalDataBytes, kArrowAlignment), int64(totalDataBytes))
)
if arr.Len() > 0 {
beg = arr.ValueOffset64(0)
Expand All @@ -655,7 +655,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) error {
// non-zero offset: slice the buffer
offset := data.Offset() * int(typeWidth)
// send padding if available
len := int(minI64(bitutil.CeilByte64(arrLen*typeWidth), int64(values.Len()-offset)))
len := int(min(bitutil.CeilByte64(arrLen*typeWidth), int64(values.Len()-offset)))
values = memory.SliceBuffer(values, offset, len)
default:
if values != nil {
Expand Down Expand Up @@ -1028,7 +1028,7 @@ func (w *recordEncoder) rebaseDenseUnionValueOffsets(arr *array.DenseUnion, offs
} else {
shiftedOffsets[i] = unshiftedOffsets[i] - offsets[c]
}
lengths[c] = maxI32(lengths[c], shiftedOffsets[i]+1)
lengths[c] = max(lengths[c], shiftedOffsets[i]+1)
}
return shiftedOffsetsBuf
}
Expand Down Expand Up @@ -1071,7 +1071,7 @@ func getTruncatedBuffer(offset, length int64, byteWidth int32, buf *memory.Buffe

paddedLen := paddedLength(length*int64(byteWidth), kArrowAlignment)
if offset != 0 || paddedLen < int64(buf.Len()) {
return memory.SliceBuffer(buf, int(offset*int64(byteWidth)), int(minI64(paddedLen, int64(buf.Len()))))
return memory.SliceBuffer(buf, int(offset*int64(byteWidth)), int(min(paddedLen, int64(buf.Len()))))
}
buf.Retain()
return buf
Expand All @@ -1084,16 +1084,37 @@ func needTruncate(offset int64, buf *memory.Buffer, minLength int64) bool {
return offset != 0 || minLength < int64(buf.Len())
}

func minI64(a, b int64) int64 {
if a < b {
return a
// GetRecordBatchPayload produces the ipc payload for a given record batch.
// The resulting payload itself must be released by the caller via the Release
// method after it is no longer needed.
func GetRecordBatchPayload(batch arrow.Record, opts ...Option) (Payload, error) {
cfg := newConfig(opts...)
var (
data = Payload{msg: MessageRecordBatch}
enc = newRecordEncoder(
cfg.alloc,
0,
kMaxNestingDepth,
true,
cfg.codec,
cfg.compressNP,
cfg.minSpaceSavings,
make([]compressor, cfg.compressNP),
)
)

err := enc.Encode(&data, batch)
if err != nil {
return Payload{}, err
}
return b

return data, nil
}

func maxI32(a, b int32) int32 {
if a > b {
return a
}
return b
// GetSchemaPayload produces the ipc payload for a given schema.
func GetSchemaPayload(schema *arrow.Schema, mem memory.Allocator) Payload {
var mapper dictutils.Mapper
mapper.ImportSchema(schema)
ps := payloadFromSchema(schema, mem, &mapper)
return ps[0]
}
78 changes: 78 additions & 0 deletions arrow/ipc/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"strings"
"testing"
Expand Down Expand Up @@ -254,3 +255,80 @@ func TestWriterInferSchema(t *testing.T) {

require.True(t, r.Schema().Equal(rec.Schema()))
}

type testMsgReader struct {
messages []*Message

curmsg *Message
}

func (r *testMsgReader) Message() (*Message, error) {
if r.curmsg != nil {
r.curmsg.Release()
r.curmsg = nil
}

if len(r.messages) == 0 {
return nil, io.EOF
}

r.curmsg = r.messages[0]
r.messages = r.messages[1:]
return r.curmsg, nil
}

func (r *testMsgReader) Release() {
if r.curmsg != nil {
r.curmsg.Release()
r.curmsg = nil
}
for _, m := range r.messages {
m.Release()
}
r.messages = nil
}

func (r *testMsgReader) Retain() {}

func TestGetPayloads(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer mem.AssertSize(t, 0)

schema := arrow.NewSchema([]arrow.Field{
{Name: "s", Type: arrow.BinaryTypes.String},
}, nil)

b := array.NewRecordBuilder(mem, schema)
defer b.Release()

b.Field(0).(*array.StringBuilder).AppendValues([]string{"foo", "bar", "baz"}, nil)
rec := b.NewRecord()
defer rec.Release()

schemaPayload := GetSchemaPayload(rec.Schema(), mem)
defer schemaPayload.Release()
dataPayload, err := GetRecordBatchPayload(rec, WithAllocator(mem))
require.NoError(t, err)
defer dataPayload.Release()

var schemaBuf, dataBuf bytes.Buffer
schemaPayload.SerializeBody(&schemaBuf)
dataPayload.SerializeBody(&dataBuf)

msgrdr := &testMsgReader{
messages: []*Message{
NewMessage(schemaPayload.meta, memory.NewBufferBytes(schemaBuf.Bytes())),
NewMessage(dataPayload.meta, memory.NewBufferBytes(dataBuf.Bytes())),
},
}

rdr, err := NewReaderFromMessageReader(msgrdr, WithAllocator(mem))
require.NoError(t, err)
defer rdr.Release()

assert.Truef(t, rdr.Schema().Equal(rec.Schema()), "expected: %s\ngot: %s", rec.Schema(), rdr.Schema())
got, err := rdr.Read()
require.NoError(t, err)

assert.Truef(t, array.RecordEqual(rec, got), "expected: %s\ngot: %s", rec, got)
}

0 comments on commit 9786155

Please sign in to comment.