1010
1111from io import BytesIO
1212from urllib .request import urlopen
13+ import urllib .request
1314from zipfile import ZipFile
1415
1516
1617def get_bop_template (modelname ):
1718 return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{ modelname } /{ modelname } "
1819
1920
20- available_datasets = {"lm" : get_bop_template ("lm" )}
21+ def get_ipb_template (modelname ):
22+ return f"https://huggingface.co/datasets/bop-benchmark/{ modelname } /resolve/main/{ modelname } "
23+
2124
2225bop_suffixes = [
2326 "_base.zip" ,
@@ -26,11 +29,22 @@ def get_bop_template(modelname):
2629 "_train_pbr.zip" ,
2730]
2831
32+ ipb_suffixes = [s for s in bop_suffixes ]
33+ ipb_suffixes .append ("_val.zip" )
34+ ipb_suffixes .append ("_test_all.z01" )
35+
36+ available_datasets = {
37+ "ipb" : (get_ipb_template ("ipb" ), ipb_suffixes ),
38+ "lm" : (get_bop_template ("lm" ), bop_suffixes ),
39+ }
40+
2941
30- def fetch_bop_dataset (dataset , output_path ):
31- for suffix in bop_suffixes :
42+ def fetch_dataset (dataset , output_path ):
43+ (url_base , suffixes ) = available_datasets [dataset ]
44+ for suffix in suffixes :
3245
33- url = get_bop_template (dataset ) + suffix
46+ url = url_base + suffix
47+ print (f"Downloading from url: { url } " )
3448 with urlopen (url ) as zipurlfile :
3549 with ZipFile (BytesIO (zipurlfile .read ())) as zfile :
3650 zfile .extractall (output_path )
@@ -55,7 +69,7 @@ def main():
5569 test_parser .add_argument ("--debug-inside" , action = "store_true" )
5670
5771 fetch_parser = sub_parsers .add_parser ("fetch" )
58- fetch_parser .add_argument ("dataset" , choices = [ "lm" ] )
72+ fetch_parser .add_argument ("dataset" , choices = available_datasets . keys () )
5973 fetch_parser .add_argument ("--dataset-path" , default = "." )
6074
6175 extension_manager = RockerExtensionManager ()
@@ -66,7 +80,7 @@ def main():
6680 args_dict = vars (args )
6781 if args .subparser_name == "fetch" :
6882 print (f"Fetching dataset { args_dict ['dataset' ]} to { args_dict ['dataset_path' ]} " )
69- fetch_bop_dataset (args_dict ["dataset" ], args_dict ["dataset_path" ])
83+ fetch_dataset (args_dict ["dataset" ], args_dict ["dataset_path" ])
7084 print ("Fetch complete" )
7185 return
7286
0 commit comments