diff --git a/ablation/collate.py b/ablation/collate.py index ade722865..3cd530bcd 100644 --- a/ablation/collate.py +++ b/ablation/collate.py @@ -1,5 +1,6 @@ import logging import os +from typing import Collection, Optional import click import pandas as pd @@ -24,12 +25,13 @@ def read_collation() -> pd.DataFrame: return pykeen_report.utils.read_ablation_collation(COLLATION_PATH) -def collate(key: str = 'hits@10') -> pd.DataFrame: +def collate(key: str = 'hits@10', additional_metrics: Optional[Collection[str]] = None) -> pd.DataFrame: """Collate all results for a given metric.""" return pykeen_report.utils.collate_ablation( results_directory=RESULTS, output_path=COLLATION_PATH, key=key, + additional_metrics=additional_metrics, ) diff --git a/src/pykeen_report/utils.py b/src/pykeen_report/utils.py index 6bba6b056..f80ae43a0 100644 --- a/src/pykeen_report/utils.py +++ b/src/pykeen_report/utils.py @@ -7,7 +7,7 @@ import os from copy import deepcopy from pathlib import Path -from typing import Any, Iterable, Mapping, Optional, Type, Union +from typing import Any, Collection, Iterable, Mapping, Optional, Type, Union import pandas as pd from tqdm import tqdm @@ -33,11 +33,14 @@ def collate_ablation( results_directory: str, output_path: str, key: str, + additional_metrics: Optional[Collection[str]] = None, ) -> pd.DataFrame: """Collate all results for a given metric. :param key: The metric which you care about. Should be the same one against which you optimized + :param additional_metrics: + Additional metrics to collect. """ columns = [ 'searcher', @@ -55,6 +58,9 @@ def collate_ablation( 'evaluation_time', key, ] + if additional_metrics is None: + additional_metrics = [] + columns += list(additional_metrics) directories = [ (directory, filenames)