Skip to content

Commit d748c22

Browse files
committed
switch to using files as memory is likely not enough for the whole ipd dataset for most people.
Also resolve sharding of size 1
1 parent b50d697 commit d748c22

File tree

1 file changed

+22
-30
lines changed

1 file changed

+22
-30
lines changed

ibpc_py/src/ibpc/ibpc.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import os
33
import shlex
4+
import shutil
45
import threading
56
import subprocess
67

@@ -10,7 +11,7 @@
1011
from rocker.core import OPERATIONS_NON_INTERACTIVE
1112

1213
from io import BytesIO
13-
from urllib.request import urlopen
14+
from urllib.request import urlretrieve
1415
import urllib.request
1516
from zipfile import ZipFile
1617
from contextlib import nullcontext
@@ -43,39 +44,30 @@ def get_ipd_template(modelname):
4344

4445
def fetch_dataset(dataset, output_path):
4546
(url_base, suffixes) = available_datasets[dataset]
46-
for suffix in sorted(suffixes):
47-
if suffix.endswith("01"):
48-
continue
47+
fetched_files = []
48+
for suffix in sorted(suffixes, reverse=True):
49+
4950
# Sorted so that zip comes before z01
5051

5152
url = url_base + suffix
52-
print(f"Downloading from url: {url}")
53-
url_two = None
54-
if url.endswith('_test_all.zip'):
55-
url_two = url[:-2]+"01"
5653

57-
58-
with urlopen(url) as zipurlfile:
59-
with urlopen(url_two) if url_two else nullcontext(BytesIO()) as zipulrfile2:
60-
if not url_two:
61-
zipulrfile2 = BytesIO() # Empty file to be ignored
62-
else:
63-
print(f"Fetching extra zip file: {url_two}")
64-
with ZipFile(BytesIO(zipurlfile.read() + zipulrfile2.read())) as zfile:
65-
zfile.extractall(output_path)
66-
# native zipfile doesn't support sharded zip files which are in the ipd dataset (see .z01)
67-
# zipfile.BadZipFile: zipfiles that span multiple disks are not supported
68-
# There's the same problem in the datasets module too
69-
70-
continue
71-
72-
basename = os.path.basename(url)
73-
zip_file = os.path.join(output_path, basename)
74-
urllib.request.urlretrieve(url, zip_file)
75-
if not suffix.endswith("zip"):
76-
# SKip the extraction of z01 making sure it is present next to the same named zip file.
77-
continue
78-
subprocess.check_call(["7z", "x", "-y", basename], cwd=output_path)
54+
print(f"Downloading from url: {url}")
55+
outfile = os.path.basename(url)
56+
(filename, headers) = urlretrieve(url, outfile)
57+
# Append shard if found
58+
if url.endswith("01"):
59+
orig_filename = filename[:-2] + "ip"
60+
print(f"Appending shard {filename} to {orig_filename}")
61+
with open(filename,'ab') as zipfile:
62+
with open(orig_filename,'rb') as fd:
63+
shutil.copyfileobj(fd, zipfile)
64+
else:
65+
fetched_files.append(filename)
66+
67+
for filename in fetched_files:
68+
print(f"Unzipping {filename}")
69+
with ZipFile(filename) as zfile:
70+
zfile.extractall(output_path)
7971

8072

8173
def main():

0 commit comments

Comments
 (0)