Skip to content

Commit

Permalink
Merge pull request #231 from javascriptdata/classification-twod
Browse files Browse the repository at this point in the history
SGD classifier can now accept 2D y values
  • Loading branch information
dcrescim authored May 22, 2022
2 parents 4975d96 + 10141cd commit 3e4ce65
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 13 deletions.
62 changes: 62 additions & 0 deletions src/linear_model/LogisticRegression.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,68 @@ describe('LogisticRegression', function () {
expect(results.arraySync()).toEqual([0, 0, 0, 1, 1, 1])
expect(logreg.score(X, y) > 0.5).toBe(true)
}, 30000)
it('Test of the function used with 2 classes (one hot)', async function () {
let X = [
[0, -1],
[1, 0],
[1, 1],
[1, -1],
[2, 0],
[2, 1],
[2, -1],
[3, 2],
[0, 4],
[1, 3],
[1, 4],
[1, 5],
[2, 3],
[2, 4],
[2, 5],
[3, 4]
]
let y = [
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1]
]

let Xtest = [
[0, -2],
[1, 0.5],
[1.5, -1],
[1, 4.5],
[2, 3.5],
[1.5, 5]
]

let logreg = new LogisticRegression({ penalty: 'none' })
await logreg.fit(X, y)
let probabilities = logreg.predictProba(X)
expect(probabilities instanceof tf.Tensor).toBe(true)
let results = logreg.predict(Xtest) // compute results of the training set
expect(results.arraySync()).toEqual([
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1]
])
expect(logreg.score(X, y) > 0.5).toBe(true)
}, 30000)
it('Test of the prediction with 3 classes', async function () {
let X = [
[0, -1],
Expand Down
30 changes: 18 additions & 12 deletions src/linear_model/SgdClassifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
* ==========================================================================
*/

import { convertToNumericTensor1D, convertToNumericTensor2D } from '../utils'
import {
convertToNumericTensor1D_2D,
convertToNumericTensor2D
} from '../utils'
import {
Scikit2D,
Scikit1D,
Expand All @@ -23,8 +26,7 @@ import {
Tensor2D,
Tensor,
ModelCompileArgs,
ModelFitArgs,
RecursiveArray
ModelFitArgs
} from '../types'
import { OneHotEncoder } from '../preprocessing/OneHotEncoder'
import { assert } from '../typesUtils'
Expand Down Expand Up @@ -103,6 +105,7 @@ export class SGDClassifier extends ClassifierMixin {
lossType: LossTypes
oneHot: OneHotEncoder
tf: any
isMultiOutput: boolean

constructor({
modelFitArgs,
Expand All @@ -119,6 +122,7 @@ export class SGDClassifier extends ClassifierMixin {
this.denseLayerArgs = denseLayerArgs
this.optimizerType = optimizerType
this.lossType = lossType
this.isMultiOutput = false
// Next steps: Implement "drop" mechanics for OneHotEncoder
// There is a possibility to do a drop => if_binary which would
// squash down on the number of variables that we'd have to learn
Expand Down Expand Up @@ -200,12 +204,17 @@ export class SGDClassifier extends ClassifierMixin {
* // lr model weights have been updated
*/

public async fit(X: Scikit2D, y: Scikit1D): Promise<SGDClassifier> {
public async fit(
X: Scikit2D,
y: Scikit1D | Scikit2D
): Promise<SGDClassifier> {
let XTwoD = convertToNumericTensor2D(X)
let yOneD = convertToNumericTensor1D(y)
let yOneD = convertToNumericTensor1D_2D(y)

const yTwoD = this.initializeModelForClassification(yOneD)

if (yOneD.shape.length > 1) {
this.isMultiOutput = true
}
if (this.model.layers.length === 0) {
this.initializeModel(XTwoD, yTwoD)
}
Expand Down Expand Up @@ -344,6 +353,9 @@ export class SGDClassifier extends ClassifierMixin {
public predict(X: Scikit2D): Tensor1D {
assert(this.model.layers.length > 0, 'Need to call "fit" before "predict"')
const y2D = this.predictProba(X)
if (this.isMultiOutput) {
return this.tf.oneHot(y2D.argMax(1), y2D.shape[1])
}
return this.tf.tensor1d(this.oneHot.inverseTransform(y2D))
}

Expand Down Expand Up @@ -418,10 +430,4 @@ export class SGDClassifier extends ClassifierMixin {

return intercept
}

private getModelWeight(): Promise<RecursiveArray<number>> {
return Promise.all(
this.model.getWeights().map((weight: any) => weight.array())
)
}
}
13 changes: 12 additions & 1 deletion src/mixins.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Scikit2D, Scikit1D, Tensor2D, Tensor1D } from './types'
import { r2Score, accuracyScore } from './metrics/metrics'
import { Serialize } from './simpleSerializer'
import { assert, isScikit2D } from './typesUtils'
import { convertToNumericTensor1D_2D } from './utils'
export class TransformerMixin extends Serialize {
// We assume that fit and transform exist
[x: string]: any
Expand Down Expand Up @@ -35,8 +37,17 @@ export class ClassifierMixin extends Serialize {
[x: string]: any

EstimatorType = 'classifier'
public score(X: Scikit2D, y: Scikit1D): number {
public score(X: Scikit2D, y: Scikit1D | Scikit2D): number {
const yPred = this.predict(X)
const yTrue = convertToNumericTensor1D_2D(y)
assert(
yPred.shape.length === yTrue.shape.length,
"The shape of the model output doesn't match the shape of the actual y values"
)

if (isScikit2D(y)) {
return accuracyScore(yTrue.argMax(1) as Scikit1D, yPred.argMax(1))
}
return accuracyScore(y, yPred)
}
}
Expand Down

0 comments on commit 3e4ce65

Please sign in to comment.