Skip to content

Commit 570b8f3

Browse files
imatiach-msftsrowen
authored andcommitted
[SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator
## What changes were proposed in this pull request? The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data. I've closed the PR: apache#16557 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update. The updates to the regression metrics were based on (and updated with new changes based on comments): https://issues.apache.org/jira/browse/SPARK-11520 ("RegressionMetrics should support instance weights") but the pull request was closed as the changes were never checked in. ## How was this patch tested? I added tests to the metrics class. Closes apache#17085 from imatiach-msft/ilmat/regression-evaluate. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 79e36e2 commit 570b8f3

File tree

6 files changed

+106
-29
lines changed

6 files changed

+106
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
1919

2020
import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
22-
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
22+
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
2323
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
2424
import org.apache.spark.mllib.evaluation.RegressionMetrics
2525
import org.apache.spark.sql.{Dataset, Row}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
3333
@Since("1.4.0")
3434
@Experimental
3535
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
36-
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
36+
extends Evaluator with HasPredictionCol with HasLabelCol
37+
with HasWeightCol with DefaultParamsWritable {
3738

3839
@Since("1.4.0")
3940
def this() = this(Identifiable.randomUID("regEval"))
@@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
6970
@Since("1.4.0")
7071
def setLabelCol(value: String): this.type = set(labelCol, value)
7172

73+
/** @group setParam */
74+
@Since("3.0.0")
75+
def setWeightCol(value: String): this.type = set(weightCol, value)
76+
7277
setDefault(metricName -> "rmse")
7378

7479
@Since("2.0.0")
@@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
7782
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
7883
SchemaUtils.checkNumericType(schema, $(labelCol))
7984

80-
val predictionAndLabels = dataset
81-
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
85+
val predictionAndLabelsWithWeights = dataset
86+
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
87+
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
8288
.rdd
83-
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
84-
val metrics = new RegressionMetrics(predictionAndLabels)
89+
.map { case Row(prediction: Double, label: Double, weight: Double) =>
90+
(prediction, label, weight) }
91+
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
8592
val metric = $(metricName) match {
8693
case "rmse" => metrics.rootMeanSquaredError
8794
case "mse" => metrics.meanSquaredError

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
2727
/**
2828
* Evaluator for regression.
2929
*
30-
* @param predictionAndObservations an RDD of (prediction, observation) pairs
30+
* @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
31+
* or (prediction, observation) pairs
3132
* @param throughOrigin True if the regression is through the origin. For example, in linear
3233
* regression, it will be true without fitting intercept.
3334
*/
3435
@Since("1.2.0")
3536
class RegressionMetrics @Since("2.0.0") (
36-
predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
37+
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
3738
extends Logging {
3839

3940
@Since("1.2.0")
40-
def this(predictionAndObservations: RDD[(Double, Double)]) =
41+
def this(predictionAndObservations: RDD[_ <: Product]) =
4142
this(predictionAndObservations, false)
4243

4344
/**
@@ -52,22 +53,27 @@ class RegressionMetrics @Since("2.0.0") (
5253
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
5354
*/
5455
private lazy val summary: MultivariateStatisticalSummary = {
55-
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
56-
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
56+
val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
57+
case (prediction: Double, observation: Double, weight: Double) =>
58+
(Vectors.dense(observation, observation - prediction), weight)
59+
case (prediction: Double, observation: Double) =>
60+
(Vectors.dense(observation, observation - prediction), 1.0)
5761
}.treeAggregate(new MultivariateOnlineSummarizer())(
58-
(summary, v) => summary.add(v),
62+
(summary, sample) => summary.add(sample._1, sample._2),
5963
(sum1, sum2) => sum1.merge(sum2)
6064
)
6165
summary
6266
}
6367

6468
private lazy val SSy = math.pow(summary.normL2(0), 2)
6569
private lazy val SSerr = math.pow(summary.normL2(1), 2)
66-
private lazy val SStot = summary.variance(0) * (summary.count - 1)
70+
private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
6771
private lazy val SSreg = {
6872
val yMean = summary.mean(0)
69-
predictionAndObservations.map {
70-
case (prediction, _) => math.pow(prediction - yMean, 2)
73+
predAndObsWithOptWeight.map {
74+
case (prediction: Double, _: Double, weight: Double) =>
75+
math.pow(prediction - yMean, 2) * weight
76+
case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
7177
}.sum()
7278
}
7379

@@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
7985
*/
8086
@Since("1.2.0")
8187
def explainedVariance: Double = {
82-
SSreg / summary.count
88+
SSreg / summary.weightSum
8389
}
8490

8591
/**
@@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
8894
*/
8995
@Since("1.2.0")
9096
def meanAbsoluteError: Double = {
91-
summary.normL1(1) / summary.count
97+
summary.normL1(1) / summary.weightSum
9298
}
9399

94100
/**
@@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
97103
*/
98104
@Since("1.2.0")
99105
def meanSquaredError: Double = {
100-
SSerr / summary.count
106+
SSerr / summary.weightSum
101107
}
102108

103109
/**

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
5252
private var totalCnt: Long = 0
5353
private var totalWeightSum: Double = 0.0
5454
private var weightSquareSum: Double = 0.0
55-
private var weightSum: Array[Double] = _
55+
private var currWeightSum: Array[Double] = _
5656
private var nnz: Array[Long] = _
5757
private var currMax: Array[Double] = _
5858
private var currMin: Array[Double] = _
@@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
7878
currM2n = Array.ofDim[Double](n)
7979
currM2 = Array.ofDim[Double](n)
8080
currL1 = Array.ofDim[Double](n)
81-
weightSum = Array.ofDim[Double](n)
81+
currWeightSum = Array.ofDim[Double](n)
8282
nnz = Array.ofDim[Long](n)
8383
currMax = Array.fill[Double](n)(Double.MinValue)
8484
currMin = Array.fill[Double](n)(Double.MaxValue)
@@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
9191
val localCurrM2n = currM2n
9292
val localCurrM2 = currM2
9393
val localCurrL1 = currL1
94-
val localWeightSum = weightSum
94+
val localWeightSum = currWeightSum
9595
val localNumNonzeros = nnz
9696
val localCurrMax = currMax
9797
val localCurrMin = currMin
@@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
139139
weightSquareSum += other.weightSquareSum
140140
var i = 0
141141
while (i < n) {
142-
val thisNnz = weightSum(i)
143-
val otherNnz = other.weightSum(i)
142+
val thisNnz = currWeightSum(i)
143+
val otherNnz = other.currWeightSum(i)
144144
val totalNnz = thisNnz + otherNnz
145145
val totalCnnz = nnz(i) + other.nnz(i)
146146
if (totalNnz != 0.0) {
@@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
157157
currMax(i) = math.max(currMax(i), other.currMax(i))
158158
currMin(i) = math.min(currMin(i), other.currMin(i))
159159
}
160-
weightSum(i) = totalNnz
160+
currWeightSum(i) = totalNnz
161161
nnz(i) = totalCnnz
162162
i += 1
163163
}
@@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
170170
this.totalCnt = other.totalCnt
171171
this.totalWeightSum = other.totalWeightSum
172172
this.weightSquareSum = other.weightSquareSum
173-
this.weightSum = other.weightSum.clone()
173+
this.currWeightSum = other.currWeightSum.clone()
174174
this.nnz = other.nnz.clone()
175175
this.currMax = other.currMax.clone()
176176
this.currMin = other.currMin.clone()
@@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
189189
val realMean = Array.ofDim[Double](n)
190190
var i = 0
191191
while (i < n) {
192-
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
192+
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
193193
i += 1
194194
}
195195
Vectors.dense(realMean)
@@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
214214
val len = currM2n.length
215215
while (i < len) {
216216
// We prevent variance from negative value caused by numerical error.
217-
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
218-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
217+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
218+
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
219219
i += 1
220220
}
221221
}
@@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
229229
@Since("1.1.0")
230230
override def count: Long = totalCnt
231231

232+
/**
233+
* Sum of weights.
234+
*/
235+
override def weightSum: Double = totalWeightSum
236+
232237
/**
233238
* Number of nonzero elements in each dimension.
234239
*

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
4444
@Since("1.0.0")
4545
def count: Long
4646

47+
/**
48+
* Sum of weights.
49+
*/
50+
@Since("3.0.0")
51+
def weightSum: Double
52+
4753
/**
4854
* Number of nonzero elements (including explicitly presented zero values) in each column.
4955
*/

mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
133133
"root mean squared error mismatch")
134134
assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
135135
}
136+
137+
test("regression metrics with same (1.0) weight samples") {
138+
val predictionAndObservationWithWeight = sc.parallelize(
139+
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
140+
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
141+
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
142+
"explained variance regression score mismatch")
143+
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
144+
assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
145+
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
146+
"root mean squared error mismatch")
147+
assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
148+
}
149+
150+
/**
151+
* The following values are hand calculated using the formula:
152+
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
153+
* preds = c(2.25, -0.25, 1.75, 7.75)
154+
* obs = c(3.0, -0.5, 2.0, 7.0)
155+
* weights = c(0.1, 0.2, 0.15, 0.05)
156+
* count = 4
157+
*
158+
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
159+
* (observations, observations - predictions)
160+
* mean (1.7, 0.05)
161+
* variance (7.3, 0.3)
162+
* numNonZeros (0.5, 0.5)
163+
* max (7.0, 0.75)
164+
* min (-0.5, -0.75)
165+
* normL2 (2.0, 0.32596)
166+
* normL1 (1.05, 0.2)
167+
*
168+
* explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
169+
* meanAbsoluteError: normL1(1) / weightedCount = 0.4
170+
* meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
171+
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
172+
* r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
173+
*/
174+
test("regression metrics with weighted samples") {
175+
val predictionAndObservationWithWeight = sc.parallelize(
176+
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
177+
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
178+
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
179+
"explained variance regression score mismatch")
180+
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
181+
assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
182+
assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
183+
"root mean squared error mismatch")
184+
assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
185+
}
136186
}

project/MimaExcludes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,10 @@ object MimaExcludes {
531531
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
532532
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
533533
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
534-
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
534+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),
535+
536+
// [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
537+
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
535538
) ++ Seq(
536539
// [SPARK-17019] Expose on-heap and off-heap memory usage in various places
537540
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),

0 commit comments

Comments
 (0)