Skip to content

Commit 14cf511

Browse files
committed
support pass format to backend
1 parent c362c5e commit 14cf511

File tree

4 files changed

+76
-59
lines changed

4 files changed

+76
-59
lines changed

gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala

+4-6
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,10 @@ case class CartesianProductExecTransformer(
8686
val (inputRightRelNode, inputRightOutput) =
8787
(rightPlanContext.root, rightPlanContext.outputAttributes)
8888

89-
val expressionNode = condition.map {
90-
expr =>
91-
ExpressionConverter
92-
.replaceWithExpressionTransformer(expr, inputLeftOutput ++ inputRightOutput)
93-
.doTransform(context.registeredFunction)
94-
}
89+
val expressionNode =
90+
condition.map {
91+
SubstraitUtil.toSubstraitExpression(_, inputLeftOutput ++ inputRightOutput, context)
92+
}
9593

9694
val extensionNode =
9795
JoinUtils.createExtensionNode(inputLeftOutput ++ inputRightOutput, validation = false)

gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala

+22-33
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
*/
1717
package org.apache.gluten.execution
1818

19-
import org.apache.gluten.backendsapi.BackendsApiManager
20-
import org.apache.gluten.expression.{AttributeReferenceTransformer, ConverterUtils, ExpressionConverter}
21-
import org.apache.gluten.substrait.`type`.TypeBuilder
19+
import org.apache.gluten.expression.{AttributeReferenceTransformer, ExpressionConverter}
2220
import org.apache.gluten.substrait.SubstraitContext
2321
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
2422
import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
2523
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
24+
import org.apache.gluten.utils.SubstraitUtil
2625

2726
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
2827
import org.apache.spark.sql.catalyst.plans._
@@ -34,21 +33,11 @@ import io.substrait.proto.{CrossRel, JoinRel}
3433
import scala.collection.JavaConverters._
3534

3635
object JoinUtils {
37-
private def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
38-
val inputTypeNodes = output.map {
39-
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
40-
}
41-
// Normally the enhancement node is only used for plan validation. But here the enhancement
42-
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
43-
// so that it can be correctly parsed into json string on the cpp side.
44-
BackendsApiManager.getTransformerApiInstance.packPBMessage(
45-
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
46-
}
4736

4837
def createExtensionNode(output: Seq[Attribute], validation: Boolean): AdvancedExtensionNode = {
4938
// Use field [enhancement] in a extension node for input type validation.
5039
if (validation) {
51-
ExtensionBuilder.makeAdvancedExtension(createEnhancement(output))
40+
ExtensionBuilder.makeAdvancedExtension(SubstraitUtil.createEnhancement(output))
5241
} else {
5342
null
5443
}
@@ -58,7 +47,7 @@ object JoinUtils {
5847
!keyExprs.forall(_.isInstanceOf[AttributeReference])
5948
}
6049

61-
def createPreProjectionIfNeeded(
50+
private def createPreProjectionIfNeeded(
6251
keyExprs: Seq[Expression],
6352
inputNode: RelNode,
6453
inputNodeOutput: Seq[Attribute],
@@ -131,17 +120,17 @@ object JoinUtils {
131120
}
132121
}
133122

134-
def createJoinExtensionNode(
123+
private def createJoinExtensionNode(
135124
joinParameters: Any,
136125
output: Seq[Attribute]): AdvancedExtensionNode = {
137126
// Use field [optimization] in a extension node
138127
// to send some join parameters through Substrait plan.
139-
val enhancement = createEnhancement(output)
128+
val enhancement = SubstraitUtil.createEnhancement(output)
140129
ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement)
141130
}
142131

143132
// Return the direct join output.
144-
protected def getDirectJoinOutput(
133+
private def getDirectJoinOutput(
145134
joinType: JoinType,
146135
leftOutput: Seq[Attribute],
147136
rightOutput: Seq[Attribute]): (Seq[Attribute], Seq[Attribute]) = {
@@ -164,7 +153,7 @@ object JoinUtils {
164153
}
165154
}
166155

167-
protected def getDirectJoinOutputSeq(
156+
private def getDirectJoinOutputSeq(
168157
joinType: JoinType,
169158
leftOutput: Seq[Attribute],
170159
rightOutput: Seq[Attribute]): Seq[Attribute] = {
@@ -209,8 +198,8 @@ object JoinUtils {
209198
validation)
210199

211200
// Combine join keys to make a single expression.
212-
val joinExpressionNode = (streamedKeys
213-
.zip(buildKeys))
201+
val joinExpressionNode = streamedKeys
202+
.zip(buildKeys)
214203
.map {
215204
case ((leftKey, leftType), (rightKey, rightType)) =>
216205
HashJoinLikeExecTransformer.makeEqualToExpression(
@@ -225,12 +214,10 @@ object JoinUtils {
225214
HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction))
226215

227216
// Create post-join filter, which will be computed in hash join.
228-
val postJoinFilter = condition.map {
229-
expr =>
230-
ExpressionConverter
231-
.replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput)
232-
.doTransform(substraitContext.registeredFunction)
233-
}
217+
val postJoinFilter =
218+
condition.map {
219+
SubstraitUtil.toSubstraitExpression(_, streamedOutput ++ buildOutput, substraitContext)
220+
}
234221

235222
// Create JoinRel.
236223
val joinRel = RelBuilder.makeJoinRel(
@@ -340,12 +327,14 @@ object JoinUtils {
340327
joinParameters: Any,
341328
validation: Boolean = false
342329
): RelNode = {
343-
val expressionNode = condition.map {
344-
expr =>
345-
ExpressionConverter
346-
.replaceWithExpressionTransformer(expr, inputStreamedOutput ++ inputBuildOutput)
347-
.doTransform(substraitContext.registeredFunction)
348-
}
330+
val expressionNode =
331+
condition.map {
332+
SubstraitUtil.toSubstraitExpression(
333+
_,
334+
inputStreamedOutput ++ inputBuildOutput,
335+
substraitContext)
336+
}
337+
349338
val extensionNode =
350339
createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput)
351340

gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala

+21-20
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,28 @@ import org.apache.gluten.backendsapi.BackendsApiManager
2121
import org.apache.gluten.expression.ConverterUtils
2222
import org.apache.gluten.extension.ValidationResult
2323
import org.apache.gluten.metrics.MetricsUpdater
24-
import org.apache.gluten.substrait.`type`.{ColumnTypeNode, TypeBuilder}
24+
import org.apache.gluten.substrait.`type`.ColumnTypeNode
2525
import org.apache.gluten.substrait.SubstraitContext
2626
import org.apache.gluten.substrait.extensions.ExtensionBuilder
2727
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
28+
import org.apache.gluten.utils.SubstraitUtil
2829

2930
import org.apache.spark.sql.catalyst.catalog.BucketSpec
3031
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3132
import org.apache.spark.sql.catalyst.expressions.Attribute
3233
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
3334
import org.apache.spark.sql.execution.SparkPlan
3435
import org.apache.spark.sql.execution.datasources.FileFormat
36+
import org.apache.spark.sql.execution.metric.SQLMetric
3537
import org.apache.spark.sql.internal.SQLConf
38+
import org.apache.spark.sql.sources.DataSourceRegister
3639
import org.apache.spark.sql.types.MetadataBuilder
3740

3841
import com.google.protobuf.{Any, StringValue}
3942
import org.apache.parquet.hadoop.ParquetOutputFormat
4043

4144
import java.util.Locale
4245

43-
import scala.collection.JavaConverters._
4446
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
4547

4648
/**
@@ -56,7 +58,7 @@ case class WriteFilesExecTransformer(
5658
staticPartitions: TablePartitionSpec)
5759
extends UnaryTransformSupport {
5860
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
59-
@transient override lazy val metrics =
61+
@transient override lazy val metrics: Map[String, SQLMetric] =
6062
BackendsApiManager.getMetricsApiInstance.genWriteFilesTransformerMetrics(sparkContext)
6163

6264
override def metricsUpdater(): MetricsUpdater =
@@ -66,27 +68,25 @@ case class WriteFilesExecTransformer(
6668

6769
private val caseInsensitiveOptions = CaseInsensitiveMap(options)
6870

69-
def genWriteParameters(): Any = {
71+
private def genWriteParameters(): Any = {
72+
val fileFormatStr = fileFormat match {
73+
case register: DataSourceRegister =>
74+
register.shortName
75+
case _ => "UnknownFileFormat"
76+
}
7077
val compressionCodec =
7178
WriteFilesExecTransformer.getCompressionCodec(caseInsensitiveOptions).capitalize
7279
val writeParametersStr = new StringBuffer("WriteParameters:")
73-
writeParametersStr.append("is").append(compressionCodec).append("=1").append("\n")
80+
writeParametersStr.append("is").append(compressionCodec).append("=1")
81+
writeParametersStr.append(";format=").append(fileFormatStr).append("\n")
82+
7483
val message = StringValue
7584
.newBuilder()
7685
.setValue(writeParametersStr.toString)
7786
.build()
7887
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
7988
}
8089

81-
def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
82-
val inputTypeNodes = output.map {
83-
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
84-
}
85-
86-
BackendsApiManager.getTransformerApiInstance.packPBMessage(
87-
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
88-
}
89-
9090
def getRelNode(
9191
context: SubstraitContext,
9292
originalInputAttributes: Seq[Attribute],
@@ -118,10 +118,11 @@ case class WriteFilesExecTransformer(
118118
val extensionNode = if (!validation) {
119119
ExtensionBuilder.makeAdvancedExtension(
120120
genWriteParameters(),
121-
createEnhancement(originalInputAttributes))
121+
SubstraitUtil.createEnhancement(originalInputAttributes))
122122
} else {
123123
// Use a extension node to send the input types through Substrait plan for validation.
124-
ExtensionBuilder.makeAdvancedExtension(createEnhancement(originalInputAttributes))
124+
ExtensionBuilder.makeAdvancedExtension(
125+
SubstraitUtil.createEnhancement(originalInputAttributes))
125126
}
126127
RelBuilder.makeWriteRel(
127128
input,
@@ -133,7 +134,7 @@ case class WriteFilesExecTransformer(
133134
operatorId)
134135
}
135136

136-
private def getFinalChildOutput(): Seq[Attribute] = {
137+
private def getFinalChildOutput: Seq[Attribute] = {
137138
val metadataExclusionList = conf
138139
.getConf(GlutenConfig.NATIVE_WRITE_FILES_COLUMN_METADATA_EXCLUSION_LIST)
139140
.split(",")
@@ -143,7 +144,7 @@ case class WriteFilesExecTransformer(
143144
}
144145

145146
override protected def doValidateInternal(): ValidationResult = {
146-
val finalChildOutput = getFinalChildOutput()
147+
val finalChildOutput = getFinalChildOutput
147148
val validationResult =
148149
BackendsApiManager.getSettings.supportWriteFilesExec(
149150
fileFormat,
@@ -165,7 +166,7 @@ case class WriteFilesExecTransformer(
165166
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
166167
val operatorId = context.nextOperatorId(this.nodeName)
167168
val currRel =
168-
getRelNode(context, getFinalChildOutput(), operatorId, childCtx.root, validation = false)
169+
getRelNode(context, getFinalChildOutput, operatorId, childCtx.root, validation = false)
169170
assert(currRel != null, "Write Rel should be valid")
170171
TransformContext(childCtx.outputAttributes, output, currRel)
171172
}
@@ -196,7 +197,7 @@ object WriteFilesExecTransformer {
196197
"__file_source_generated_metadata_col"
197198
)
198199

199-
def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = {
200+
private def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = {
200201
val metadataKeys = INTERNAL_METADATA_KEYS ++ metadataExclusionList
201202
attr.withMetadata {
202203
var builder = new MetadataBuilder().withMetadata(attr.metadata)

gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala

+29
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
*/
1717
package org.apache.gluten.utils
1818

19+
import org.apache.gluten.backendsapi.BackendsApiManager
20+
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
21+
import org.apache.gluten.substrait.`type`.TypeBuilder
22+
import org.apache.gluten.substrait.SubstraitContext
23+
import org.apache.gluten.substrait.expression.ExpressionNode
24+
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
1926
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
2027

2128
import io.substrait.proto.{CrossRel, JoinRel}
2229

30+
import scala.collection.JavaConverters._
31+
2332
object SubstraitUtil {
2433
def toSubstrait(sparkJoin: JoinType): JoinRel.JoinType = sparkJoin match {
2534
case _: InnerLike =>
@@ -55,4 +64,24 @@ object SubstraitUtil {
5564
case _ =>
5665
CrossRel.JoinType.UNRECOGNIZED
5766
}
67+
68+
def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
69+
val inputTypeNodes = output.map {
70+
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
71+
}
72+
// Normally the enhancement node is only used for plan validation. But here the enhancement
73+
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
74+
// so that it can be correctly parsed into json string on the cpp side.
75+
BackendsApiManager.getTransformerApiInstance.packPBMessage(
76+
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
77+
}
78+
79+
def toSubstraitExpression(
80+
expr: Expression,
81+
attributeSeq: Seq[Attribute],
82+
context: SubstraitContext): ExpressionNode = {
83+
ExpressionConverter
84+
.replaceWithExpressionTransformer(expr, attributeSeq)
85+
.doTransform(context.registeredFunction)
86+
}
5887
}

0 commit comments

Comments
 (0)