-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pqarrow/arrowutils: Add SortRecord and ReorderRecord (#628)
* pqarrow/arrowutils: Add SortRecord and ReorderRecord This is extract from a previous PR #461. * pqarrow/arrowutils: Update SortRecord to allow for multiple sort columns This isn't implemented yet, just the function signature is future proof. * pqarrow/arrowutils: Use compute.Take for ReorderRecord * pqarrow/arrowutils: Add support for sorting NULL NULL always gets sorted to the back. This seems to be the default for other language implementations. It can be made configurable in the future. * Update pqarrow/arrowutils/sort.go Co-authored-by: Geofrey Ernest <[email protected]> * Update pqarrow/arrowutils/sort.go Co-authored-by: Geofrey Ernest <[email protected]> * Update pqarrow/arrowutils/sort.go Co-authored-by: Geofrey Ernest <[email protected]> * Update pqarrow/arrowutils/sort.go Co-authored-by: Geofrey Ernest <[email protected]> * Update pqarrow/arrowutils/sort.go Co-authored-by: Geofrey Ernest <[email protected]> * pqarrow/arrowutils: Remove sorting *array.Binary This isn't properly unit tested and was more of an experiment. * pqarrow/arrowutils: Add context and reserve indices length --------- Co-authored-by: Geofrey Ernest <[email protected]>
- Loading branch information
1 parent
b41edc6
commit ee6970e
Showing
4 changed files
with
180 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package arrowutils | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sort" | ||
|
||
"github.com/apache/arrow/go/v14/arrow" | ||
"github.com/apache/arrow/go/v14/arrow/array" | ||
"github.com/apache/arrow/go/v14/arrow/compute" | ||
"github.com/apache/arrow/go/v14/arrow/memory" | ||
) | ||
|
||
// SortRecord sorts the given record's rows by the given column. Currently only supports int64, string and binary columns. | ||
func SortRecord(mem memory.Allocator, r arrow.Record, cols []int) (*array.Int64, error) { | ||
if len(cols) > 1 { | ||
return nil, fmt.Errorf("sorting by multiple columns isn't implemented yet") | ||
} | ||
indicesBuilder := array.NewInt64Builder(mem) | ||
|
||
if r.NumRows() == 0 { | ||
return indicesBuilder.NewInt64Array(), nil | ||
} | ||
if r.NumRows() == 1 { | ||
indicesBuilder.Append(0) | ||
return indicesBuilder.NewInt64Array(), nil | ||
} | ||
|
||
indices := make([]int64, r.NumRows()) | ||
// populate indices | ||
for i := range indices { | ||
indices[i] = int64(i) | ||
} | ||
|
||
switch c := r.Column(cols[0]).(type) { | ||
case *array.Int64: | ||
sort.Sort(orderedSorter[int64]{array: c, indices: indices}) | ||
case *array.String: | ||
sort.Sort(orderedSorter[string]{array: c, indices: indices}) | ||
default: | ||
return nil, fmt.Errorf("unsupported column type for sorting %T", c) | ||
} | ||
|
||
indicesBuilder.Reserve(len(indices)) | ||
for _, i := range indices { | ||
indicesBuilder.Append(i) | ||
} | ||
|
||
return indicesBuilder.NewInt64Array(), nil | ||
} | ||
|
||
// ReorderRecord reorders the given record's rows by the given indices. | ||
// This is a wrapper around compute.Take which handles the type castings. | ||
func ReorderRecord(ctx context.Context, r arrow.Record, indices arrow.Array) (arrow.Record, error) { | ||
res, err := compute.Take( | ||
ctx, | ||
*compute.DefaultTakeOptions(), | ||
compute.NewDatum(r), | ||
compute.NewDatum(indices), | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return res.(*compute.RecordDatum).Value, nil | ||
} | ||
|
||
type orderedArray[T int64 | float64 | string] interface { | ||
Value(int) T | ||
IsNull(int) bool | ||
Len() int | ||
} | ||
|
||
type orderedSorter[T int64 | float64 | string] struct { | ||
array orderedArray[T] | ||
indices []int64 | ||
} | ||
|
||
func (s orderedSorter[T]) Len() int { | ||
return s.array.Len() | ||
} | ||
|
||
func (s orderedSorter[T]) Less(i, j int) bool { | ||
if s.array.IsNull(int(s.indices[i])) { | ||
return false | ||
} | ||
if s.array.IsNull(int(s.indices[j])) { | ||
return true | ||
} | ||
return s.array.Value(int(s.indices[i])) < s.array.Value(int(s.indices[j])) | ||
} | ||
|
||
func (s orderedSorter[T]) Swap(i, j int) { | ||
s.indices[i], s.indices[j] = s.indices[j], s.indices[i] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package arrowutils | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/apache/arrow/go/v14/arrow" | ||
"github.com/apache/arrow/go/v14/arrow/array" | ||
"github.com/apache/arrow/go/v14/arrow/memory" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestSortRecord(t *testing.T) { | ||
ctx := context.Background() | ||
schema := arrow.NewSchema( | ||
[]arrow.Field{ | ||
{Name: "int", Type: arrow.PrimitiveTypes.Int64}, | ||
{Name: "string", Type: arrow.BinaryTypes.String}, | ||
}, | ||
nil, | ||
) | ||
|
||
mem := memory.DefaultAllocator | ||
ib := array.NewInt64Builder(mem) | ||
ib.Append(0) | ||
ib.AppendNull() | ||
ib.Append(3) | ||
ib.Append(5) | ||
ib.Append(1) | ||
|
||
sb := array.NewStringBuilder(mem) | ||
sb.Append("d") | ||
sb.Append("c") | ||
sb.Append("b") | ||
sb.AppendNull() | ||
sb.Append("a") | ||
|
||
record := array.NewRecord(schema, []arrow.Array{ib.NewArray(), sb.NewArray()}, int64(5)) | ||
|
||
// Sort the record by the first column - int64 | ||
{ | ||
sortedIndices, err := SortRecord(mem, record, []int{record.Schema().FieldIndices("int")[0]}) | ||
require.NoError(t, err) | ||
require.Equal(t, []int64{0, 4, 2, 3, 1}, sortedIndices.Int64Values()) | ||
|
||
sortedByInts, err := ReorderRecord(ctx, record, sortedIndices) | ||
require.NoError(t, err) | ||
|
||
// check that the column got sortedIndices | ||
intCol := sortedByInts.Column(0).(*array.Int64) | ||
require.Equal(t, []int64{0, 1, 3, 5, 0}, intCol.Int64Values()) | ||
require.True(t, intCol.IsNull(intCol.Len()-1)) // last is NULL | ||
// make sure the other column got updated too | ||
strings := make([]string, sortedByInts.NumRows()) | ||
stringCol := sortedByInts.Column(1).(*array.String) | ||
for i := 0; i < int(sortedByInts.NumRows()); i++ { | ||
strings[i] = stringCol.Value(i) | ||
} | ||
require.Equal(t, []string{"d", "a", "b", "", "c"}, strings) | ||
} | ||
|
||
// Sort the record by the second column - string | ||
{ | ||
sortedIndices, err := SortRecord(mem, record, []int{record.Schema().FieldIndices("string")[0]}) | ||
require.NoError(t, err) | ||
require.Equal(t, []int64{4, 2, 1, 0, 3}, sortedIndices.Int64Values()) | ||
|
||
sortedByStrings, err := ReorderRecord(ctx, record, sortedIndices) | ||
require.NoError(t, err) | ||
|
||
// check that the column got sortedByInts | ||
intCol := sortedByStrings.Column(0).(*array.Int64) | ||
require.Equal(t, []int64{1, 3, 0, 0, 5}, intCol.Int64Values()) | ||
// make sure the other column got updated too | ||
strings := make([]string, sortedByStrings.NumRows()) | ||
stringCol := sortedByStrings.Column(1).(*array.String) | ||
for i := 0; i < int(sortedByStrings.NumRows()); i++ { | ||
strings[i] = stringCol.Value(i) | ||
} | ||
require.Equal(t, []string{"a", "b", "c", "d", ""}, strings) | ||
require.True(t, stringCol.IsNull(stringCol.Len()-1)) // last is NULL | ||
} | ||
} |