16
16
*/
17
17
package org .apache .gluten .backendsapi .clickhouse
18
18
19
- import org .apache .gluten .{ GlutenConfig , GlutenNumaBindingInfo }
19
+ import org .apache .gluten .GlutenNumaBindingInfo
20
20
import org .apache .gluten .backendsapi .IteratorApi
21
21
import org .apache .gluten .execution ._
22
22
import org .apache .gluten .expression .ConverterUtils
@@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
61
61
StructType (dataSchema)
62
62
}
63
63
64
+ private def createNativeIterator (
65
+ splitInfoByteArray : Array [Array [Byte ]],
66
+ wsPlan : Array [Byte ],
67
+ materializeInput : Boolean ,
68
+ inputIterators : Seq [Iterator [ColumnarBatch ]]): BatchIterator = {
69
+
70
+ /** Generate closeable ColumnBatch iterator. */
71
+ val listIterator =
72
+ inputIterators
73
+ .map {
74
+ case i : CloseableCHColumnBatchIterator => i
75
+ case it => new CloseableCHColumnBatchIterator (it)
76
+ }
77
+ .map(it => new ColumnarNativeIterator (it.asJava).asInstanceOf [GeneralInIterator ])
78
+ .asJava
79
+ new CHNativeExpressionEvaluator ().createKernelWithBatchIterator(
80
+ wsPlan,
81
+ splitInfoByteArray,
82
+ listIterator,
83
+ materializeInput
84
+ )
85
+ }
86
+
87
+ private def createCloseIterator (
88
+ context : TaskContext ,
89
+ pipelineTime : SQLMetric ,
90
+ updateNativeMetrics : IMetrics => Unit ,
91
+ updateInputMetrics : Option [InputMetricsWrapper => Unit ] = None ,
92
+ nativeIter : BatchIterator ): CloseableCHColumnBatchIterator = {
93
+
94
+ val iter = new CollectMetricIterator (
95
+ nativeIter,
96
+ updateNativeMetrics,
97
+ updateInputMetrics,
98
+ updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull)
99
+
100
+ context.addTaskFailureListener(
101
+ (ctx, _) => {
102
+ if (ctx.isInterrupted()) {
103
+ iter.cancel()
104
+ }
105
+ })
106
+ context.addTaskCompletionListener[Unit ](_ => iter.close())
107
+ new CloseableCHColumnBatchIterator (iter, Some (pipelineTime))
108
+ }
109
+
64
110
// only set file schema for text format table
65
111
private def setFileSchemaForLocalFiles (
66
112
localFilesNode : LocalFilesNode ,
@@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
198
244
inputIterators : Seq [Iterator [ColumnarBatch ]] = Seq ()
199
245
): Iterator [ColumnarBatch ] = {
200
246
201
- assert (
247
+ require (
202
248
inputPartition.isInstanceOf [GlutenPartition ],
203
249
" CH backend only accepts GlutenPartition in GlutenWholeStageColumnarRDD." )
204
-
205
- val transKernel = new CHNativeExpressionEvaluator ()
206
- val inBatchIters = new JArrayList [GeneralInIterator ](inputIterators.map {
207
- iter => new ColumnarNativeIterator (CHIteratorApi .genCloseableColumnBatchIterator(iter).asJava)
208
- }.asJava)
209
-
210
250
val splitInfoByteArray = inputPartition
211
251
.asInstanceOf [GlutenPartition ]
212
252
.splitInfosByteArray
213
- val nativeIter =
214
- transKernel.createKernelWithBatchIterator(
215
- inputPartition.plan,
216
- splitInfoByteArray,
217
- inBatchIters,
218
- false )
253
+ val wsPlan = inputPartition.plan
254
+ val materializeInput = false
219
255
220
- val iter = new CollectMetricIterator (
221
- nativeIter,
222
- updateNativeMetrics,
223
- updateInputMetrics,
224
- context.taskMetrics().inputMetrics)
225
-
226
- context.addTaskFailureListener(
227
- (ctx, _) => {
228
- if (ctx.isInterrupted()) {
229
- iter.cancel()
230
- }
231
- })
232
- context.addTaskCompletionListener[Unit ](_ => iter.close())
233
-
234
- // TODO: SPARK-25083 remove the type erasure hack in data source scan
235
256
new InterruptibleIterator (
236
257
context,
237
- new CloseableCHColumnBatchIterator (
238
- iter.asInstanceOf [Iterator [ColumnarBatch ]],
239
- Some (pipelineTime)))
258
+ createCloseIterator(
259
+ context,
260
+ pipelineTime,
261
+ updateNativeMetrics,
262
+ Some (updateInputMetrics),
263
+ createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators))
264
+ )
240
265
}
241
266
242
267
// Generate Iterator[ColumnarBatch] for final stage.
@@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
252
277
partitionIndex : Int ,
253
278
materializeInput : Boolean ): Iterator [ColumnarBatch ] = {
254
279
// scalastyle:on argcount
255
- GlutenConfig .getConf
256
-
257
- val transKernel = new CHNativeExpressionEvaluator ()
258
- val columnarNativeIterator =
259
- new JArrayList [GeneralInIterator ](inputIterators.map {
260
- iter =>
261
- new ColumnarNativeIterator (CHIteratorApi .genCloseableColumnBatchIterator(iter).asJava)
262
- }.asJava)
263
- // we need to complete dependency RDD's firstly
264
- val nativeIterator = transKernel.createKernelWithBatchIterator(
265
- rootNode.toProtobuf.toByteArray,
266
- // Final iterator does not contain scan split, so pass empty split info to native here.
267
- new Array [Array [Byte ]](0 ),
268
- columnarNativeIterator,
269
- materializeInput
270
- )
271
-
272
- val iter = new CollectMetricIterator (nativeIterator, updateNativeMetrics, null , null )
273
280
274
- context.addTaskFailureListener(
275
- (ctx, _) => {
276
- if (ctx.isInterrupted()) {
277
- iter.cancel()
278
- }
279
- })
280
- context.addTaskCompletionListener[Unit ](_ => iter.close())
281
- new CloseableCHColumnBatchIterator (iter, Some (pipelineTime))
282
- }
283
- }
281
+ // Final iterator does not contain scan split, so pass empty split info to native here.
282
+ val splitInfoByteArray = new Array [Array [Byte ]](0 )
283
+ val wsPlan = rootNode.toProtobuf.toByteArray
284
284
285
- object CHIteratorApi {
286
-
287
- /** Generate closeable ColumnBatch iterator. */
288
- def genCloseableColumnBatchIterator (iter : Iterator [ColumnarBatch ]): Iterator [ColumnarBatch ] = {
289
- iter match {
290
- case _ : CloseableCHColumnBatchIterator => iter
291
- case _ => new CloseableCHColumnBatchIterator (iter)
292
- }
285
+ // we need to complete dependency RDD's firstly
286
+ createCloseIterator(
287
+ context,
288
+ pipelineTime,
289
+ updateNativeMetrics,
290
+ None ,
291
+ createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators))
293
292
}
294
293
}
295
294
296
295
class CollectMetricIterator (
297
296
val nativeIterator : BatchIterator ,
298
297
val updateNativeMetrics : IMetrics => Unit ,
299
- val updateInputMetrics : InputMetricsWrapper => Unit ,
300
- val inputMetrics : InputMetrics
298
+ val updateInputMetrics : Option [ InputMetricsWrapper => Unit ] = None ,
299
+ val inputMetrics : InputMetrics = null
301
300
) extends Iterator [ColumnarBatch ] {
302
301
private var outputRowCount = 0L
303
302
private var outputVectorCount = 0L
@@ -329,9 +328,7 @@ class CollectMetricIterator(
329
328
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf [NativeMetrics ]
330
329
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
331
330
updateNativeMetrics(nativeMetrics)
332
- if (updateInputMetrics != null ) {
333
- updateInputMetrics(inputMetrics)
334
- }
331
+ updateInputMetrics.foreach(_(inputMetrics))
335
332
metricsUpdated = true
336
333
}
337
334
}
0 commit comments