|
1 | 1 | import argparse |
2 | 2 | import os |
3 | 3 | import shlex |
| 4 | +import shutil |
4 | 5 | import threading |
5 | 6 | import subprocess |
6 | 7 |
|
|
10 | 11 | from rocker.core import OPERATIONS_NON_INTERACTIVE |
11 | 12 |
|
12 | 13 | from io import BytesIO |
13 | | -from urllib.request import urlopen |
| 14 | +from urllib.request import urlretrieve |
14 | 15 | import urllib.request |
15 | 16 | from zipfile import ZipFile |
16 | 17 | from contextlib import nullcontext |
@@ -43,39 +44,30 @@ def get_ipd_template(modelname): |
43 | 44 |
|
44 | 45 | def fetch_dataset(dataset, output_path): |
45 | 46 | (url_base, suffixes) = available_datasets[dataset] |
46 | | - for suffix in sorted(suffixes): |
47 | | - if suffix.endswith("01"): |
48 | | - continue |
| 47 | + fetched_files = [] |
| 48 | + for suffix in sorted(suffixes, reverse=True): |
| 49 | + |
49 | 50 | # Sorted so that zip comes before z01 |
50 | 51 |
|
51 | 52 | url = url_base + suffix |
52 | | - print(f"Downloading from url: {url}") |
53 | | - url_two = None |
54 | | - if url.endswith('_test_all.zip'): |
55 | | - url_two = url[:-2]+"01" |
56 | 53 |
|
57 | | - |
58 | | - with urlopen(url) as zipurlfile: |
59 | | - with urlopen(url_two) if url_two else nullcontext(BytesIO()) as zipulrfile2: |
60 | | - if not url_two: |
61 | | - zipulrfile2 = BytesIO() # Empty file to be ignored |
62 | | - else: |
63 | | - print(f"Fetching extra zip file: {url_two}") |
64 | | - with ZipFile(BytesIO(zipurlfile.read() + zipulrfile2.read())) as zfile: |
65 | | - zfile.extractall(output_path) |
66 | | - # native zipfile doesn't support sharded zip files which are in the ipd dataset (see .z01) |
67 | | - # zipfile.BadZipFile: zipfiles that span multiple disks are not supported |
68 | | - # There's the same problem in the datasets module too |
69 | | - |
70 | | - continue |
71 | | - |
72 | | - basename = os.path.basename(url) |
73 | | - zip_file = os.path.join(output_path, basename) |
74 | | - urllib.request.urlretrieve(url, zip_file) |
75 | | - if not suffix.endswith("zip"): |
76 | | - # SKip the extraction of z01 making sure it is present next to the same named zip file. |
77 | | - continue |
78 | | - subprocess.check_call(["7z", "x", "-y", basename], cwd=output_path) |
| 54 | + print(f"Downloading from url: {url}") |
| 55 | + outfile = os.path.basename(url) |
| 56 | + (filename, headers) = urlretrieve(url, outfile) |
| 57 | + # Append shard if found |
| 58 | + if url.endswith("01"): |
| 59 | + orig_filename = filename[:-2] + "ip" |
| 60 | + print(f"Appending shard {filename} to {orig_filename}") |
| 61 | + with open(filename,'ab') as zipfile: |
| 62 | + with open(orig_filename,'rb') as fd: |
| 63 | + shutil.copyfileobj(fd, zipfile) |
| 64 | + else: |
| 65 | + fetched_files.append(filename) |
| 66 | + |
| 67 | + for filename in fetched_files: |
| 68 | + print(f"Unzipping {filename}") |
| 69 | + with ZipFile(filename) as zfile: |
| 70 | + zfile.extractall(output_path) |
79 | 71 |
|
80 | 72 |
|
81 | 73 | def main(): |
|
0 commit comments