@@ -138,19 +138,140 @@ void generateSet(
138
138
DecodedVector* decodeArrayElements (
139
139
exec::LocalDecodedVector& arrayDecoder,
140
140
exec::LocalDecodedVector& elementsDecoder,
141
- const SelectivityVector& rows) {
141
+ const SelectivityVector& rows,
142
+ SelectivityVector* elementRows) {
142
143
auto decodedVector = arrayDecoder.get ();
143
144
auto baseArrayVector = arrayDecoder->base ()->as <ArrayVector>();
144
145
145
146
// Decode and acquire array elements vector.
146
147
auto elementsVector = baseArrayVector->elements ();
147
- auto elementsSelectivityRows = toElementRows (
148
+ *elementRows = toElementRows (
148
149
elementsVector->size (), rows, baseArrayVector, decodedVector->indices ());
149
- elementsDecoder.get ()->decode (*elementsVector, elementsSelectivityRows );
150
+ elementsDecoder.get ()->decode (*elementsVector, *elementRows );
150
151
auto decodedElementsVector = elementsDecoder.get ();
151
152
return decodedElementsVector;
152
153
}
153
154
155
+ DecodedVector* decodeArrayElements (
156
+ exec::LocalDecodedVector& arrayDecoder,
157
+ exec::LocalDecodedVector& elementsDecoder,
158
+ const SelectivityVector& rows) {
159
+ SelectivityVector elementRows;
160
+ return decodeArrayElements (arrayDecoder, elementsDecoder, rows, &elementRows);
161
+ }
162
+
163
+ template <typename T>
164
+ class ArraysIntersectSingleParam : public exec ::VectorFunction {
165
+ public:
166
+ // / This class is used for array_intersect function with single parameter.
167
+ void apply (
168
+ const SelectivityVector& rows,
169
+ std::vector<VectorPtr>& args,
170
+ const TypePtr& outputType,
171
+ exec::EvalCtx& context,
172
+ VectorPtr& result) const override {
173
+ memory::MemoryPool* pool = context.pool ();
174
+
175
+ exec::LocalDecodedVector outerArrayDecoder (context, *args[0 ], rows);
176
+ auto decodedOuterArray = outerArrayDecoder.get ();
177
+ auto outerArray = decodedOuterArray->base ()->as <ArrayVector>();
178
+
179
+ exec::LocalDecodedVector innerArrayDecoder (context);
180
+ SelectivityVector innerRows;
181
+ auto decodedInnerArray = decodeArrayElements (
182
+ outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
183
+ auto innerArray = decodedInnerArray->base ()->as <ArrayVector>();
184
+
185
+ exec::LocalDecodedVector elementDecoder (context);
186
+ SelectivityVector elementRows;
187
+ auto decodedInnerElement = decodeArrayElements (
188
+ innerArrayDecoder, elementDecoder, innerRows, &elementRows);
189
+
190
+ const auto elementCount =
191
+ countElements<ArrayVector>(innerRows, *decodedInnerArray);
192
+ const auto rowCount = rows.end ();
193
+
194
+ // Allocate new vectors for indices, nulls, length and offsets.
195
+ BufferPtr newIndices = allocateIndices (elementCount, pool);
196
+ BufferPtr newElementNulls =
197
+ AlignedBuffer::allocate<bool >(elementCount, pool, bits::kNotNull );
198
+ BufferPtr newLengths = allocateSizes (rowCount, pool);
199
+ BufferPtr newOffsets = allocateOffsets (rowCount, pool);
200
+
201
+ // Pointers and cursors to the raw data.
202
+ auto rawNewIndices = newIndices->asMutable <vector_size_t >();
203
+ auto rawNewElementNulls = newElementNulls->asMutable <uint64_t >();
204
+ auto rawNewOffsets = newOffsets->asMutable <vector_size_t >();
205
+ auto rawNewLengths = newLengths->asMutable <vector_size_t >();
206
+ auto indicesCursor = 0 ;
207
+
208
+ rows.applyToSelected ([&](vector_size_t row) {
209
+ rawNewOffsets[row] = indicesCursor;
210
+ std::optional<vector_size_t > finalNullIndex;
211
+ SetWithNull<T> finalSet;
212
+
213
+ auto idx = decodedOuterArray->index (row);
214
+ auto offset = outerArray->offsetAt (idx);
215
+ auto size = outerArray->sizeAt (idx);
216
+ for (auto i = offset; i < (offset + size); ++i) {
217
+ if (decodedInnerArray->isNullAt (i)) {
218
+ continue ;
219
+ }
220
+ auto innerIdx = decodedInnerArray->index (i);
221
+ auto innerOffset = innerArray->offsetAt (innerIdx);
222
+ auto innerSize = innerArray->sizeAt (innerIdx);
223
+
224
+ // prepare for next iteration
225
+ indicesCursor = rawNewOffsets[row];
226
+ SetWithNull<T> intermediateSet;
227
+ std::optional<vector_size_t > intermediateNullIndex;
228
+
229
+ for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
230
+ // null element
231
+ if (decodedInnerElement->isNullAt (j)) {
232
+ if ((finalSet.empty () || finalSet.hasNull ) &&
233
+ !intermediateNullIndex.has_value ()) {
234
+ intermediateSet.hasNull = true ;
235
+ intermediateNullIndex = std::optional (indicesCursor++);
236
+ }
237
+ continue ;
238
+ }
239
+
240
+ // regular element
241
+ if (finalSet.empty () || finalSet.count (decodedInnerElement, j)) {
242
+ auto success = intermediateSet.insert (decodedInnerElement, j);
243
+ if (success) {
244
+ rawNewIndices[indicesCursor++] = j;
245
+ }
246
+ }
247
+ }
248
+ finalSet = intermediateSet;
249
+ finalNullIndex = intermediateNullIndex;
250
+ rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
251
+ if (finalSet.empty ()) {
252
+ break ;
253
+ }
254
+ }
255
+
256
+ if (finalNullIndex.has_value ()) {
257
+ bits::setNull (rawNewElementNulls, finalNullIndex.value (), true );
258
+ }
259
+ });
260
+
261
+ auto newElements = BaseVector::wrapInDictionary (
262
+ newElementNulls, newIndices, indicesCursor, innerArray->elements ());
263
+ auto resultArray = std::make_shared<ArrayVector>(
264
+ pool,
265
+ outputType,
266
+ nullptr ,
267
+ rowCount,
268
+ newOffsets,
269
+ newLengths,
270
+ newElements);
271
+ context.moveOrCopyResult (resultArray, rows, result);
272
+ }
273
+ };
274
+
154
275
// See documentation at https://prestodb.io/docs/current/functions/array.html
155
276
template <bool isIntersect, typename T>
156
277
class ArrayIntersectExceptFunction : public exec ::VectorFunction {
@@ -211,7 +332,7 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
211
332
212
333
auto leftElementsCount =
213
334
countElements<ArrayVector>(rows, *decodedLeftArray);
214
- vector_size_t rowCount = left-> size ();
335
+ vector_size_t rowCount = rows. end ();
215
336
216
337
// Allocate new vectors for indices, nulls, length and offsets.
217
338
BufferPtr newIndices = allocateIndices (leftElementsCount, pool);
@@ -414,7 +535,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
414
535
void validateMatchingArrayTypes (
415
536
const std::vector<exec::VectorFunctionArg>& inputArgs,
416
537
const std::string& name,
417
- vector_size_t expectedArgCount) {
538
+ size_t expectedArgCount) {
418
539
VELOX_USER_CHECK_EQ (
419
540
inputArgs.size (),
420
541
expectedArgCount,
@@ -504,10 +625,30 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
504
625
}
505
626
}
506
627
628
+ template <TypeKind kind>
629
+ std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam (
630
+ const TypePtr& elementType) {
631
+ if (elementType->providesCustomComparison ()) {
632
+ return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
633
+ } else {
634
+ using T = std::conditional_t <
635
+ TypeTraits<kind>::isPrimitiveType,
636
+ typename TypeTraits<kind>::NativeType,
637
+ WrappedVectorEntry>;
638
+ return std::make_shared<ArraysIntersectSingleParam<T>>();
639
+ }
640
+ }
641
+
507
642
std::shared_ptr<exec::VectorFunction> createArrayIntersect (
508
643
const std::string& name,
509
644
const std::vector<exec::VectorFunctionArg>& inputArgs,
510
645
const core::QueryConfig& /* config*/ ) {
646
+ if (inputArgs.size () == 1 ) {
647
+ auto elementType = inputArgs.front ().type ->childAt (0 )->childAt (0 );
648
+ return VELOX_DYNAMIC_TYPE_DISPATCH (
649
+ createArraysIntersectSingleParam, elementType->kind (), elementType);
650
+ }
651
+
511
652
validateMatchingArrayTypes (inputArgs, name, 2 );
512
653
auto elementType = inputArgs.front ().type ->childAt (0 );
513
654
@@ -534,6 +675,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
534
675
elementType);
535
676
}
536
677
678
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
679
+ arrayIntersectSignatures () {
680
+ return std::vector<std::shared_ptr<exec::FunctionSignature>>{
681
+ exec::FunctionSignatureBuilder ()
682
+ .typeVariable (" T" )
683
+ .returnType (" array(T)" )
684
+ .argumentType (" array(T)" )
685
+ .argumentType (" array(T)" )
686
+ .build (),
687
+ exec::FunctionSignatureBuilder ()
688
+ .typeVariable (" T" )
689
+ .returnType (" array(T)" )
690
+ .argumentType (" array(array(T))" )
691
+ .build ()};
692
+ }
693
+
537
694
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures (
538
695
const std::string& returnType) {
539
696
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
@@ -600,7 +757,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
600
757
601
758
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
602
759
udf_array_intersect,
603
- signatures ( " array(T) " ),
760
+ arrayIntersectSignatures ( ),
604
761
createArrayIntersect);
605
762
606
763
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
0 commit comments