Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypedOneHotEncoder #322

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package frameless.ml.feature

import frameless.ml.{AppendTransformer, TypedEstimator}
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
import frameless.ml.internals.UnaryInputsChecker

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel}
import org.apache.spark.ml.linalg.Vector

/**
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
* at most a single one-value per row that indicates the input category index.
*
* @see `TypedStringIndexer` for converting categorical values into category indices
*/
class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoder, inputCol: String)
extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] {

override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder
.setInputCols(Array(inputCol))
.setOutputCols(Array(AppendTransformer.tempColumnName))

def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setHandleInvalid(value.sparkValue))

def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setDropLast(value))

private def copy(newOneHotEncoder: OneHotEncoder): TypedOneHotEncoder[Inputs] =
new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol)
}

object TypedOneHotEncoder {

case class Outputs(output: Vector)

final class HandleInvalid private(val sparkValue: String) extends AnyVal

object HandleInvalid {
val Error = new HandleInvalid("error")
val Keep = new HandleInvalid("keep")
}

def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] =
new TypedOneHotEncoder[T](new OneHotEncoder(), inputsChecker.inputCol)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package frameless.ml.feature

import frameless.ml.FramelessMlSuite
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid

import org.apache.spark.ml.linalg._
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop._
import shapeless.test.illTyped

final class TypedOneHotEncoderTests extends FramelessMlSuite {

test(".fit() returns a correct TypedTransformer") {
implicit val arbInt = Arbitrary(Gen.choose(0, 99))
def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) =>
val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast)
val inputs = 0.to(x2.a).map(i => X2(i, x2.b))
val ds = TypedDataset.create(inputs)
val model = encoder.fit(ds).run()
val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]]
val result = resultDs.collect.run()
if (dropLast) {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray)))
} else {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0))))
}
}

check(prop[Double])
check(prop[String])
}

test("param setting is retained") {
implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary {
Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error)
}

val prop = forAll { handleInvalid: HandleInvalid =>
val encoder = TypedOneHotEncoder[X1[Int]]
.setHandleInvalid(handleInvalid)
val ds = TypedDataset.create(Seq(X1(1)))
val model = encoder.fit(ds).run()

model.transformer.getHandleInvalid == handleInvalid.sparkValue
}

check(prop)
}

test("apply() compiles only with correct inputs") {
illTyped("TypedOneHotEncoder.apply[Double]()")
illTyped("TypedOneHotEncoder.apply[X1[Double]]()")
illTyped("TypedOneHotEncoder.apply[X2[String, Long]]()")
}
}