|
14 | 14 | # ============================================================================== |
15 | 15 | """Tests for metrics.""" |
16 | 16 |
|
| 17 | +from absl.testing import parameterized |
17 | 18 | import lingvo.compat as tf |
18 | 19 | from lingvo.core import metrics |
19 | 20 | from lingvo.core import py_utils |
20 | 21 | from lingvo.core import test_utils |
21 | 22 | import numpy as np |
22 | 23 |
|
23 | 24 |
|
24 | | -class MetricsTest(test_utils.TestCase): |
| 25 | +class MetricsTest(parameterized.TestCase, test_utils.TestCase): |
25 | 26 |
|
26 | 27 | def testAverageMetric(self): |
27 | 28 | m = metrics.AverageMetric() |
@@ -179,7 +180,7 @@ def _CreateSummary(self, name): |
179 | 180 | summary = m.Summary('test') |
180 | 181 | # Reservoir sampling will sample values 5 and 3 to remain with the current |
181 | 182 | # seed. |
182 | | - self.assertEqual(2, len(summary.value)) |
| 183 | + self.assertLen(summary.value, 2) |
183 | 184 | self.assertEqual(5, summary.value[0].simple_value) |
184 | 185 | self.assertEqual(3, summary.value[1].simple_value) |
185 | 186 |
|
@@ -273,5 +274,39 @@ def testGroupPairAUCMetric(self): |
273 | 274 | left = right |
274 | 275 | self.assertEqual(pair_m.value, group_m.value) |
275 | 276 |
|
| 277 | + @parameterized.named_parameters( |
| 278 | + dict( |
| 279 | + testcase_name='all_positive_first', |
| 280 | + target=[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], |
| 281 | + logits=[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0], |
| 282 | + expected_auc=0.0, # 0.0 because all pairs have the same label. |
| 283 | + ), |
| 284 | + dict( |
| 285 | + testcase_name='all_negative_first', |
| 286 | + target=[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], |
| 287 | + logits=[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0], |
| 288 | + expected_auc=0.0, # 0.0 because all pairs have the same label. |
| 289 | + ), |
| 290 | + dict( |
| 291 | + testcase_name='all_positive_first_except_first_group', |
| 292 | + target=[0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], |
| 293 | + logits=[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0], |
| 294 | + expected_auc=0.17, |
| 295 | + ), |
| 296 | + ) |
| 297 | + def testGroupPairAUCMetricTargetSorting( |
| 298 | + self, *, target, logits, expected_auc |
| 299 | + ): |
| 300 | + if not metrics.HAS_SKLEARN: |
| 301 | + self.skipTest('sklearn is not installed.') |
| 302 | + group_m = metrics.GroupPairAUCMetric() |
| 303 | + group_ids = [0, 0, 1, 1, 2, 2, 3, 3] |
| 304 | + weight = [1.0] * 8 |
| 305 | + group_m.UpdateRaw( |
| 306 | + group_ids=group_ids, target=target, logits=logits, weight=weight |
| 307 | + ) |
| 308 | + self.assertAlmostEqual(group_m.value, expected_auc, places=2) |
| 309 | + |
| 310 | + |
276 | 311 | if __name__ == '__main__': |
277 | 312 | test_utils.main() |
0 commit comments