Skip to content

Commit c63c430

Browse files
authored
Merge pull request #15 from Yadunund/ibp_download
downloading ipb
2 parents 5f3b702 + 450c2f7 commit c63c430

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

ibpc_py/src/ibpc/ibpc.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010

1111
from io import BytesIO
1212
from urllib.request import urlopen
13+
import urllib.request
1314
from zipfile import ZipFile
1415

1516

1617
def get_bop_template(modelname):
1718
return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{modelname}/{modelname}"
1819

1920

20-
available_datasets = {"lm": get_bop_template("lm")}
21+
def get_ipb_template(modelname):
22+
return f"https://huggingface.co/datasets/bop-benchmark/{modelname}/resolve/main/{modelname}"
23+
2124

2225
bop_suffixes = [
2326
"_base.zip",
@@ -26,11 +29,22 @@ def get_bop_template(modelname):
2629
"_train_pbr.zip",
2730
]
2831

32+
ipb_suffixes = [s for s in bop_suffixes]
33+
ipb_suffixes.append("_val.zip")
34+
ipb_suffixes.append("_test_all.z01")
35+
36+
available_datasets = {
37+
"ipb": (get_ipb_template("ipb"), ipb_suffixes),
38+
"lm": (get_bop_template("lm"), bop_suffixes),
39+
}
40+
2941

30-
def fetch_bop_dataset(dataset, output_path):
31-
for suffix in bop_suffixes:
42+
def fetch_dataset(dataset, output_path):
43+
(url_base, suffixes) = available_datasets[dataset]
44+
for suffix in suffixes:
3245

33-
url = get_bop_template(dataset) + suffix
46+
url = url_base + suffix
47+
print(f"Downloading from url: {url}")
3448
with urlopen(url) as zipurlfile:
3549
with ZipFile(BytesIO(zipurlfile.read())) as zfile:
3650
zfile.extractall(output_path)
@@ -55,7 +69,7 @@ def main():
5569
test_parser.add_argument("--debug-inside", action="store_true")
5670

5771
fetch_parser = sub_parsers.add_parser("fetch")
58-
fetch_parser.add_argument("dataset", choices=["lm"])
72+
fetch_parser.add_argument("dataset", choices=available_datasets.keys())
5973
fetch_parser.add_argument("--dataset-path", default=".")
6074

6175
extension_manager = RockerExtensionManager()
@@ -66,7 +80,7 @@ def main():
6680
args_dict = vars(args)
6781
if args.subparser_name == "fetch":
6882
print(f"Fetching dataset {args_dict['dataset']} to {args_dict['dataset_path']}")
69-
fetch_bop_dataset(args_dict["dataset"], args_dict["dataset_path"])
83+
fetch_dataset(args_dict["dataset"], args_dict["dataset_path"])
7084
print("Fetch complete")
7185
return
7286

0 commit comments

Comments
 (0)