Skip to content

Commit 1de9154

Browse files
yuandagitsfacebook-github-bot
authored andcommitted
feat: add classification functions (facebookincubator#11792)
Summary: Add the classification functions from presto into velox: https://prestodb.io/docs/current/functions/aggregate.html#classification-metrics-aggregate-functions Classification functions all use `FixedDoubleHistogram`, which is a data structure to represent the bucket of weights. The index of the bucket for the histogram is evenly distributed between the min and value values. For all of the classification functions, the only difference is the extraction phase. All other steps will be the same. At a high level: - addRawInput will add a value into either the true or false weight bucket. The bucket to add the value to will depend on the prediction value. The prediction value is linearly mapped into a bucket based on (min, max and bucketCount) by normalizing the prediction between min and max. - The schema of the intermediate states is [version header][bucket count][min][max][weights] Differential Revision: D66684198
1 parent ac13440 commit 1de9154

11 files changed

+1250
-7
lines changed

velox/docs/functions/presto/aggregate.rst

+195
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,201 @@ __ https://www.cse.ust.hk/~raywong/comp5331/References/EfficientComputationOfFre
411411
As ``approx_percentile(x, w, percentages)``, but with a maximum rank error
412412
of ``accuracy``.
413413

414+
Classification Metrics Aggregate Functions
415+
------------------------------------------
416+
417+
The following functions each measure how some metric of a binary
418+
`confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_ changes as a function of
419+
classification thresholds. They are meant to be used in conjunction.
420+
421+
For example, to find the `precision-recall curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_, use
422+
423+
.. code-block:: none
424+
425+
WITH
426+
recall_precision AS (
427+
SELECT
428+
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls,
429+
CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions
430+
FROM
431+
classification_dataset
432+
)
433+
SELECT
434+
recall,
435+
precision
436+
FROM
437+
recall_precision
438+
CROSS JOIN UNNEST(recalls, precisions) AS t(recall, precision)
439+
440+
To get the corresponding thresholds for these values, use
441+
442+
.. code-block:: none
443+
444+
WITH
445+
recall_precision AS (
446+
SELECT
447+
CLASSIFICATION_THRESHOLDS(10000, correct, pred) AS thresholds,
448+
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls,
449+
CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions
450+
FROM
451+
classification_dataset
452+
)
453+
SELECT
454+
threshold,
455+
recall,
456+
precision
457+
FROM
458+
recall_precision
459+
CROSS JOIN UNNEST(thresholds, recalls, precisions) AS t(threshold, recall, precision)
460+
461+
To find the `ROC curve <https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_, use
462+
463+
.. code-block:: none
464+
465+
WITH
466+
fallout_recall AS (
467+
SELECT
468+
CLASSIFICATION_FALLOUT(10000, correct, pred) AS fallouts,
469+
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls
470+
FROM
471+
classification_dataset
472+
)
473+
SELECT
474+
fallout
475+
recall,
476+
FROM
477+
recall_fallout
478+
CROSS JOIN UNNEST(fallouts, recalls) AS t(fallout, recall)
479+
480+
481+
.. function:: classification_miss_rate(buckets, y, x, weight) -> array<double>
482+
483+
Computes the miss-rate with up to ``buckets`` number of buckets. Returns
484+
an array of miss-rate values.
485+
486+
``y`` should be a boolean outcome value; ``x`` should be predictions, each
487+
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.
488+
489+
The
490+
`miss-rate <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates>`_
491+
is defined as a sequence whose :math:`j`-th entry is
492+
493+
.. math ::
494+
495+
{
496+
\sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right]
497+
\over
498+
\sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right]
499+
+
500+
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
501+
},
502+
503+
where :math:`t_j` is the :math:`j`-th smallest threshold,
504+
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
505+
entries of ``y``, ``x``, and ``weight``, respectively.
506+
507+
.. function:: classification_miss_rate(buckets, y, x) -> array<double>
508+
509+
This function is equivalent to the variant of
510+
:func:`!classification_miss_rate` that takes a ``weight``, with a per-item weight of ``1``.
511+
512+
.. function:: classification_fall_out(buckets, y, x, weight) -> array<double>
513+
514+
Computes the fall-out with up to ``buckets`` number of buckets. Returns
515+
an array of fall-out values.
516+
517+
``y`` should be a boolean outcome value; ``x`` should be predictions, each
518+
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.
519+
520+
The
521+
`fall-out <https://en.wikipedia.org/wiki/Information_retrieval#Fall-out>`_
522+
is defined as a sequence whose :math:`j`-th entry is
523+
524+
.. math ::
525+
526+
{
527+
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 0} \left[ w_i \right]
528+
\over
529+
\sum_{i \;|\; y_i = 0} \left[ w_i \right]
530+
},
531+
532+
where :math:`t_j` is the :math:`j`-th smallest threshold,
533+
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
534+
entries of ``y``, ``x``, and ``weight``, respectively.
535+
536+
.. function:: classification_fall_out(buckets, y, x) -> array<double>
537+
538+
This function is equivalent to the variant of
539+
:func:`!classification_fall_out` that takes a ``weight``, with a per-item weight of ``1``.
540+
541+
.. function:: classification_precision(buckets, y, x, weight) -> array<double>
542+
543+
Computes the precision with up to ``buckets`` number of buckets. Returns
544+
an array of precision values.
545+
546+
``y`` should be a boolean outcome value; ``x`` should be predictions, each
547+
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.
548+
549+
The
550+
`precision <https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values>`_
551+
is defined as a sequence whose :math:`j`-th entry is
552+
553+
.. math ::
554+
555+
{
556+
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
557+
\over
558+
\sum_{i \;|\; x_i > t_j} \left[ w_i \right]
559+
},
560+
561+
where :math:`t_j` is the :math:`j`-th smallest threshold,
562+
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
563+
entries of ``y``, ``x``, and ``weight``, respectively.
564+
565+
.. function:: classification_precision(buckets, y, x) -> array<double>
566+
567+
This function is equivalent to the variant of
568+
:func:`!classification_precision` that takes a ``weight``, with a per-item weight of ``1``.
569+
570+
.. function:: classification_recall(buckets, y, x, weight) -> array<double>
571+
572+
Computes the recall with up to ``buckets`` number of buckets. Returns
573+
an array of recall values.
574+
575+
``y`` should be a boolean outcome value; ``x`` should be predictions, each
576+
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.
577+
578+
The
579+
`recall <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`_
580+
is defined as a sequence whose :math:`j`-th entry is
581+
582+
.. math ::
583+
584+
{
585+
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
586+
\over
587+
\sum_{i \;|\; y_i = 1} \left[ w_i \right]
588+
},
589+
590+
where :math:`t_j` is the :math:`j`-th smallest threshold,
591+
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
592+
entries of ``y``, ``x``, and ``weight``, respectively.
593+
594+
.. function:: classification_recall(buckets, y, x) -> array<double>
595+
596+
This function is equivalent to the variant of
597+
:func:`!classification_recall` that takes a ``weight``, with a per-item weight of ``1``.
598+
599+
.. function:: classification_thresholds(buckets, y, x) -> array<double>
600+
601+
Computes the thresholds with up to ``buckets`` number of buckets. Returns
602+
an array of threshold values.
603+
604+
``y`` should be a boolean outcome value; ``x`` should be predictions, each
605+
between 0 and 1.
606+
607+
The thresholds are defined as a sequence whose :math:`j`-th entry is the :math:`j`-th smallest threshold.
608+
414609
Statistical Aggregate Functions
415610
-------------------------------
416611

velox/docs/functions/presto/coverage.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,11 @@ Here is a list of all scalar and aggregate Presto functions with functions that
325325
:func:`array_duplicates` :func:`dow` :func:`json_extract` :func:`repeat` st_union :func:`bool_and` :func:`rank`
326326
:func:`array_except` :func:`doy` :func:`json_extract_scalar` :func:`replace` st_within :func:`bool_or` :func:`row_number`
327327
:func:`array_frequency` :func:`e` :func:`json_format` replace_first st_x :func:`checksum`
328-
:func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax classification_fall_out
329-
:func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin classification_miss_rate
330-
:func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y classification_precision
331-
array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax classification_recall
332-
:func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin classification_thresholds
328+
:func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax :func: `classification_fall_out`
329+
:func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin :func: `classification_miss_rate`
330+
:func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y :func: `classification_precision`
331+
array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax :func: `classification_recall`
332+
:func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin :func: `classification_thresholds`
333333
array_max_by expand_envelope :func:`least` scale_qdigest :func:`starts_with` convex_hull_agg
334334
:func:`array_min` :func:`f_cdf` :func:`length` :func:`second` :func:`strpos` :func:`corr`
335335
array_min_by features :func:`levenshtein_distance` secure_rand :func:`strrpos` :func:`count`

velox/functions/prestosql/aggregates/AggregateNames.h

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ const char* const kBitwiseXor = "bitwise_xor_agg";
3232
const char* const kBoolAnd = "bool_and";
3333
const char* const kBoolOr = "bool_or";
3434
const char* const kChecksum = "checksum";
35+
const char* const kClassificationFallout = "classification_fall_out";
36+
const char* const kClassificationPrecision = "classification_precision";
37+
const char* const kClassificationRecall = "classification_recall";
38+
const char* const kClassificationMissRate = "classification_miss_rate";
39+
const char* const kClassificationThreshold = "classification_thresholds";
3540
const char* const kCorr = "corr";
3641
const char* const kCount = "count";
3742
const char* const kCountIf = "count_if";

velox/functions/prestosql/aggregates/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ velox_add_library(
2828
CountIfAggregate.cpp
2929
CovarianceAggregates.cpp
3030
ChecksumAggregate.cpp
31+
ClassificationAggregation.cpp
3132
EntropyAggregates.cpp
3233
GeometricMeanAggregate.cpp
3334
HistogramAggregate.cpp

0 commit comments

Comments
 (0)