Skip to content

Commit 5ef71dd

Browse files
committed
Add single parameter (array<array<T>>) support for array_intersect
1 parent f93eae6 commit 5ef71dd

File tree

2 files changed

+377
-4
lines changed

2 files changed

+377
-4
lines changed

velox/functions/prestosql/ArrayIntersectExcept.cpp

+160-4
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ void generateSet(
138138
DecodedVector* decodeArrayElements(
139139
exec::LocalDecodedVector& arrayDecoder,
140140
exec::LocalDecodedVector& elementsDecoder,
141-
const SelectivityVector& rows) {
141+
const SelectivityVector& rows,
142+
SelectivityVector* elementRows = nullptr) {
142143
auto decodedVector = arrayDecoder.get();
143144
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();
144145

@@ -147,10 +148,125 @@ DecodedVector* decodeArrayElements(
147148
auto elementsSelectivityRows = toElementRows(
148149
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
149150
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
151+
if (elementRows != nullptr) {
152+
*elementRows = elementsSelectivityRows;
153+
}
150154
auto decodedElementsVector = elementsDecoder.get();
151155
return decodedElementsVector;
152156
}
153157

158+
template <typename T>
159+
class ArraysIntersectSingleParam : public exec::VectorFunction {
160+
public:
161+
/// This class is used for array_intersect function with single parameter.
162+
void apply(
163+
const SelectivityVector& rows,
164+
std::vector<VectorPtr>& args,
165+
const TypePtr& outputType,
166+
exec::EvalCtx& context,
167+
VectorPtr& result) const {
168+
memory::MemoryPool* pool = context.pool();
169+
170+
exec::LocalDecodedVector outerArrayDecoder(context, *args[0], rows);
171+
auto decodedOuterArray = outerArrayDecoder.get();
172+
auto outerArray = decodedOuterArray->base()->as<ArrayVector>();
173+
174+
exec::LocalDecodedVector innerArrayDecoder(context);
175+
SelectivityVector innerRows;
176+
auto decodedInnerArray = decodeArrayElements(
177+
outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
178+
auto innerArray = decodedInnerArray->base()->as<ArrayVector>();
179+
180+
exec::LocalDecodedVector elementDecoder(context);
181+
SelectivityVector elementRows;
182+
auto decodedInnerElement = decodeArrayElements(
183+
innerArrayDecoder, elementDecoder, innerRows, &elementRows);
184+
185+
const auto elementCount =
186+
countElements<ArrayVector>(innerRows, *decodedInnerArray);
187+
const auto rowCount = rows.end();
188+
189+
// Allocate new vectors for indices, nulls, length and offsets.
190+
BufferPtr newIndices = allocateIndices(elementCount, pool);
191+
BufferPtr newElementNulls =
192+
AlignedBuffer::allocate<bool>(elementCount, pool, bits::kNotNull);
193+
BufferPtr newLengths = allocateSizes(rowCount, pool);
194+
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
195+
196+
// Pointers and cursors to the raw data.
197+
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
198+
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
199+
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
200+
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
201+
auto indicesCursor = 0;
202+
203+
rows.applyToSelected([&](vector_size_t row) {
204+
rawNewOffsets[row] = indicesCursor;
205+
std::optional<vector_size_t> finalNullIndex;
206+
SetWithNull<T> finalSet;
207+
208+
auto idx = decodedOuterArray->index(row);
209+
auto offset = outerArray->offsetAt(idx);
210+
auto size = outerArray->sizeAt(idx);
211+
for (auto i = offset; i < (offset + size); ++i) {
212+
if (decodedInnerArray->isNullAt(i)) {
213+
continue;
214+
}
215+
auto innerIdx = decodedInnerArray->index(i);
216+
auto innerOffset = innerArray->offsetAt(innerIdx);
217+
auto innerSize = innerArray->sizeAt(innerIdx);
218+
219+
// prepare for next iteration
220+
indicesCursor = rawNewOffsets[row];
221+
SetWithNull<T> intermediateSet;
222+
std::optional<vector_size_t> intermediateNullIndex;
223+
224+
for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
225+
// null element
226+
if (decodedInnerElement->isNullAt(j)) {
227+
if ((finalSet.empty() || finalSet.hasNull) &&
228+
!intermediateNullIndex.has_value()) {
229+
intermediateSet.hasNull = true;
230+
intermediateNullIndex = std::optional(indicesCursor++);
231+
}
232+
continue;
233+
}
234+
235+
// regular element
236+
if (finalSet.empty() || finalSet.count(decodedInnerElement, j)) {
237+
auto success = intermediateSet.insert(decodedInnerElement, j);
238+
if (success) {
239+
rawNewIndices[indicesCursor++] = j;
240+
}
241+
}
242+
}
243+
finalSet = intermediateSet;
244+
finalNullIndex = intermediateNullIndex;
245+
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
246+
if (finalSet.empty()) {
247+
break;
248+
}
249+
}
250+
251+
if (finalNullIndex.has_value()) {
252+
bits::setNull(rawNewElementNulls, finalNullIndex.value(), true);
253+
}
254+
});
255+
256+
auto newElements = BaseVector::wrapInDictionary(
257+
newElementNulls, newIndices, indicesCursor, innerArray->elements());
258+
auto resultArray = std::make_shared<ArrayVector>(
259+
pool,
260+
outputType,
261+
nullptr,
262+
rowCount,
263+
newOffsets,
264+
newLengths,
265+
newElements);
266+
context.moveOrCopyResult(resultArray, rows, result);
267+
}
268+
};
269+
154270
// See documentation at https://prestodb.io/docs/current/functions/array.html
155271
template <bool isIntersect, typename T>
156272
class ArrayIntersectExceptFunction : public exec::VectorFunction {
@@ -211,7 +327,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
211327

212328
auto leftElementsCount =
213329
countElements<ArrayVector>(rows, *decodedLeftArray);
214-
vector_size_t rowCount = left->size();
330+
vector_size_t rowCount = rows.end();
215331

216332
// Allocate new vectors for indices, nulls, length and offsets.
217333
BufferPtr newIndices = allocateIndices(leftElementsCount, pool);
@@ -414,7 +530,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
414530
void validateMatchingArrayTypes(
415531
const std::vector<exec::VectorFunctionArg>& inputArgs,
416532
const std::string& name,
417-
vector_size_t expectedArgCount) {
533+
size_t expectedArgCount) {
418534
VELOX_USER_CHECK_EQ(
419535
inputArgs.size(),
420536
expectedArgCount,
@@ -504,10 +620,34 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
504620
}
505621
}
506622

623+
template <TypeKind kind>
624+
std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam(
625+
const std::vector<exec::VectorFunctionArg>& inputArgs,
626+
const TypePtr& elementType) {
627+
if (elementType->providesCustomComparison()) {
628+
return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
629+
} else {
630+
using T = std::conditional_t<
631+
TypeTraits<kind>::isPrimitiveType,
632+
typename TypeTraits<kind>::NativeType,
633+
WrappedVectorEntry>;
634+
return std::make_shared<ArraysIntersectSingleParam<T>>();
635+
}
636+
}
637+
507638
std::shared_ptr<exec::VectorFunction> createArrayIntersect(
508639
const std::string& name,
509640
const std::vector<exec::VectorFunctionArg>& inputArgs,
510641
const core::QueryConfig& /*config*/) {
642+
if (inputArgs.size() == 1) {
643+
auto elementType = inputArgs.front().type->childAt(0)->childAt(0);
644+
return VELOX_DYNAMIC_TYPE_DISPATCH(
645+
createArraysIntersectSingleParam,
646+
elementType->kind(),
647+
inputArgs,
648+
elementType);
649+
}
650+
511651
validateMatchingArrayTypes(inputArgs, name, 2);
512652
auto elementType = inputArgs.front().type->childAt(0);
513653

@@ -534,6 +674,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534674
elementType);
535675
}
536676

677+
std::vector<std::shared_ptr<exec::FunctionSignature>>
678+
arrayIntersectSignatures() {
679+
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
680+
exec::FunctionSignatureBuilder()
681+
.typeVariable("T")
682+
.returnType("array(T)")
683+
.argumentType("array(T)")
684+
.argumentType("array(T)")
685+
.build(),
686+
exec::FunctionSignatureBuilder()
687+
.typeVariable("T")
688+
.returnType("array(T)")
689+
.argumentType("array(array(T))")
690+
.build()};
691+
}
692+
537693
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
538694
const std::string& returnType) {
539695
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +756,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600756

601757
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
602758
udf_array_intersect,
603-
signatures("array(T)"),
759+
arrayIntersectSignatures(),
604760
createArrayIntersect);
605761

606762
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

0 commit comments

Comments
 (0)