Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lee Yang <[email protected]>
  • Loading branch information
leewyang committed Jan 13, 2025
1 parent bf2a602 commit 13bdc99
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions user_tools/src/spark_rapids_tools/tools/qualx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def compute_accuracy(
return scores


def compute_precision_recall(results: pd.DataFrame, y: str, y_preds: Dict[str, str], threshold: float):
def compute_precision_recall(
results: pd.DataFrame, y: str, y_preds: Dict[str, str], threshold: float
) -> Tuple[Dict[str, float], Dict[str, float]]:
"""Compute precision and recall from a dataframe using a threshold for identifying true positives.
Parameters
Expand Down Expand Up @@ -255,7 +257,6 @@ def compute_precision_recall(results: pd.DataFrame, y: str, y_preds: Dict[str, s
for name, y_pred in y_preds.items():
tp = sum((results[y_pred] >= threshold) & (results[y] >= threshold))
fp = sum((results[y_pred] >= threshold) & (results[y] < threshold))
# tn = sum((results[y_pred] < threshold) & (results[y] < threshold))
fn = sum((results[y_pred] < threshold) & (results[y] >= threshold))

precision[name] = tp / (tp + fp) if (tp + fp) > 0 else np.nan
Expand Down

0 comments on commit 13bdc99

Please sign in to comment.