@@ -138,7 +138,8 @@ void generateSet(
138
138
DecodedVector* decodeArrayElements (
139
139
exec::LocalDecodedVector& arrayDecoder,
140
140
exec::LocalDecodedVector& elementsDecoder,
141
- const SelectivityVector& rows) {
141
+ const SelectivityVector& rows,
142
+ SelectivityVector* elementRows = nullptr ) {
142
143
auto decodedVector = arrayDecoder.get ();
143
144
auto baseArrayVector = arrayDecoder->base ()->as <ArrayVector>();
144
145
@@ -147,10 +148,125 @@ DecodedVector* decodeArrayElements(
147
148
auto elementsSelectivityRows = toElementRows (
148
149
elementsVector->size (), rows, baseArrayVector, decodedVector->indices ());
149
150
elementsDecoder.get ()->decode (*elementsVector, elementsSelectivityRows);
151
+ if (elementRows != nullptr ) {
152
+ *elementRows = elementsSelectivityRows;
153
+ }
150
154
auto decodedElementsVector = elementsDecoder.get ();
151
155
return decodedElementsVector;
152
156
}
153
157
158
+ template <typename T>
159
+ class ArraysIntersectSingleParam : public exec ::VectorFunction {
160
+ public:
161
+ // / This class is used for array_intersect function with single parameter.
162
+ void apply (
163
+ const SelectivityVector& rows,
164
+ std::vector<VectorPtr>& args,
165
+ const TypePtr& outputType,
166
+ exec::EvalCtx& context,
167
+ VectorPtr& result) const {
168
+ memory::MemoryPool* pool = context.pool ();
169
+
170
+ exec::LocalDecodedVector outerArrayDecoder (context, *args[0 ], rows);
171
+ auto decodedOuterArray = outerArrayDecoder.get ();
172
+ auto outerArray = decodedOuterArray->base ()->as <ArrayVector>();
173
+
174
+ exec::LocalDecodedVector innerArrayDecoder (context);
175
+ SelectivityVector innerRows;
176
+ auto decodedInnerArray = decodeArrayElements (
177
+ outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
178
+ auto innerArray = decodedInnerArray->base ()->as <ArrayVector>();
179
+
180
+ exec::LocalDecodedVector elementDecoder (context);
181
+ SelectivityVector elementRows;
182
+ auto decodedInnerElement = decodeArrayElements (
183
+ innerArrayDecoder, elementDecoder, innerRows, &elementRows);
184
+
185
+ const auto elementCount =
186
+ countElements<ArrayVector>(innerRows, *decodedInnerArray);
187
+ const auto rowCount = rows.end ();
188
+
189
+ // Allocate new vectors for indices, nulls, length and offsets.
190
+ BufferPtr newIndices = allocateIndices (elementCount, pool);
191
+ BufferPtr newElementNulls =
192
+ AlignedBuffer::allocate<bool >(elementCount, pool, bits::kNotNull );
193
+ BufferPtr newLengths = allocateSizes (rowCount, pool);
194
+ BufferPtr newOffsets = allocateOffsets (rowCount, pool);
195
+
196
+ // Pointers and cursors to the raw data.
197
+ auto rawNewIndices = newIndices->asMutable <vector_size_t >();
198
+ auto rawNewElementNulls = newElementNulls->asMutable <uint64_t >();
199
+ auto rawNewOffsets = newOffsets->asMutable <vector_size_t >();
200
+ auto rawNewLengths = newLengths->asMutable <vector_size_t >();
201
+ auto indicesCursor = 0 ;
202
+
203
+ rows.applyToSelected ([&](vector_size_t row) {
204
+ rawNewOffsets[row] = indicesCursor;
205
+ std::optional<vector_size_t > finalNullIndex;
206
+ SetWithNull<T> finalSet;
207
+
208
+ auto idx = decodedOuterArray->index (row);
209
+ auto offset = outerArray->offsetAt (idx);
210
+ auto size = outerArray->sizeAt (idx);
211
+ for (auto i = offset; i < (offset + size); ++i) {
212
+ if (decodedInnerArray->isNullAt (i)) {
213
+ continue ;
214
+ }
215
+ auto innerIdx = decodedInnerArray->index (i);
216
+ auto innerOffset = innerArray->offsetAt (innerIdx);
217
+ auto innerSize = innerArray->sizeAt (innerIdx);
218
+
219
+ // prepare for next iteration
220
+ indicesCursor = rawNewOffsets[row];
221
+ SetWithNull<T> intermediateSet;
222
+ std::optional<vector_size_t > intermediateNullIndex;
223
+
224
+ for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
225
+ // null element
226
+ if (decodedInnerElement->isNullAt (j)) {
227
+ if ((finalSet.empty () || finalSet.hasNull ) &&
228
+ !intermediateNullIndex.has_value ()) {
229
+ intermediateSet.hasNull = true ;
230
+ intermediateNullIndex = std::optional (indicesCursor++);
231
+ }
232
+ continue ;
233
+ }
234
+
235
+ // regular element
236
+ if (finalSet.empty () || finalSet.count (decodedInnerElement, j)) {
237
+ auto success = intermediateSet.insert (decodedInnerElement, j);
238
+ if (success) {
239
+ rawNewIndices[indicesCursor++] = j;
240
+ }
241
+ }
242
+ }
243
+ finalSet = intermediateSet;
244
+ finalNullIndex = intermediateNullIndex;
245
+ rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
246
+ if (finalSet.empty ()) {
247
+ break ;
248
+ }
249
+ }
250
+
251
+ if (finalNullIndex.has_value ()) {
252
+ bits::setNull (rawNewElementNulls, finalNullIndex.value (), true );
253
+ }
254
+ });
255
+
256
+ auto newElements = BaseVector::wrapInDictionary (
257
+ newElementNulls, newIndices, indicesCursor, innerArray->elements ());
258
+ auto resultArray = std::make_shared<ArrayVector>(
259
+ pool,
260
+ outputType,
261
+ nullptr ,
262
+ rowCount,
263
+ newOffsets,
264
+ newLengths,
265
+ newElements);
266
+ context.moveOrCopyResult (resultArray, rows, result);
267
+ }
268
+ };
269
+
154
270
// See documentation at https://prestodb.io/docs/current/functions/array.html
155
271
template <bool isIntersect, typename T>
156
272
class ArrayIntersectExceptFunction : public exec ::VectorFunction {
@@ -211,7 +327,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
211
327
212
328
auto leftElementsCount =
213
329
countElements<ArrayVector>(rows, *decodedLeftArray);
214
- vector_size_t rowCount = left-> size ();
330
+ vector_size_t rowCount = rows. end ();
215
331
216
332
// Allocate new vectors for indices, nulls, length and offsets.
217
333
BufferPtr newIndices = allocateIndices (leftElementsCount, pool);
@@ -414,7 +530,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
414
530
void validateMatchingArrayTypes (
415
531
const std::vector<exec::VectorFunctionArg>& inputArgs,
416
532
const std::string& name,
417
- vector_size_t expectedArgCount) {
533
+ size_t expectedArgCount) {
418
534
VELOX_USER_CHECK_EQ (
419
535
inputArgs.size (),
420
536
expectedArgCount,
@@ -504,10 +620,34 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
504
620
}
505
621
}
506
622
623
+ template <TypeKind kind>
624
+ std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam (
625
+ const std::vector<exec::VectorFunctionArg>& inputArgs,
626
+ const TypePtr& elementType) {
627
+ if (elementType->providesCustomComparison ()) {
628
+ return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
629
+ } else {
630
+ using T = std::conditional_t <
631
+ TypeTraits<kind>::isPrimitiveType,
632
+ typename TypeTraits<kind>::NativeType,
633
+ WrappedVectorEntry>;
634
+ return std::make_shared<ArraysIntersectSingleParam<T>>();
635
+ }
636
+ }
637
+
507
638
std::shared_ptr<exec::VectorFunction> createArrayIntersect (
508
639
const std::string& name,
509
640
const std::vector<exec::VectorFunctionArg>& inputArgs,
510
641
const core::QueryConfig& /* config*/ ) {
642
+ if (inputArgs.size () == 1 ) {
643
+ auto elementType = inputArgs.front ().type ->childAt (0 )->childAt (0 );
644
+ return VELOX_DYNAMIC_TYPE_DISPATCH (
645
+ createArraysIntersectSingleParam,
646
+ elementType->kind (),
647
+ inputArgs,
648
+ elementType);
649
+ }
650
+
511
651
validateMatchingArrayTypes (inputArgs, name, 2 );
512
652
auto elementType = inputArgs.front ().type ->childAt (0 );
513
653
@@ -534,6 +674,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534
674
elementType);
535
675
}
536
676
677
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
678
+ arrayIntersectSignatures () {
679
+ return std::vector<std::shared_ptr<exec::FunctionSignature>>{
680
+ exec::FunctionSignatureBuilder ()
681
+ .typeVariable (" T" )
682
+ .returnType (" array(T)" )
683
+ .argumentType (" array(T)" )
684
+ .argumentType (" array(T)" )
685
+ .build (),
686
+ exec::FunctionSignatureBuilder ()
687
+ .typeVariable (" T" )
688
+ .returnType (" array(T)" )
689
+ .argumentType (" array(array(T))" )
690
+ .build ()};
691
+ }
692
+
537
693
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures (
538
694
const std::string& returnType) {
539
695
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +756,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600
756
601
757
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
602
758
udf_array_intersect,
603
- signatures ( " array(T) " ),
759
+ arrayIntersectSignatures ( ),
604
760
createArrayIntersect);
605
761
606
762
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
0 commit comments