Skip to content

Commit f78ffb3

Browse files
committed
Support Native Write
1 parent 5e9b3ec commit f78ffb3

28 files changed

+1942
-214
lines changed

backends-clickhouse/src/main/java/org/apache/gluten/metrics/OperatorMetrics.java

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ public class OperatorMetrics implements IOperatorMetrics {
2727
public JoinParams joinParams;
2828
public AggregationParams aggParams;
2929

30+
public long physicalWrittenBytes;
31+
public long numWrittenFiles;
32+
3033
/** Create an instance for operator metrics. */
3134
public OperatorMetrics(
3235
List<MetricsData> metricsList, JoinParams joinParams, AggregationParams aggParams) {

backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.spark.SparkConf;
2929
import org.apache.spark.sql.internal.SQLConf;
3030

31+
import java.nio.charset.StandardCharsets;
3132
import java.util.Arrays;
3233
import java.util.List;
3334
import java.util.Map;
@@ -79,8 +80,10 @@ private static Map<String, String> getNativeBackendConf() {
7980
}
8081

8182
public static void injectWriteFilesTempPath(String path, String fileName) {
82-
throw new UnsupportedOperationException(
83-
"injectWriteFilesTempPath Not supported in CHNativeExpressionEvaluator");
83+
ExpressionEvaluatorJniWrapper.injectWriteFilesTempPath(
84+
CHNativeMemoryAllocators.contextInstance().getNativeInstanceId(),
85+
path.getBytes(StandardCharsets.UTF_8),
86+
fileName.getBytes(StandardCharsets.UTF_8));
8487
}
8588

8689
// Used by WholeStageTransform to create the native computing pipeline and

backends-clickhouse/src/main/java/org/apache/gluten/vectorized/ExpressionEvaluatorJniWrapper.java

+9
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,13 @@ public static native long nativeCreateKernelWithIterator(
4242
GeneralInIterator[] batchItr,
4343
byte[] confArray,
4444
boolean materializeInput);
45+
46+
/**
47+
* Set the temp path for writing files.
48+
*
49+
* @param allocatorId allocator id for current task attempt(or thread)
50+
* @param path the temp path for writing files
51+
*/
52+
public static native void injectWriteFilesTempPath(
53+
long allocatorId, byte[] path, byte[] filename);
4554
}

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala

+73-1
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,25 @@ package org.apache.gluten.backendsapi.clickhouse
1818

1919
import org.apache.gluten.{CH_BRANCH, CH_COMMIT, GlutenConfig}
2020
import org.apache.gluten.backendsapi._
21+
import org.apache.gluten.execution.WriteFilesExecTransformer
2122
import org.apache.gluten.expression.WindowFunctionsBuilder
2223
import org.apache.gluten.extension.ValidationResult
2324
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
2425
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat._
2526

2627
import org.apache.spark.SparkEnv
2728
import org.apache.spark.internal.Logging
29+
import org.apache.spark.sql.catalyst.catalog.BucketSpec
2830
import org.apache.spark.sql.catalyst.expressions._
2931
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
3032
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
3133
import org.apache.spark.sql.execution.SparkPlan
3234
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
35+
import org.apache.spark.sql.execution.datasources.FileFormat
36+
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
37+
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
3338
import org.apache.spark.sql.internal.SQLConf
34-
import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType}
39+
import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, StructField, StructType}
3540

3641
import java.util.Locale
3742

@@ -187,6 +192,73 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
187192
}
188193
}
189194

195+
override def supportWriteFilesExec(
196+
format: FileFormat,
197+
fields: Array[StructField],
198+
bucketSpec: Option[BucketSpec],
199+
options: Map[String, String]): ValidationResult = {
200+
201+
def validateCompressionCodec(): Option[String] = {
202+
// FIXME: verify Support compression codec
203+
val compressionCodec = WriteFilesExecTransformer.getCompressionCodec(options)
204+
None
205+
}
206+
207+
def validateFileFormat(): Option[String] = {
208+
format match {
209+
case _: ParquetFileFormat => None
210+
case _: OrcFileFormat => None
211+
case f: FileFormat => Some(s"Not support FileFormat: ${f.getClass.getSimpleName}")
212+
}
213+
}
214+
215+
// Validate if all types are supported.
216+
def validateDateTypes(): Option[String] = {
217+
None
218+
}
219+
220+
def validateFieldMetadata(): Option[String] = {
221+
// copy CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY
222+
val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING"
223+
fields
224+
.find(_.metadata != Metadata.empty)
225+
.filterNot(_.metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))
226+
.map {
227+
filed =>
228+
s"StructField contain the metadata information: $filed, metadata: ${filed.metadata}"
229+
}
230+
}
231+
def validateWriteFilesOptions(): Option[String] = {
232+
val maxRecordsPerFile = options
233+
.get("maxRecordsPerFile")
234+
.map(_.toLong)
235+
.getOrElse(SQLConf.get.maxRecordsPerFile)
236+
if (maxRecordsPerFile > 0) {
237+
Some("Unsupported native write: maxRecordsPerFile not supported.")
238+
} else {
239+
None
240+
}
241+
}
242+
243+
def validateBucketSpec(): Option[String] = {
244+
if (bucketSpec.nonEmpty) {
245+
Some("Unsupported native write: bucket write is not supported.")
246+
} else {
247+
None
248+
}
249+
}
250+
251+
validateCompressionCodec()
252+
.orElse(validateFileFormat())
253+
.orElse(validateFieldMetadata())
254+
.orElse(validateDateTypes())
255+
.orElse(validateWriteFilesOptions())
256+
.orElse(validateBucketSpec()) match {
257+
case Some(reason) => ValidationResult.failed(reason)
258+
case _ => ValidationResult.succeeded
259+
}
260+
}
261+
190262
override def supportShuffleWithProject(
191263
outputPartitioning: Partitioning,
192264
child: SparkPlan): Boolean = {

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,13 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
383383
s"SampleTransformer metrics update is not supported in CH backend")
384384
}
385385

386-
def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = {
387-
throw new UnsupportedOperationException(
388-
s"WriteFilesTransformer metrics update is not supported in CH backend")
389-
}
386+
def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
387+
Map(
388+
"physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number of written bytes"),
389+
"numWrittenFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files")
390+
)
390391

391392
def genWriteFilesTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = {
392-
throw new UnsupportedOperationException(
393-
s"WriteFilesTransformer metrics update is not supported in CH backend")
393+
new WriteFilesMetricsUpdater(metrics)
394394
}
395395
}

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
674674
CHRegExpReplaceTransformer(substraitExprName, children, expr)
675675
}
676676

677-
def createBackendWrite(description: WriteJobDescription): BackendWrite =
678-
throw new UnsupportedOperationException("createBackendWrite is not supported in ch backend.")
677+
def createBackendWrite(description: WriteJobDescription): BackendWrite = ClickhouseBackendWrite(
678+
description)
679679

680680
override def createColumnarArrowEvalPythonExec(
681681
udfs: Seq[PythonUDF],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.metrics
18+
19+
import org.apache.spark.sql.execution.metric.SQLMetric
20+
21+
class WriteFilesMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater {
22+
23+
override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
24+
if (opMetrics != null) {
25+
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
26+
metrics("physicalWrittenBytes") += operatorMetrics.physicalWrittenBytes
27+
metrics("numWrittenFiles") += operatorMetrics.numWrittenFiles
28+
}
29+
}
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution
18+
19+
import org.apache.spark.internal.Logging
20+
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
21+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
22+
import org.apache.spark.sql.execution.datasources._
23+
import org.apache.spark.sql.vectorized.ColumnarBatch
24+
25+
import scala.collection.mutable
26+
27+
case class ClickhouseBackendWrite(description: WriteJobDescription)
28+
extends BackendWrite
29+
with Logging {
30+
31+
override def collectNativeWriteFilesMetrics(cb: ColumnarBatch): Option[WriteTaskResult] = {
32+
val numFiles = cb.numRows()
33+
// Write an empty iterator
34+
if (numFiles == 0) {
35+
None
36+
} else {
37+
val file_col = cb.column(0)
38+
val partition_col = cb.column(1)
39+
val count_col = cb.column(2)
40+
41+
val outputPath = description.path
42+
var updatedPartitions = Set.empty[String]
43+
val addedAbsPathFiles: mutable.Map[String, String] = mutable.Map[String, String]()
44+
45+
val write_stats = Range(0, cb.numRows()).map {
46+
i =>
47+
val targetFileName = file_col.getUTF8String(i).toString
48+
val partition = partition_col.getUTF8String(i).toString
49+
if (partition != "__NO_PARTITION_ID__") {
50+
updatedPartitions += partition
51+
val tmpOutputPath = outputPath + "/" + partition + "/" + targetFileName
52+
val customOutputPath =
53+
description.customPartitionLocations.get(
54+
PartitioningUtils.parsePathFragment(partition))
55+
if (customOutputPath.isDefined) {
56+
addedAbsPathFiles(tmpOutputPath) = customOutputPath.get + "/" + targetFileName
57+
}
58+
}
59+
count_col.getLong(i)
60+
}
61+
62+
val partitionsInternalRows = updatedPartitions.map {
63+
part =>
64+
val parts = new Array[Any](1)
65+
parts(0) = part
66+
new GenericInternalRow(parts)
67+
}.toSeq
68+
69+
val numWrittenRows = write_stats.sum
70+
val stats = BasicWriteTaskStats(
71+
partitions = partitionsInternalRows,
72+
numFiles = numFiles,
73+
numBytes = 101,
74+
numRows = numWrittenRows)
75+
val summary =
76+
ExecutedWriteSummary(updatedPartitions = updatedPartitions, stats = Seq(stats))
77+
78+
Some(
79+
WriteTaskResult(
80+
new TaskCommitMessage(addedAbsPathFiles.toMap -> updatedPartitions),
81+
summary))
82+
}
83+
}
84+
}

cpp-ch/local-engine/Common/CHUtil.h

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ namespace local_engine
4242
static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage";
4343
static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert";
4444
static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss";
45+
static const std::string SPARK_TASK_WRITE_TMEP_DIR = "gluten.write.temp.dir";
46+
static const std::string SPARK_TASK_WRITE_FILENAME = "gluten.write.file.name";
4547

4648
static const std::unordered_set<String> BOOL_VALUE_SETTINGS{
4749
MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, DECIMAL_OPERATIONS_ALLOW_PREC_LOSS};

cpp-ch/local-engine/Parser/SerializedPlanParser.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include <Parser/RelParser.h>
5858
#include <Parser/SubstraitParserUtils.h>
5959
#include <Parser/TypeParser.h>
60+
#include <Parser/WriteRelParser.h>
6061
#include <Parsers/ASTIdentifier.h>
6162
#include <Parsers/ExpressionListParsers.h>
6263
#include <Processors/Formats/Impl/ArrowBlockOutputFormat.h>
@@ -423,12 +424,13 @@ QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan)
423424
if (!root_rel.has_root())
424425
throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!");
425426

426-
if (root_rel.root().input().has_write())
427-
throw Exception(ErrorCodes::BAD_ARGUMENTS, "write pipeline is not supported yet!");
427+
const bool writePipeline = root_rel.root().input().has_write();
428+
const substrait::Rel & first_read_rel = writePipeline ? root_rel.root().input().write().input() : root_rel.root().input();
428429

429430
std::list<const substrait::Rel *> rel_stack;
430-
auto query_plan = parseOp(root_rel.root().input(), rel_stack);
431-
adjustOutput(query_plan, root_rel);
431+
auto query_plan = parseOp(first_read_rel, rel_stack);
432+
if (!writePipeline)
433+
adjustOutput(query_plan, root_rel);
432434

433435
#ifndef NDEBUG
434436
PlanUtil::checkOuputType(*query_plan);
@@ -1339,9 +1341,16 @@ std::unique_ptr<LocalExecutor> SerializedPlanParser::createExecutor(DB::QueryPla
13391341
Stopwatch stopwatch;
13401342

13411343
const Settings & settings = context->getSettingsRef();
1342-
auto pipeline_builder = buildQueryPipeline(*query_plan);
1344+
auto builder = buildQueryPipeline(*query_plan);
13431345

1344-
QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder));
1346+
///
1347+
assert(s_plan.relations_size() == 1);
1348+
const substrait::PlanRel & root_rel = s_plan.relations().at(0);
1349+
assert(root_rel.has_root());
1350+
if (root_rel.root().input().has_write())
1351+
addSinkTransfrom(context, root_rel.root().input().write(), builder);
1352+
///
1353+
QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder));
13451354

13461355
auto * logger = &Poco::Logger::get("SerializedPlanParser");
13471356
LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() / 1000.0);

0 commit comments

Comments
 (0)