Skip to content

Commit

Permalink
Add NFCorpus support
Browse files Browse the repository at this point in the history
Adds a Collection and Benchmark to support NFCorpus [1]
 
[1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval. ECIR '16.
  • Loading branch information
crystina-z authored Jun 18, 2020
1 parent 0278041 commit a9272d2
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 5 deletions.
108 changes: 106 additions & 2 deletions capreolus/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from profane import import_all_modules


import json
import re
import os
import gzip
import pickle
Expand Down Expand Up @@ -79,7 +79,6 @@ class Robust04Yang19(Benchmark):
"""Robust04 benchmark using the folds from Yang et al. [1]
[1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. 2019. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
"""

module_name = "robust04.yang19"
Expand All @@ -90,6 +89,111 @@ class Robust04Yang19(Benchmark):
query_type = "title"


@Benchmark.register
class NF(Benchmark):
""" A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1]
[1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016
"""

module_name = "nf"
dependencies = [Dependency(key="collection", module="collection", name="nf")]
config_spec = [
ConfigOption(key="labelrange", default_value="0-2", description="range of dataset qrels, options: 0-2, 1-3"),
ConfigOption(
key="fields",
default_value="all_fields",
description="query fields included in topic file, "
"options: 'all_fields', 'all_titles', 'nontopics', 'vid_title', 'vid_desc'",
),
]

fold_file = PACKAGE_PATH / "data" / "nf.json"

query_type = "title"

def __init__(self, config, provide, share_dependency_objects):
super().__init__(config, provide, share_dependency_objects)
fields, label_range = self.config["fields"], self.config["labelrange"]
self.field2kws = {
"all_fields": ["all"],
"nontopics": ["nontopic-titles"],
"vid_title": ["vid-titles"],
"vid_desc": ["vid-desc"],
"all_titles": ["nontopic-titles", "vid-titles", "nontopic-titles"],
}
self.labelrange2kw = {"0-2": "2-1-0", "1-3": "3-2-1"}

if fields not in self.field2kws:
raise ValueError(f"Unexpected fields value: {fields}, expect: {', '.join(self.field2kws.keys())}")
if label_range not in self.labelrange2kw:
raise ValueError(f"Unexpected label range: {label_range}, expect: {', '.join(self.field2kws.keys())}")

self.qrel_file = PACKAGE_PATH / "data" / f"qrels.nf.{label_range}.txt"
self.test_qrel_file = PACKAGE_PATH / "data" / f"test.qrels.nf.{label_range}.txt"
self.topic_file = PACKAGE_PATH / "data" / f"topics.nf.{fields}.txt"
self.download_if_missing()

def _transform_qid(self, raw):
""" NFCorpus dataset specific, remove prefix in query id since anserini convert all qid to integer """
return raw.replace("PLAIN-", "")

def download_if_missing(self):
if all([f.exists() for f in [self.topic_file, self.fold_file, self.qrel_file]]):
return

tmp_corpus_dir = self.collection.download_raw()
topic_f = open(self.topic_file, "w", encoding="utf-8")
qrel_f = open(self.qrel_file, "w", encoding="utf-8")
test_qrel_f = open(self.test_qrel_file, "w", encoding="utf-8")

set_names = ["train", "dev", "test"]
folds = {s: set() for s in set_names}
qrel_kw = self.labelrange2kw[self.config["labelrange"]]
for set_name in set_names:
with open(tmp_corpus_dir / f"{set_name}.{qrel_kw}.qrel") as f:
for line in f:
line = self._transform_qid(line)
qid = line.strip().split()[0]
folds[set_name].add(qid)
if set_name == "test":
test_qrel_f.write(line)
qrel_f.write(line)

files = [tmp_corpus_dir / f"{set_name}.{keyword}.queries" for keyword in self.field2kws[self.config["fields"]]]
qids2topics = self._align_queries(files, "title")

for qid, txts in qids2topics.items():
topic_f.write(topic_to_trectxt(qid, txts["title"]))

json.dump(
{"s1": {"train_qids": list(folds["train"]), "predict": {"dev": list(folds["dev"]), "test": list(folds["test"])}}},
open(self.fold_file, "w"),
)

topic_f.close()
qrel_f.close()
test_qrel_f.close()
logger.info(f"nf benchmark prepared")

def _align_queries(self, files, field, qid2queries=None):
if not qid2queries:
qid2queries = {}
for fn in files:
with open(fn, "r", encoding="utf-8") as f:
for line in f:
qid, txt = line.strip().split("\t")
qid = self._transform_qid(qid)
txt = " ".join(re.sub("[^A-Za-z]", " ", txt).split()[:1020])
if qid not in qid2queries:
qid2queries[qid] = {field: txt}
else:
if field in qid2queries[qid]:
logger.warning(f"Overwriting title for query {qid}")
qid2queries[qid][field] = txt
return qid2queries


@Benchmark.register
class ANTIQUE(Benchmark):
"""A Non-factoid Question Answering Benchmark from Hashemi et al. [1]
Expand Down
51 changes: 51 additions & 0 deletions capreolus/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,57 @@ def _validate_document_path(self, path):
return "dummy_trec_doc" in os.listdir(path)


@Collection.register
class NF(Collection):
module_name = "nf"
_path = PACKAGE_PATH / "data" / "nf-collection"
url = "http://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/nfcorpus.tar.gz"

collection_type = "TrecCollection"
generator_type = "DefaultLuceneDocumentGenerator"

def download_raw(self):
cachedir = self.get_cache_path()
tmp_dir = cachedir / "tmp"
tmp_tar_fn, tmp_corpus_dir = tmp_dir / "nfcorpus.tar.gz", tmp_dir / "nfcorpus"

os.makedirs(tmp_dir, exist_ok=True)

if not tmp_tar_fn.exists():
download_file(self.url, tmp_tar_fn)

with tarfile.open(tmp_tar_fn) as f:
f.extractall(tmp_dir)
return tmp_corpus_dir

def download_if_missing(self):
cachedir = self.get_cache_path()
document_dir = os.path.join(cachedir, "documents")
coll_filename = os.path.join(document_dir, "nf-collection.txt")
if os.path.exists(coll_filename):
return document_dir

os.makedirs(document_dir, exist_ok=True)
tmp_corpus_dir = self.download_raw()

inp_fns = [tmp_corpus_dir / f"{set_name}.docs" for set_name in ["train", "dev", "test"]]
print(inp_fns)
with open(coll_filename, "w", encoding="utf-8") as outp_file:
self._convert_to_trec(inp_fns, outp_file)
logger.info(f"nf collection file prepared, stored at {coll_filename}")

return document_dir

def _convert_to_trec(self, inp_fns, outp_file):
for inp_fn in inp_fns:
assert os.path.exists(inp_fn)

with open(inp_fn, "rt", encoding="utf-8") as f:
for line in f:
docid, doc = line.strip().split("\t")
outp_file.write(f"<DOC>\n<DOCNO>{docid}</DOCNO>\n<TEXT>\n{doc}\n</TEXT>\n</DOC>\n")


@Collection.register
class ANTIQUE(Collection):
module_name = "antique"
Expand Down
13 changes: 10 additions & 3 deletions capreolus/task/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
class RankTask(Task):
module_name = "rank"
requires_random_seed = False
config_spec = [ConfigOption("optimize", "map", "metric to maximize on the dev set"), ConfigOption("filter", False)]
config_keys_not_in_path = ["optimize"] # only used for choosing best result; does not affect search()
config_spec = [
ConfigOption("filter", False),
ConfigOption("optimize", "map", "metric to maximize on the dev set"),
ConfigOption("metrics", "default", "metrics reported for evaluation", value_type="strlist"),
]
config_keys_not_in_path = ["optimize", "metrics"] # affect only evaluation but not search()

dependencies = [
Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]),
Dependency(key="searcher", module="searcher", name="BM25"),
Expand Down Expand Up @@ -44,8 +49,10 @@ def search(self):
return search_results_folder

def evaluate(self):
metrics = self.config["metrics"] if list(self.config["metrics"]) != ["default"] else evaluator.DEFAULT_METRICS

best_results = evaluator.search_best_run(
self.get_results_path(), self.benchmark, primary_metric=self.config["optimize"], metrics=evaluator.DEFAULT_METRICS
self.get_results_path(), self.benchmark, primary_metric=self.config["optimize"], metrics=metrics
)

for fold, path in best_results["path"].items():
Expand Down

0 comments on commit a9272d2

Please sign in to comment.