Skip to content

Commit

Permalink
add boolean parameter to relevance based calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
tobhey committed Aug 21, 2024
1 parent 59b4c3a commit 4847756
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,52 @@ interface RankMetricsCalculator {
* @param rankedResults the ranked results as a list of sorted lists (most relevant artifacts first). Each list represents one query of a source artifact.
* @param groundTruth the ground truth
* @param stringProvider a function to convert the ranked results and ground truth to strings
* @param rankedRelevances An optional list of lists representing the relevance scores associated with each ranked result.
* If provided, this list must correspond to `rankedResults` in structure. If `null`, the relevance scores are ignored.
* @param doubleProvider A function that converts the ranked relevances into their double representations.
* @param relevanceBasedInput the input for relevance based calculations
*
* @return the rank metrics result
*/
fun <T> calculateMetrics(
rankedResults: List<List<T>>,
groundTruth: Set<T>,
stringProvider: (T) -> String,
rankedRelevances: List<List<T>>?,
doubleProvider: (T) -> Double
relevanceBasedInput: RelevanceBasedInput<T>?
): SingleRankMetricsResult {
return calculateMetrics(
rankedResults.map { id -> id.map { stringProvider(it) } },
groundTruth.map { stringProvider(it) }.toSet(),
rankedRelevances?.map { id -> id.map { doubleProvider(it) } }
relevanceBasedInput?.let { rbi ->
RelevanceBasedInput(
rbi.rankedRelevances.map { id -> id.map { rbi.doubleProvider(it) } },
{ it },
rbi.biggerIsMoreSimilar
)
}
)
}

/**
* @param rankedRelevances An optional list of lists representing the relevance scores associated with each ranked result.
* If provided, this list must correspond to `rankedResults` in structure. If `null`, the relevance scores are ignored.
* @param doubleProvider A function that converts the ranked relevances into their double representations.
* @param biggerIsMoreSimilar Whether the relevance scores are more similar if bigger
*/
data class RelevanceBasedInput<T>(
val rankedRelevances: List<List<T>>,
val doubleProvider: ((T) -> Double),
val biggerIsMoreSimilar: Boolean
)

/**
* Calculates the metrics for the given ranked results.
* @param rankedResults the ranked results as a list of sorted lists (most relevant artifacts first). Each list represents one query of a source artifact.
* @param groundTruth the ground truth
* @param rankedRelevances An optional list of lists representing the relevance scores associated with each ranked result.
* If provided, this list must correspond to `rankedResults` in structure. If `null`, the relevance scores are ignored.
* @param relevanceBasedInput the input for relevance based calculation
* @return the rank metrics result
*/
fun calculateMetrics(
rankedResults: List<List<String>>,
groundTruth: Set<String>,
rankedRelevances: List<List<Double>>?
relevanceBasedInput: RelevanceBasedInput<Double>?
): SingleRankMetricsResult

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ fun calculateLAG(
fun calculateAUC(
rankedResults: List<List<String>>,
rankedRelevances: List<List<Double>>,
groundTruth: Set<String>
groundTruth: Set<String>,
biggerIsMoreSimilar: Boolean
): Double {
require(rankedResults.size == rankedRelevances.size) {
"Results and relevance lists must have the same size."
Expand All @@ -103,7 +104,7 @@ fun calculateAUC(
flattenedTPLabels += results.map { groundTruth.contains(it) }
}

return calculateAUC(calculateROC(flattenedRelevances, flattenedTPLabels))
return calculateAUC(calculateROC(flattenedRelevances, flattenedTPLabels, biggerIsMoreSimilar))
}

/**
Expand All @@ -127,13 +128,14 @@ fun calculateAUC(
*/
fun calculateROC(
relevances: List<Double>,
isTPLabels: List<Boolean>
isTPLabels: List<Boolean>,
biggerIsMoreSimilar: Boolean
): List<DoubleArray> {
require(relevances.size == isTPLabels.size) { "Relevances and labels must have the same length" }

// Create a list of pairs (relevance, isTPLabel) and sort it by relevance in descending order
val relevanceIsTPList: MutableList<Pair<Double, Boolean>> = relevances.zip(isTPLabels).toMutableList()
relevanceIsTPList.sortByDescending { it.first }
if (biggerIsMoreSimilar) relevanceIsTPList.sortByDescending { it.first } else relevanceIsTPList.sortBy { it.first }

// Initialize variables for TPR and FPR
val totalPositives = isTPLabels.count { it }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ internal class RankMetricsCalculatorImpl : RankMetricsCalculator {
override fun calculateMetrics(
rankedResults: List<List<String>>,
groundTruth: Set<String>,
rankedRelevances: List<List<Double>>?
relevanceBasedInput: RankMetricsCalculator.RelevanceBasedInput<Double>?
): SingleRankMetricsResult {
require(rankedResults.isNotEmpty())
require(rankedResults.all { it.size == rankedResults.first().size })
val map = calculateMAP(rankedResults, groundTruth)
val lag = calculateLAG(rankedResults, groundTruth)
val auc = rankedRelevances?.let { calculateAUC(rankedResults, it, groundTruth) }
val auc = relevanceBasedInput?.let { calculateAUC(rankedResults, it.rankedRelevances, groundTruth, it.biggerIsMoreSimilar) }
return SingleRankMetricsResult(map, lag, auc, groundTruth.size)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,22 @@ import java.io.File
@OptIn(ExperimentalCli::class)
class RankCommand(private val outputFileOption: SingleNullableOption<String>) : Subcommand("rank", "Calculates rank metrics") {
private val rankedListDirectoryOption by option(
ArgType.String,
shortName = "r",
description = "The directory of the ranked list files",
fullName = "ranked-list-directory"
ArgType.String, shortName = "r", description = "The directory of the ranked list files", fullName = "ranked-list-directory"
).required()
private val groundTruthFileOption by option(
ArgType.String,
shortName = "g",
description = "The ground truth file",
fullName = "ground-truth"
ArgType.String, shortName = "g", description = "The ground truth file", fullName = "ground-truth"
).required()
private val fileHeaderOption by option(ArgType.Boolean, description = "Whether the files have a header", fullName = "header").default(false)
private val rankedRelevanceListDirectoryOption by option(
ArgType.String,
shortName = "rrl",
description = "The directory of the ranked relevance list files",
fullName = "ranked-relevance-list-directory"
).default("")
)
private val biggerIsMoreSimilar by option(
ArgType.String, shortName = "b", description = "Whether the relevance scores are more similar if bigger", fullName = "bigger-is-more-similar"
)


override fun execute() {
println("Calculating rank metrics")
Expand All @@ -47,39 +45,41 @@ class RankCommand(private val outputFileOption: SingleNullableOption<String>) :
println("The provided path is not a directory")
return
}
val rankedResults: List<List<String>> =
rankedListDirectory.listFiles()?.filter { file ->
file.isFile
}?.map { file -> file.readLines().filter { it.isNotBlank() }.drop(if (fileHeaderOption) 1 else 0) } ?: emptyList()
val rankedResults: List<List<String>> = rankedListDirectory.listFiles()?.filter { file ->
file.isFile
}?.map { file -> file.readLines().filter { it.isNotBlank() }.drop(if (fileHeaderOption) 1 else 0) } ?: emptyList()
if (rankedResults.isEmpty()) {
println("No classification results found")
return
}
val groundTruth = groundTruthFile.readLines().filter { it.isNotBlank() }.drop(if (fileHeaderOption) 1 else 0).toSet()

var rankedRelevances: List<List<Double>>? = null
if (rankedRelevanceListDirectoryOption != "") {
val rankedRelevanceListDirectory = File(rankedRelevanceListDirectoryOption)
var relevanceBasedInput: RankMetricsCalculator.RelevanceBasedInput<Double>? = null
if (rankedRelevanceListDirectoryOption != null) {
val rankedRelevanceListDirectory = File(rankedRelevanceListDirectoryOption!!)
if (!rankedRelevanceListDirectory.exists() || !rankedRelevanceListDirectory.isDirectory) {
println("The directory of the ranked relevance list files does not exist or is not a directory")
return
}
rankedRelevances = rankedRelevanceListDirectory.listFiles()?.filter { file ->
val rankedRelevances = rankedRelevanceListDirectory.listFiles()?.filter { file ->
file.isFile
}?.map { file -> file.readLines().filter { it.isNotBlank() }.map { it.toDouble() }.drop(if (fileHeaderOption) 1 else 0) } ?: emptyList()
if (rankedRelevances.isEmpty()) {
println("No relevance scores found")
return
}
if (biggerIsMoreSimilar == null) {
throw IllegalArgumentException("ranked relevances and bigger is more similar can only occur together")
}
relevanceBasedInput = if (biggerIsMoreSimilar != null) RankMetricsCalculator.RelevanceBasedInput(
rankedRelevances, { it }, biggerIsMoreSimilar.toBoolean()
) else null
}
val rankMetrics = RankMetricsCalculator.Instance

val result =
rankMetrics.calculateMetrics(
rankedResults,
groundTruth,
rankedRelevances
)
val result = rankMetrics.calculateMetrics(
rankedResults, groundTruth, relevanceBasedInput
)
result.prettyPrint()

val output = outputFileOption.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ import org.springframework.web.bind.annotation.RestController
class ClassificationMetricsController {
@Operation(summary = "Check if the service is running")
@GetMapping
fun running(): String {
return "ClassificationMetricsController is running"
}
fun running(): String = "ClassificationMetricsController is running"

@Operation(summary = "Calculate classification metrics for one project")
@PostMapping
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,22 @@ class RankMetricsController {
@RequestBody body: RankMetricsRequest
): SingleRankMetricsResult {
val rankMetricsCalculator = RankMetricsCalculator.Instance
val result = rankMetricsCalculator.calculateMetrics(body.rankedResults, body.groundTruth, body.rankedRelevances)
val relevanceBasedInput = relevanceBasedInput(body)
val result = rankMetricsCalculator.calculateMetrics(body.rankedResults, body.groundTruth, relevanceBasedInput)
return result
}

private fun relevanceBasedInput(body: RankMetricsRequest): RankMetricsCalculator.RelevanceBasedInput<Double>? {
if ((body.rankedRelevances == null) != (body.biggerIsMoreSimilar == null)) {
throw IllegalArgumentException("ranked relevances and bigger is more similar can only occur together")
}
val relevanceBasedInput =
if (body.rankedRelevances != null && body.biggerIsMoreSimilar != null) RankMetricsCalculator.RelevanceBasedInput(
body.rankedRelevances, { it }, body.biggerIsMoreSimilar
) else null
return relevanceBasedInput
}

@Operation(summary = "Calculate rank metrics for multiple projects. Calculate the average and optionally a weighted average.")
@PostMapping("/average")
fun calculateMultipleRankMetrics(
Expand All @@ -37,19 +49,18 @@ class RankMetricsController {
val rankMetricsCalculator = RankMetricsCalculator.Instance

val requests = body.rankMetricsRequests
val results =
requests.map {
rankMetricsCalculator.calculateMetrics(it.rankedResults, it.groundTruth, it.rankedRelevances)
}
val results = requests.map {
val relevanceBasedInput = relevanceBasedInput(it)
rankMetricsCalculator.calculateMetrics(it.rankedResults, it.groundTruth, relevanceBasedInput)
}

val averages = rankMetricsCalculator.calculateAverages(results, body.weights)

return AverageRankMetricsResponse(averages)
}

data class AverageRankMetricsRequest(
val rankMetricsRequests: List<RankMetricsRequest>,
val weights: List<Int>? = null
val rankMetricsRequests: List<RankMetricsRequest>, val weights: List<Int>? = null
)

data class AverageRankMetricsResponse(
Expand All @@ -59,6 +70,7 @@ class RankMetricsController {
data class RankMetricsRequest(
val rankedResults: List<List<String>>,
val groundTruth: Set<String>,
val rankedRelevances: List<List<Double>>?
val rankedRelevances: List<List<Double>>?,
val biggerIsMoreSimilar: Boolean?
)
}

0 comments on commit 4847756

Please sign in to comment.