From 936b448c209fa683beef2dfc3d7ead2c0ecb35e9 Mon Sep 17 00:00:00 2001 From: lukonik <81145822+lukonik@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:00:22 +0400 Subject: [PATCH] Subject: Add R2Score metric. (#8169) (#8353) Body: FEATURE Co-authored-by: Matthew Soulanille --- tfjs-layers/src/exports_metrics.ts | 19 +++++++++++++++++++ tfjs-layers/src/metrics.ts | 12 +++++++++--- tfjs-layers/src/metrics_test.ts | 23 ++++++++++++++++++++++- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tfjs-layers/src/exports_metrics.ts b/tfjs-layers/src/exports_metrics.ts index dd6472c34f3..84ffc6e220d 100644 --- a/tfjs-layers/src/exports_metrics.ts +++ b/tfjs-layers/src/exports_metrics.ts @@ -314,3 +314,22 @@ export function MSE(yTrue: Tensor, yPred: Tensor): Tensor { export function mse(yTrue: Tensor, yPred: Tensor): Tensor { return losses.meanSquaredError(yTrue, yPred); } + +/** + * Computes R2 score. + * + * ```js + * const yTrue = tf.tensor2d([[0, 1], [3, 4]]); + * const yPred = tf.tensor2d([[0, 1], [-3, -4]]); + * const r2Score = tf.metrics.r2Score(yTrue, yPred); + * r2Score.print(); + * ``` + * @param yTrue Truth Tensor. + * @param yPred Prediction Tensor. + * @return R2 score Tensor. + * + * @doc {heading: 'Metrics', namespace: 'metrics'} + */ +export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor { + return metrics.r2Score(yTrue, yPred); +} diff --git a/tfjs-layers/src/metrics.ts b/tfjs-layers/src/metrics.ts index a8080d8bf89..7c0f52d41a2 100644 --- a/tfjs-layers/src/metrics.ts +++ b/tfjs-layers/src/metrics.ts @@ -17,9 +17,7 @@ import {Tensor, tidy} from '@tensorflow/tfjs-core'; import * as K from './backend/tfjs_backend'; import {NotImplementedError, ValueError} from './errors'; -import {categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses'; -import {binaryCrossentropy as lossBinaryCrossentropy} from './losses'; -import {lossesMap} from './losses'; +import {binaryCrossentropy as lossBinaryCrossentropy, categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, lossesMap, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses'; import {LossOrMetricFn} from './types'; import * as util from './utils/generic_utils'; @@ -112,6 +110,14 @@ export function sparseTopKCategoricalAccuracy( throw new NotImplementedError(); } +export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor { + return tidy(() => { + const sumSquaresResiduals = yTrue.sub(yPred).square().sum(); + const sumSquares = yTrue.sub(yTrue.mean()).square().sum(); + return tfc.scalar(1).sub(sumSquaresResiduals.div(sumSquares)); + }); +} + // Aliases. export const mse = meanSquaredError; export const MSE = meanSquaredError; diff --git a/tfjs-layers/src/metrics_test.ts b/tfjs-layers/src/metrics_test.ts index 3bcb3308bca..e34852d9f7d 100644 --- a/tfjs-layers/src/metrics_test.ts +++ b/tfjs-layers/src/metrics_test.ts @@ -16,7 +16,7 @@ import {scalar, Tensor, tensor, tensor1d, tensor2d} from '@tensorflow/tfjs-core' import {setEpsilon} from './backend/common'; import * as tfl from './index'; -import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName} from './metrics'; +import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName, r2Score} from './metrics'; import {LossOrMetricFn} from './types'; import {describeMathCPUAndGPU, describeMathCPUAndWebGL2, expectTensorsClose} from './utils/test_utils'; @@ -283,6 +283,27 @@ describeMathCPUAndGPU('recall metric', () => { }); }); +describeMathCPUAndGPU('r2Score', () => { + it('1D', () => { + const yTrue = tensor1d([3, -0.5, 2, 7, 4.2, 8.5, 1.3, 2.8, 6.7, 9.0]); + const yPred = tensor1d([2.5, 0.0, 2.1, 7.8, 4.0, 8.2, 1.4, 2.9, 6.5, 9.1]); + const score = r2Score(yTrue, yPred); + expectTensorsClose(score, scalar(0.985)); + }); + it('2D', () => { + const yTrue = tensor2d([ + [3, 2.5], [-0.5, 3.2], [2, 1.9], [7, 5.1], [4.2, 3.8], [8.5, 7.4], + [1.3, 0.6], [2.8, 2.1], [6.7, 5.3], [9.0, 8.7] + ]); + const yPred = tensor2d([ + [2.7, 2.3], [0.0, 3.1], [2.1, 1.8], [6.8, 5.0], [4.1, 3.7], [8.4, 7.2], + [1.4, 0.7], [2.9, 2.2], [6.6, 5.2], [9.2, 8.9] + ]); + const score = r2Score(yTrue, yPred); + expectTensorsClose(score, scalar(0.995)); + }); +}); + describe('metrics.get', () => { it('valid name, not alias', () => { expect(get('binaryAccuracy') === get('categoricalAccuracy')).toEqual(false);