11import argparse
22import hashlib
33import os
4+ import pathlib
45import shlex
56import shutil
67import signal
@@ -103,7 +104,7 @@ def sha256_file(filename):
103104 return sha256 .hexdigest ()
104105
105106
106- def fetch_dataset (dataset , output_path ):
107+ def fetch_dataset (dataset , output_path , remove_zip_after_extract ):
107108 (url_base , files ) = available_datasets [dataset ]
108109 # Before we do anything make sure the directory exists
109110 dataset_dir = os .path .join (output_path , dataset )
@@ -165,6 +166,10 @@ def fetch_dataset(dataset, output_path):
165166 # with ZipFile(filename) as zfile:
166167 # zfile.extractall(output_path)
167168
169+ if remove_zip_after_extract :
170+ for filename in fetched_files :
171+ pathlib .Path (filename ).unlink (missing_ok = True )
172+
168173
169174def main ():
170175
@@ -191,6 +196,12 @@ def main():
191196 fetch_parser = sub_parsers .add_parser ("fetch" )
192197 fetch_parser .add_argument ("dataset" , choices = available_datasets .keys ())
193198 fetch_parser .add_argument ("--dataset-path" , default = "." )
199+ fetch_parser .add_argument (
200+ "--remove-zip-after-extract" ,
201+ default = False ,
202+ action = "store_true" ,
203+ help = "Remove the zip files after extracting to save disk space." ,
204+ )
194205
195206 extension_manager = RockerExtensionManager ()
196207
@@ -200,7 +211,9 @@ def main():
200211 dataset_name = args_dict ["dataset" ]
201212 dataset_directory = args_dict ["dataset_path" ]
202213 print (f"Fetching dataset { dataset_name } to { dataset_directory } " )
203- fetch_dataset (dataset_name , dataset_directory )
214+ fetch_dataset (
215+ dataset_name , dataset_directory , args_dict ["remove_zip_after_extract" ]
216+ )
204217 print ("Fetch complete" )
205218 return
206219
0 commit comments