@@ -190,12 +190,96 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
190
190
explicit ArrayIntersectExceptFunction (SetWithNull<T> constantSet)
191
191
: constantSet_(std::move(constantSet)) {}
192
192
193
+ void intersectSingle (
194
+ const SelectivityVector& rows,
195
+ std::vector<VectorPtr>& args,
196
+ const TypePtr& outputType,
197
+ exec::EvalCtx& context,
198
+ VectorPtr& result) const {
199
+ memory::MemoryPool* pool = context.pool ();
200
+ BaseVector* data = args[0 ].get ();
201
+
202
+ exec::LocalDecodedVector dataDecoder (context, *data, rows);
203
+ auto decodedArray = dataDecoder.get ();
204
+ auto baseArray = decodedArray->base ()->as <ArrayVector>();
205
+
206
+ // Decode and acquire array elements vector.
207
+ exec::LocalDecodedVector elementsDecoder (context);
208
+ auto decodedElements =
209
+ decodeArrayElements (dataDecoder, elementsDecoder, rows);
210
+
211
+ auto elementsCount = countElements<ArrayVector>(rows, *decodedArray);
212
+ vector_size_t rowCount = decodedArray->size ();
213
+
214
+ // Allocate new vectors for indices, nulls, length and offsets.
215
+ BufferPtr newIndices = allocateIndices (elementsCount, pool);
216
+ BufferPtr newElementNulls =
217
+ AlignedBuffer::allocate<bool >(elementsCount, pool, bits::kNotNull );
218
+ BufferPtr newLengths = allocateSizes (rowCount, pool);
219
+ BufferPtr newOffsets = allocateOffsets (rowCount, pool);
220
+
221
+ // Pointers and cursors to the raw data.
222
+ auto rawNewIndices = newIndices->asMutable <vector_size_t >();
223
+ auto rawNewElementNulls = newElementNulls->asMutable <uint64_t >();
224
+ auto rawNewLengths = newLengths->asMutable <vector_size_t >();
225
+ auto rawNewOffsets = newOffsets->asMutable <vector_size_t >();
226
+
227
+ vector_size_t indicesCursor = 0 ;
228
+
229
+ // Lambda that process each row. This is detached from the code so we can
230
+ // apply it differently based on whether the right-hand side set is constant
231
+ // or not.
232
+ auto processRow = [&](vector_size_t row) {
233
+ auto idx = decodedArray->index (row);
234
+ auto size = baseArray->sizeAt (idx);
235
+ auto offset = baseArray->offsetAt (idx);
236
+
237
+ for (vector_size_t i = offset; i < (offset + size); ++i) {
238
+ if (decodedElements->isNullAt (i)) {
239
+ continue ;
240
+ }
241
+ auto arrayVector = decodedElements->base ()->as <ArrayVector>();
242
+ auto arrIdx = decodedElements->index (i);
243
+ auto arrSize = arrayVector->sizeAt (idx);
244
+ auto arrOffset = arrayVector->offsetAt (idx);
245
+ arrayVector->elements ();
246
+
247
+ exec::LocalDecodedVector arrayDecoder (context, *right, rows);
248
+ // Decode and acquire array elements vector.
249
+ exec::LocalDecodedVector elementsDecoder (context);
250
+ auto decodedRightElements =
251
+ decodeArrayElements (arrayDecoder, elementsDecoder, rows);
252
+
253
+
254
+ }
255
+ rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
256
+ };
257
+
258
+ rows.applyToSelected ([&](vector_size_t row) { processRow (row); });
259
+
260
+ auto newElements = BaseVector::wrapInDictionary (
261
+ newElementNulls, newIndices, indicesCursor, baseArray->elements ());
262
+ auto resultArray = std::make_shared<ArrayVector>(
263
+ pool,
264
+ outputType,
265
+ nullptr ,
266
+ rowCount,
267
+ newOffsets,
268
+ newLengths,
269
+ newElements);
270
+ context.moveOrCopyResult (resultArray, rows, result);
271
+ }
272
+
193
273
void apply (
194
274
const SelectivityVector& rows,
195
275
std::vector<VectorPtr>& args,
196
276
const TypePtr& outputType,
197
277
exec::EvalCtx& context,
198
278
VectorPtr& result) const override {
279
+ if (isIntersect && args.size () == 1 ) {
280
+ intersectSingle (rows, args, outputType, context, result);
281
+ }
282
+
199
283
memory::MemoryPool* pool = context.pool ();
200
284
BaseVector* left = args[0 ].get ();
201
285
BaseVector* right = args[1 ].get ();
@@ -489,8 +573,7 @@ template <bool isIntersect, TypeKind kind>
489
573
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept (
490
574
const std::vector<exec::VectorFunctionArg>& inputArgs,
491
575
const TypePtr& elementType) {
492
- VELOX_CHECK_EQ (inputArgs.size (), 2 );
493
- const BaseVector* rhs = inputArgs[1 ].constantValue .get ();
576
+ const BaseVector* rhs = inputArgs.size () == 1 ? nullptr : inputArgs[1 ].constantValue .get ();
494
577
495
578
if (elementType->providesCustomComparison ()) {
496
579
return createTypedArraysIntersectExcept<isIntersect, WrappedVectorEntry>(
@@ -508,8 +591,10 @@ std::shared_ptr<exec::VectorFunction> createArrayIntersect(
508
591
const std::string& name,
509
592
const std::vector<exec::VectorFunctionArg>& inputArgs,
510
593
const core::QueryConfig& /* config*/ ) {
511
- validateMatchingArrayTypes (inputArgs, name, 2 );
512
- auto elementType = inputArgs.front ().type ->childAt (0 );
594
+ validateMatchingArrayTypes (inputArgs, name, inputArgs.size ());
595
+ auto elementType = inputArgs.size () == 1
596
+ ? inputArgs.front ().type ->childAt (0 )->childAt (0 )
597
+ : inputArgs.front ().type ->childAt (0 );
513
598
514
599
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH (
515
600
createTypedArraysIntersectExcept,
@@ -534,6 +619,23 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534
619
elementType);
535
620
}
536
621
622
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
623
+ arrayIntersectSignatures () {
624
+ return {// array(T), array(T) -> array(T)
625
+ (exec::FunctionSignatureBuilder ()
626
+ .typeVariable (" T" )
627
+ .returnType (" array(T)" )
628
+ .argumentType (" array(T)" )
629
+ .argumentType (" array(T)" )
630
+ .build ()),
631
+ // array(array(T)) -> array(T)
632
+ (exec::FunctionSignatureBuilder ()
633
+ .typeVariable (" T" )
634
+ .returnType (" array(T)" )
635
+ .argumentType (" array(array(T))" )
636
+ .build ())};
637
+ }
638
+
537
639
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures (
538
640
const std::string& returnType) {
539
641
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +702,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600
702
601
703
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
602
704
udf_array_intersect,
603
- signatures ( " array(T) " ),
705
+ arrayIntersectSignatures ( ),
604
706
createArrayIntersect);
605
707
606
708
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
0 commit comments