Skip to content

Commit c4a0d1b

Browse files
committed
Add array(array(T)) support for array_intersect()
1 parent 74a0db9 commit c4a0d1b

File tree

1 file changed

+107
-5
lines changed

1 file changed

+107
-5
lines changed

velox/functions/prestosql/ArrayIntersectExcept.cpp

+107-5
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,96 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
190190
explicit ArrayIntersectExceptFunction(SetWithNull<T> constantSet)
191191
: constantSet_(std::move(constantSet)) {}
192192

193+
void intersectSingle(
194+
const SelectivityVector& rows,
195+
std::vector<VectorPtr>& args,
196+
const TypePtr& outputType,
197+
exec::EvalCtx& context,
198+
VectorPtr& result) const {
199+
memory::MemoryPool* pool = context.pool();
200+
BaseVector* data = args[0].get();
201+
202+
exec::LocalDecodedVector dataDecoder(context, *data, rows);
203+
auto decodedArray = dataDecoder.get();
204+
auto baseArray = decodedArray->base()->as<ArrayVector>();
205+
206+
// Decode and acquire array elements vector.
207+
exec::LocalDecodedVector elementsDecoder(context);
208+
auto decodedElements =
209+
decodeArrayElements(dataDecoder, elementsDecoder, rows);
210+
211+
auto elementsCount = countElements<ArrayVector>(rows, *decodedArray);
212+
vector_size_t rowCount = decodedArray->size();
213+
214+
// Allocate new vectors for indices, nulls, length and offsets.
215+
BufferPtr newIndices = allocateIndices(elementsCount, pool);
216+
BufferPtr newElementNulls =
217+
AlignedBuffer::allocate<bool>(elementsCount, pool, bits::kNotNull);
218+
BufferPtr newLengths = allocateSizes(rowCount, pool);
219+
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
220+
221+
// Pointers and cursors to the raw data.
222+
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
223+
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
224+
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
225+
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
226+
227+
vector_size_t indicesCursor = 0;
228+
229+
// Lambda that process each row. This is detached from the code so we can
230+
// apply it differently based on whether the right-hand side set is constant
231+
// or not.
232+
auto processRow = [&](vector_size_t row) {
233+
auto idx = decodedArray->index(row);
234+
auto size = baseArray->sizeAt(idx);
235+
auto offset = baseArray->offsetAt(idx);
236+
237+
for (vector_size_t i = offset; i < (offset + size); ++i) {
238+
if (decodedElements->isNullAt(i)) {
239+
continue;
240+
}
241+
auto arrayVector = decodedElements->base()->as<ArrayVector>();
242+
auto arrIdx = decodedElements->index(i);
243+
auto arrSize = arrayVector->sizeAt(idx);
244+
auto arrOffset = arrayVector->offsetAt(idx);
245+
arrayVector->elements();
246+
247+
exec::LocalDecodedVector arrayDecoder(context, *right, rows);
248+
// Decode and acquire array elements vector.
249+
exec::LocalDecodedVector elementsDecoder(context);
250+
auto decodedRightElements =
251+
decodeArrayElements(arrayDecoder, elementsDecoder, rows);
252+
253+
254+
}
255+
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
256+
};
257+
258+
rows.applyToSelected([&](vector_size_t row) { processRow(row); });
259+
260+
auto newElements = BaseVector::wrapInDictionary(
261+
newElementNulls, newIndices, indicesCursor, baseArray->elements());
262+
auto resultArray = std::make_shared<ArrayVector>(
263+
pool,
264+
outputType,
265+
nullptr,
266+
rowCount,
267+
newOffsets,
268+
newLengths,
269+
newElements);
270+
context.moveOrCopyResult(resultArray, rows, result);
271+
}
272+
193273
void apply(
194274
const SelectivityVector& rows,
195275
std::vector<VectorPtr>& args,
196276
const TypePtr& outputType,
197277
exec::EvalCtx& context,
198278
VectorPtr& result) const override {
279+
if (isIntersect && args.size() == 1) {
280+
intersectSingle(rows, args, outputType, context, result);
281+
}
282+
199283
memory::MemoryPool* pool = context.pool();
200284
BaseVector* left = args[0].get();
201285
BaseVector* right = args[1].get();
@@ -489,8 +573,7 @@ template <bool isIntersect, TypeKind kind>
489573
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
490574
const std::vector<exec::VectorFunctionArg>& inputArgs,
491575
const TypePtr& elementType) {
492-
VELOX_CHECK_EQ(inputArgs.size(), 2);
493-
const BaseVector* rhs = inputArgs[1].constantValue.get();
576+
const BaseVector* rhs = inputArgs.size() == 1 ? nullptr : inputArgs[1].constantValue.get();
494577

495578
if (elementType->providesCustomComparison()) {
496579
return createTypedArraysIntersectExcept<isIntersect, WrappedVectorEntry>(
@@ -508,8 +591,10 @@ std::shared_ptr<exec::VectorFunction> createArrayIntersect(
508591
const std::string& name,
509592
const std::vector<exec::VectorFunctionArg>& inputArgs,
510593
const core::QueryConfig& /*config*/) {
511-
validateMatchingArrayTypes(inputArgs, name, 2);
512-
auto elementType = inputArgs.front().type->childAt(0);
594+
validateMatchingArrayTypes(inputArgs, name, inputArgs.size());
595+
auto elementType = inputArgs.size() == 1
596+
? inputArgs.front().type->childAt(0)->childAt(0)
597+
: inputArgs.front().type->childAt(0);
513598

514599
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
515600
createTypedArraysIntersectExcept,
@@ -534,6 +619,23 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534619
elementType);
535620
}
536621

622+
std::vector<std::shared_ptr<exec::FunctionSignature>>
623+
arrayIntersectSignatures() {
624+
return {// array(T), array(T) -> array(T)
625+
(exec::FunctionSignatureBuilder()
626+
.typeVariable("T")
627+
.returnType("array(T)")
628+
.argumentType("array(T)")
629+
.argumentType("array(T)")
630+
.build()),
631+
// array(array(T)) -> array(T)
632+
(exec::FunctionSignatureBuilder()
633+
.typeVariable("T")
634+
.returnType("array(T)")
635+
.argumentType("array(array(T))")
636+
.build())};
637+
}
638+
537639
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
538640
const std::string& returnType) {
539641
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +702,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600702

601703
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
602704
udf_array_intersect,
603-
signatures("array(T)"),
705+
arrayIntersectSignatures(),
604706
createArrayIntersect);
605707

606708
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

0 commit comments

Comments
 (0)