Skip to content

Commit 536e0e2

Browse files
lingvo-botcopybara-github
authored andcommitted
Adding a test case
PiperOrigin-RevId: 824605655
1 parent 5ed7725 commit 536e0e2

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

lingvo/core/metrics_test.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
# ==============================================================================
1515
"""Tests for metrics."""
1616

17+
from absl.testing import parameterized
1718
import lingvo.compat as tf
1819
from lingvo.core import metrics
1920
from lingvo.core import py_utils
2021
from lingvo.core import test_utils
2122
import numpy as np
2223

2324

24-
class MetricsTest(test_utils.TestCase):
25+
class MetricsTest(parameterized.TestCase, test_utils.TestCase):
2526

2627
def testAverageMetric(self):
2728
m = metrics.AverageMetric()
@@ -179,7 +180,7 @@ def _CreateSummary(self, name):
179180
summary = m.Summary('test')
180181
# Reservoir sampling will sample values 5 and 3 to remain with the current
181182
# seed.
182-
self.assertEqual(2, len(summary.value))
183+
self.assertLen(summary.value, 2)
183184
self.assertEqual(5, summary.value[0].simple_value)
184185
self.assertEqual(3, summary.value[1].simple_value)
185186

@@ -273,5 +274,39 @@ def testGroupPairAUCMetric(self):
273274
left = right
274275
self.assertEqual(pair_m.value, group_m.value)
275276

277+
@parameterized.named_parameters(
278+
dict(
279+
testcase_name='all_positive_first',
280+
target=[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
281+
logits=[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
282+
expected_auc=0.0, # 0.0 because all pairs have the same label.
283+
),
284+
dict(
285+
testcase_name='all_negative_first',
286+
target=[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
287+
logits=[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0],
288+
expected_auc=0.0, # 0.0 because all pairs have the same label.
289+
),
290+
dict(
291+
testcase_name='all_positive_first_except_first_group',
292+
target=[0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
293+
logits=[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
294+
expected_auc=0.17,
295+
),
296+
)
297+
def testGroupPairAUCMetricTargetSorting(
298+
self, *, target, logits, expected_auc
299+
):
300+
if not metrics.HAS_SKLEARN:
301+
self.skipTest('sklearn is not installed.')
302+
group_m = metrics.GroupPairAUCMetric()
303+
group_ids = [0, 0, 1, 1, 2, 2, 3, 3]
304+
weight = [1.0] * 8
305+
group_m.UpdateRaw(
306+
group_ids=group_ids, target=target, logits=logits, weight=weight
307+
)
308+
self.assertAlmostEqual(group_m.value, expected_auc, places=2)
309+
310+
276311
if __name__ == '__main__':
277312
test_utils.main()

0 commit comments

Comments
 (0)