Skip to content

Commit eec04a4

Browse files
committed
Add single parameter (array<array<T>>) support for array_intersect
1 parent f93eae6 commit eec04a4

File tree

2 files changed

+380
-6
lines changed

2 files changed

+380
-6
lines changed

velox/functions/prestosql/ArrayIntersectExcept.cpp

+163-6
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,140 @@ void generateSet(
138138
DecodedVector* decodeArrayElements(
139139
exec::LocalDecodedVector& arrayDecoder,
140140
exec::LocalDecodedVector& elementsDecoder,
141-
const SelectivityVector& rows) {
141+
const SelectivityVector& rows,
142+
SelectivityVector* elementRows) {
142143
auto decodedVector = arrayDecoder.get();
143144
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();
144145

145146
// Decode and acquire array elements vector.
146147
auto elementsVector = baseArrayVector->elements();
147-
auto elementsSelectivityRows = toElementRows(
148+
*elementRows = toElementRows(
148149
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
149-
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
150+
elementsDecoder.get()->decode(*elementsVector, *elementRows);
150151
auto decodedElementsVector = elementsDecoder.get();
151152
return decodedElementsVector;
152153
}
153154

155+
DecodedVector* decodeArrayElements(
156+
exec::LocalDecodedVector& arrayDecoder,
157+
exec::LocalDecodedVector& elementsDecoder,
158+
const SelectivityVector& rows) {
159+
SelectivityVector elementRows;
160+
return decodeArrayElements(arrayDecoder, elementsDecoder, rows, &elementRows);
161+
}
162+
163+
template <typename T>
164+
class ArraysIntersectSingleParam : public exec::VectorFunction {
165+
public:
166+
/// This class is used for array_intersect function with single parameter.
167+
void apply(
168+
const SelectivityVector& rows,
169+
std::vector<VectorPtr>& args,
170+
const TypePtr& outputType,
171+
exec::EvalCtx& context,
172+
VectorPtr& result) const override {
173+
memory::MemoryPool* pool = context.pool();
174+
175+
exec::LocalDecodedVector outerArrayDecoder(context, *args[0], rows);
176+
auto decodedOuterArray = outerArrayDecoder.get();
177+
auto outerArray = decodedOuterArray->base()->as<ArrayVector>();
178+
179+
exec::LocalDecodedVector innerArrayDecoder(context);
180+
SelectivityVector innerRows;
181+
auto decodedInnerArray = decodeArrayElements(
182+
outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
183+
auto innerArray = decodedInnerArray->base()->as<ArrayVector>();
184+
185+
exec::LocalDecodedVector elementDecoder(context);
186+
SelectivityVector elementRows;
187+
auto decodedInnerElement = decodeArrayElements(
188+
innerArrayDecoder, elementDecoder, innerRows, &elementRows);
189+
190+
const auto elementCount =
191+
countElements<ArrayVector>(innerRows, *decodedInnerArray);
192+
const auto rowCount = rows.end();
193+
194+
// Allocate new vectors for indices, nulls, length and offsets.
195+
BufferPtr newIndices = allocateIndices(elementCount, pool);
196+
BufferPtr newElementNulls =
197+
AlignedBuffer::allocate<bool>(elementCount, pool, bits::kNotNull);
198+
BufferPtr newLengths = allocateSizes(rowCount, pool);
199+
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
200+
201+
// Pointers and cursors to the raw data.
202+
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
203+
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
204+
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
205+
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
206+
auto indicesCursor = 0;
207+
208+
rows.applyToSelected([&](vector_size_t row) {
209+
rawNewOffsets[row] = indicesCursor;
210+
std::optional<vector_size_t> finalNullIndex;
211+
SetWithNull<T> finalSet;
212+
213+
auto idx = decodedOuterArray->index(row);
214+
auto offset = outerArray->offsetAt(idx);
215+
auto size = outerArray->sizeAt(idx);
216+
for (auto i = offset; i < (offset + size); ++i) {
217+
if (decodedInnerArray->isNullAt(i)) {
218+
continue;
219+
}
220+
auto innerIdx = decodedInnerArray->index(i);
221+
auto innerOffset = innerArray->offsetAt(innerIdx);
222+
auto innerSize = innerArray->sizeAt(innerIdx);
223+
224+
// prepare for next iteration
225+
indicesCursor = rawNewOffsets[row];
226+
SetWithNull<T> intermediateSet;
227+
std::optional<vector_size_t> intermediateNullIndex;
228+
229+
for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
230+
// null element
231+
if (decodedInnerElement->isNullAt(j)) {
232+
if ((finalSet.empty() || finalSet.hasNull) &&
233+
!intermediateNullIndex.has_value()) {
234+
intermediateSet.hasNull = true;
235+
intermediateNullIndex = std::optional(indicesCursor++);
236+
}
237+
continue;
238+
}
239+
240+
// regular element
241+
if (finalSet.empty() || finalSet.count(decodedInnerElement, j)) {
242+
auto success = intermediateSet.insert(decodedInnerElement, j);
243+
if (success) {
244+
rawNewIndices[indicesCursor++] = j;
245+
}
246+
}
247+
}
248+
finalSet = intermediateSet;
249+
finalNullIndex = intermediateNullIndex;
250+
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
251+
if (finalSet.empty()) {
252+
break;
253+
}
254+
}
255+
256+
if (finalNullIndex.has_value()) {
257+
bits::setNull(rawNewElementNulls, finalNullIndex.value(), true);
258+
}
259+
});
260+
261+
auto newElements = BaseVector::wrapInDictionary(
262+
newElementNulls, newIndices, indicesCursor, innerArray->elements());
263+
auto resultArray = std::make_shared<ArrayVector>(
264+
pool,
265+
outputType,
266+
nullptr,
267+
rowCount,
268+
newOffsets,
269+
newLengths,
270+
newElements);
271+
context.moveOrCopyResult(resultArray, rows, result);
272+
}
273+
};
274+
154275
// See documentation at https://prestodb.io/docs/current/functions/array.html
155276
template <bool isIntersect, typename T>
156277
class ArrayIntersectExceptFunction : public exec::VectorFunction {
@@ -211,7 +332,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
211332

212333
auto leftElementsCount =
213334
countElements<ArrayVector>(rows, *decodedLeftArray);
214-
vector_size_t rowCount = left->size();
335+
vector_size_t rowCount = rows.end();
215336

216337
// Allocate new vectors for indices, nulls, length and offsets.
217338
BufferPtr newIndices = allocateIndices(leftElementsCount, pool);
@@ -414,7 +535,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
414535
void validateMatchingArrayTypes(
415536
const std::vector<exec::VectorFunctionArg>& inputArgs,
416537
const std::string& name,
417-
vector_size_t expectedArgCount) {
538+
size_t expectedArgCount) {
418539
VELOX_USER_CHECK_EQ(
419540
inputArgs.size(),
420541
expectedArgCount,
@@ -504,10 +625,30 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
504625
}
505626
}
506627

628+
template <TypeKind kind>
629+
std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam(
630+
const TypePtr& elementType) {
631+
if (elementType->providesCustomComparison()) {
632+
return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
633+
} else {
634+
using T = std::conditional_t<
635+
TypeTraits<kind>::isPrimitiveType,
636+
typename TypeTraits<kind>::NativeType,
637+
WrappedVectorEntry>;
638+
return std::make_shared<ArraysIntersectSingleParam<T>>();
639+
}
640+
}
641+
507642
std::shared_ptr<exec::VectorFunction> createArrayIntersect(
508643
const std::string& name,
509644
const std::vector<exec::VectorFunctionArg>& inputArgs,
510645
const core::QueryConfig& /*config*/) {
646+
if (inputArgs.size() == 1) {
647+
auto elementType = inputArgs.front().type->childAt(0)->childAt(0);
648+
return VELOX_DYNAMIC_TYPE_DISPATCH(
649+
createArraysIntersectSingleParam, elementType->kind(), elementType);
650+
}
651+
511652
validateMatchingArrayTypes(inputArgs, name, 2);
512653
auto elementType = inputArgs.front().type->childAt(0);
513654

@@ -534,6 +675,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534675
elementType);
535676
}
536677

678+
std::vector<std::shared_ptr<exec::FunctionSignature>>
679+
arrayIntersectSignatures() {
680+
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
681+
exec::FunctionSignatureBuilder()
682+
.typeVariable("T")
683+
.returnType("array(T)")
684+
.argumentType("array(T)")
685+
.argumentType("array(T)")
686+
.build(),
687+
exec::FunctionSignatureBuilder()
688+
.typeVariable("T")
689+
.returnType("array(T)")
690+
.argumentType("array(array(T))")
691+
.build()};
692+
}
693+
537694
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
538695
const std::string& returnType) {
539696
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +757,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600757

601758
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
602759
udf_array_intersect,
603-
signatures("array(T)"),
760+
arrayIntersectSignatures(),
604761
createArrayIntersect);
605762

606763
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

0 commit comments

Comments
 (0)