From 6a068e091285ccdf075c4095c40350818cc8ef33 Mon Sep 17 00:00:00 2001 From: manuzhang Date: Sun, 9 Sep 2018 16:43:27 +0800 Subject: [PATCH 1/7] Add TypedOneHotEncoder --- .../ml/feature/TypedOneHotEncoder.scala | 47 +++++++++++++++ .../ml/feature/TypedOneHotEncoderTests.scala | 57 +++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala create mode 100644 ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala diff --git a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala new file mode 100644 index 00000000..23ebc22c --- /dev/null +++ b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala @@ -0,0 +1,47 @@ +package frameless +package ml +package feature + +import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid +import frameless.ml.internals.UnaryInputsChecker +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel} +import org.apache.spark.ml.linalg.Vector + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * + * @see `TypedStringIndexer` for converting categorical values into category indices + */ +class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoderEstimator, inputCol: String) + extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] { + + override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder + .setInputCols(Array(inputCol)) + .setOutputCols(Array(AppendTransformer.tempColumnName)) + + def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] = + copy(oneHotEncoder.setHandleInvalid(value.sparkValue)) + + def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] = + copy(oneHotEncoder.setDropLast(value)) + + private def copy(newOneHotEncoder: OneHotEncoderEstimator): TypedOneHotEncoder[Inputs] = + new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol) +} + +object TypedOneHotEncoder { + + case class Outputs(output: Vector) + + sealed abstract class HandleInvalid(val sparkValue: String) + object HandleInvalid { + case object Error extends HandleInvalid("error") + case object Keep extends HandleInvalid("keep") + } + + def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, Int]): TypedOneHotEncoder[Inputs] = { + new TypedOneHotEncoder[Inputs](new OneHotEncoderEstimator(), inputsChecker.inputCol) + } +} diff --git a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala new file mode 100644 index 00000000..2f53c1a1 --- /dev/null +++ b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala @@ -0,0 +1,57 @@ +package frameless +package ml +package feature + +import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid +import org.apache.spark.ml.linalg._ +import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.Prop._ +import shapeless.test.illTyped + +class TypedOneHotEncoderTests extends FramelessMlSuite { + + test(".fit() returns a correct TypedTransformer") { + implicit val arbInt = Arbitrary(Gen.choose(0, 99)) + def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) => + val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast) + val inputs = 0.to(x2.a).map(i => X2(i, x2.b)) + val ds = TypedDataset.create(inputs) + val model = encoder.fit(ds).run() + val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]] + val result = resultDs.collect.run() + if (dropLast) { + result == Seq (X3(x2.a, x2.b, + Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray))) + } else { + result == Seq (X3(x2.a, x2.b, + Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0)))) + } + } + + check(prop[Double]) + check(prop[String]) + } + + test("param setting is retained") { + implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary { + Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error) + } + + val prop = forAll { handleInvalid: HandleInvalid => + val encoder = TypedOneHotEncoder[X1[Int]] + .setHandleInvalid(handleInvalid) + val ds = TypedDataset.create(Seq(X1(1))) + val model = encoder.fit(ds).run() + + model.transformer.getHandleInvalid == handleInvalid.sparkValue + } + + check(prop) + } + + test("create() compiles only with correct inputs") { + illTyped("TypedOneHotEncoder.create[Double]()") + illTyped("TypedOneHotEncoder.create[X1[Double]]()") + illTyped("TypedOneHotEncoder.create[X2[String, Long]]()") + } +} From b51b33783bef7da53caf61ad3d47006cfb09c8a4 Mon Sep 17 00:00:00 2001 From: manuzhang Date: Fri, 28 Sep 2018 10:22:17 +0800 Subject: [PATCH 2/7] Fix testing TypedOneHotEncoder#apply --- .../frameless/ml/feature/TypedOneHotEncoderTests.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala index 2f53c1a1..0d563e0f 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala @@ -49,9 +49,9 @@ class TypedOneHotEncoderTests extends FramelessMlSuite { check(prop) } - test("create() compiles only with correct inputs") { - illTyped("TypedOneHotEncoder.create[Double]()") - illTyped("TypedOneHotEncoder.create[X1[Double]]()") - illTyped("TypedOneHotEncoder.create[X2[String, Long]]()") + test("apply() compiles only with correct inputs") { + illTyped("TypedOneHotEncoder.apply[Double]()") + illTyped("TypedOneHotEncoder.apply[X1[Double]]()") + illTyped("TypedOneHotEncoder.apply[X2[String, Long]]()") } } From 7569bffb401d3c2f08297581b92efde9292ed342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 4 Sep 2021 16:08:25 +0200 Subject: [PATCH 3/7] Update TypedOneHotEncoder.scala --- .../frameless/ml/feature/TypedOneHotEncoder.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala index 23ebc22c..c3543ab0 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala @@ -1,6 +1,4 @@ -package frameless -package ml -package feature +package frameless.ml.feature import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid import frameless.ml.internals.UnaryInputsChecker @@ -35,13 +33,13 @@ object TypedOneHotEncoder { case class Outputs(output: Vector) - sealed abstract class HandleInvalid(val sparkValue: String) + final class HandleInvalid private(val sparkValue: String) extends AnyVal + object HandleInvalid { case object Error extends HandleInvalid("error") case object Keep extends HandleInvalid("keep") } - def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, Int]): TypedOneHotEncoder[Inputs] = { + def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] = new TypedOneHotEncoder[Inputs](new OneHotEncoderEstimator(), inputsChecker.inputCol) - } } From 9d980e63c1a9f3d8a8d27f4ba8eb19aa9be69b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 4 Sep 2021 16:08:57 +0200 Subject: [PATCH 4/7] Update TypedOneHotEncoderTests.scala --- .../scala/frameless/ml/feature/TypedOneHotEncoderTests.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala index 0d563e0f..7ab714ca 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala @@ -1,6 +1,4 @@ -package frameless -package ml -package feature +package frameless.ml.feature import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid import org.apache.spark.ml.linalg._ From 9e3f78d3b4ac65e343aa65581f61e77d9fec945f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 4 Sep 2021 18:53:29 +0200 Subject: [PATCH 5/7] Update TypedOneHotEncoder.scala --- .../frameless/ml/feature/TypedOneHotEncoder.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala index c3543ab0..b792ad88 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala @@ -1,9 +1,11 @@ package frameless.ml.feature +import frameless.ml.TypedEstimator import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid import frameless.ml.internals.UnaryInputsChecker + import org.apache.spark.ml.Estimator -import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel} +import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel} import org.apache.spark.ml.linalg.Vector /** @@ -12,7 +14,7 @@ import org.apache.spark.ml.linalg.Vector * * @see `TypedStringIndexer` for converting categorical values into category indices */ -class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoderEstimator, inputCol: String) +class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoder, inputCol: String) extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] { override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder @@ -25,7 +27,7 @@ class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoderEstimat def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] = copy(oneHotEncoder.setDropLast(value)) - private def copy(newOneHotEncoder: OneHotEncoderEstimator): TypedOneHotEncoder[Inputs] = + private def copy(newOneHotEncoder: OneHotEncoder): TypedOneHotEncoder[Inputs] = new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol) } @@ -33,7 +35,7 @@ object TypedOneHotEncoder { case class Outputs(output: Vector) - final class HandleInvalid private(val sparkValue: String) extends AnyVal + sealed class HandleInvalid private(val sparkValue: String) extends AnyVal object HandleInvalid { case object Error extends HandleInvalid("error") @@ -41,5 +43,5 @@ object TypedOneHotEncoder { } def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] = - new TypedOneHotEncoder[Inputs](new OneHotEncoderEstimator(), inputsChecker.inputCol) + new TypedOneHotEncoder[T](new OneHotEncoder(), inputsChecker.inputCol) } From 11bf7d605f7aaee4f2cfe8ad3feee6b1904ebaf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 4 Sep 2021 19:07:31 +0200 Subject: [PATCH 6/7] Update TypedOneHotEncoder.scala --- .../scala/frameless/ml/feature/TypedOneHotEncoder.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala index b792ad88..89ad2bf6 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala @@ -1,6 +1,6 @@ package frameless.ml.feature -import frameless.ml.TypedEstimator +import frameless.ml.{AppendTransformer, TypedEstimator} import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid import frameless.ml.internals.UnaryInputsChecker @@ -35,11 +35,11 @@ object TypedOneHotEncoder { case class Outputs(output: Vector) - sealed class HandleInvalid private(val sparkValue: String) extends AnyVal + final class HandleInvalid private(val sparkValue: String) extends AnyVal object HandleInvalid { - case object Error extends HandleInvalid("error") - case object Keep extends HandleInvalid("keep") + val Error = new HandleInvalid("error") + val Keep = new HandleInvalid("keep") } def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] = From 347f9bc8598c44ca8fe162c14139f6e0fce3edea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 4 Sep 2021 21:00:08 +0200 Subject: [PATCH 7/7] Update TypedOneHotEncoderTests.scala --- .../scala/frameless/ml/feature/TypedOneHotEncoderTests.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala index 7ab714ca..5e6d445f 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala @@ -1,12 +1,14 @@ package frameless.ml.feature +import frameless.ml.FramelessMlSuite import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid + import org.apache.spark.ml.linalg._ import org.scalacheck.{Arbitrary, Gen} import org.scalacheck.Prop._ import shapeless.test.illTyped -class TypedOneHotEncoderTests extends FramelessMlSuite { +final class TypedOneHotEncoderTests extends FramelessMlSuite { test(".fit() returns a correct TypedTransformer") { implicit val arbInt = Arbitrary(Gen.choose(0, 99))