@@ -21,26 +21,28 @@ import org.apache.gluten.backendsapi.BackendsApiManager
21
21
import org .apache .gluten .expression .ConverterUtils
22
22
import org .apache .gluten .extension .ValidationResult
23
23
import org .apache .gluten .metrics .MetricsUpdater
24
- import org .apache .gluten .substrait .`type` .{ ColumnTypeNode , TypeBuilder }
24
+ import org .apache .gluten .substrait .`type` .ColumnTypeNode
25
25
import org .apache .gluten .substrait .SubstraitContext
26
26
import org .apache .gluten .substrait .extensions .ExtensionBuilder
27
27
import org .apache .gluten .substrait .rel .{RelBuilder , RelNode }
28
+ import org .apache .gluten .utils .SubstraitUtil
28
29
29
30
import org .apache .spark .sql .catalyst .catalog .BucketSpec
30
31
import org .apache .spark .sql .catalyst .catalog .CatalogTypes .TablePartitionSpec
31
32
import org .apache .spark .sql .catalyst .expressions .Attribute
32
33
import org .apache .spark .sql .catalyst .util .CaseInsensitiveMap
33
34
import org .apache .spark .sql .execution .SparkPlan
34
35
import org .apache .spark .sql .execution .datasources .FileFormat
36
+ import org .apache .spark .sql .execution .metric .SQLMetric
35
37
import org .apache .spark .sql .internal .SQLConf
38
+ import org .apache .spark .sql .sources .DataSourceRegister
36
39
import org .apache .spark .sql .types .MetadataBuilder
37
40
38
41
import com .google .protobuf .{Any , StringValue }
39
42
import org .apache .parquet .hadoop .ParquetOutputFormat
40
43
41
44
import java .util .Locale
42
45
43
- import scala .collection .JavaConverters ._
44
46
import scala .collection .convert .ImplicitConversions .`collection AsScalaIterable`
45
47
46
48
/**
@@ -56,7 +58,7 @@ case class WriteFilesExecTransformer(
56
58
staticPartitions : TablePartitionSpec )
57
59
extends UnaryTransformSupport {
58
60
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
59
- @ transient override lazy val metrics =
61
+ @ transient override lazy val metrics : Map [ String , SQLMetric ] =
60
62
BackendsApiManager .getMetricsApiInstance.genWriteFilesTransformerMetrics(sparkContext)
61
63
62
64
override def metricsUpdater (): MetricsUpdater =
@@ -66,27 +68,25 @@ case class WriteFilesExecTransformer(
66
68
67
69
private val caseInsensitiveOptions = CaseInsensitiveMap (options)
68
70
69
- def genWriteParameters (): Any = {
71
+ private def genWriteParameters (): Any = {
72
+ val fileFormatStr = fileFormat match {
73
+ case register : DataSourceRegister =>
74
+ register.shortName
75
+ case _ => " UnknownFileFormat"
76
+ }
70
77
val compressionCodec =
71
78
WriteFilesExecTransformer .getCompressionCodec(caseInsensitiveOptions).capitalize
72
79
val writeParametersStr = new StringBuffer (" WriteParameters:" )
73
- writeParametersStr.append(" is" ).append(compressionCodec).append(" =1" ).append(" \n " )
80
+ writeParametersStr.append(" is" ).append(compressionCodec).append(" =1" )
81
+ writeParametersStr.append(" ;format=" ).append(fileFormatStr).append(" \n " )
82
+
74
83
val message = StringValue
75
84
.newBuilder()
76
85
.setValue(writeParametersStr.toString)
77
86
.build()
78
87
BackendsApiManager .getTransformerApiInstance.packPBMessage(message)
79
88
}
80
89
81
- def createEnhancement (output : Seq [Attribute ]): com.google.protobuf.Any = {
82
- val inputTypeNodes = output.map {
83
- attr => ConverterUtils .getTypeNode(attr.dataType, attr.nullable)
84
- }
85
-
86
- BackendsApiManager .getTransformerApiInstance.packPBMessage(
87
- TypeBuilder .makeStruct(false , inputTypeNodes.asJava).toProtobuf)
88
- }
89
-
90
90
def getRelNode (
91
91
context : SubstraitContext ,
92
92
originalInputAttributes : Seq [Attribute ],
@@ -118,10 +118,11 @@ case class WriteFilesExecTransformer(
118
118
val extensionNode = if (! validation) {
119
119
ExtensionBuilder .makeAdvancedExtension(
120
120
genWriteParameters(),
121
- createEnhancement(originalInputAttributes))
121
+ SubstraitUtil . createEnhancement(originalInputAttributes))
122
122
} else {
123
123
// Use a extension node to send the input types through Substrait plan for validation.
124
- ExtensionBuilder .makeAdvancedExtension(createEnhancement(originalInputAttributes))
124
+ ExtensionBuilder .makeAdvancedExtension(
125
+ SubstraitUtil .createEnhancement(originalInputAttributes))
125
126
}
126
127
RelBuilder .makeWriteRel(
127
128
input,
@@ -133,7 +134,7 @@ case class WriteFilesExecTransformer(
133
134
operatorId)
134
135
}
135
136
136
- private def getFinalChildOutput () : Seq [Attribute ] = {
137
+ private def getFinalChildOutput : Seq [Attribute ] = {
137
138
val metadataExclusionList = conf
138
139
.getConf(GlutenConfig .NATIVE_WRITE_FILES_COLUMN_METADATA_EXCLUSION_LIST )
139
140
.split(" ," )
@@ -143,7 +144,7 @@ case class WriteFilesExecTransformer(
143
144
}
144
145
145
146
override protected def doValidateInternal (): ValidationResult = {
146
- val finalChildOutput = getFinalChildOutput()
147
+ val finalChildOutput = getFinalChildOutput
147
148
val validationResult =
148
149
BackendsApiManager .getSettings.supportWriteFilesExec(
149
150
fileFormat,
@@ -165,7 +166,7 @@ case class WriteFilesExecTransformer(
165
166
val childCtx = child.asInstanceOf [TransformSupport ].transform(context)
166
167
val operatorId = context.nextOperatorId(this .nodeName)
167
168
val currRel =
168
- getRelNode(context, getFinalChildOutput() , operatorId, childCtx.root, validation = false )
169
+ getRelNode(context, getFinalChildOutput, operatorId, childCtx.root, validation = false )
169
170
assert(currRel != null , " Write Rel should be valid" )
170
171
TransformContext (childCtx.outputAttributes, output, currRel)
171
172
}
@@ -196,7 +197,7 @@ object WriteFilesExecTransformer {
196
197
" __file_source_generated_metadata_col"
197
198
)
198
199
199
- def removeMetadata (attr : Attribute , metadataExclusionList : Seq [String ]): Attribute = {
200
+ private def removeMetadata (attr : Attribute , metadataExclusionList : Seq [String ]): Attribute = {
200
201
val metadataKeys = INTERNAL_METADATA_KEYS ++ metadataExclusionList
201
202
attr.withMetadata {
202
203
var builder = new MetadataBuilder ().withMetadata(attr.metadata)
0 commit comments