Skip to content

Commit bd20a2a

Browse files
committed
Extract unzip via subprocess
Signed-off-by: Tully Foote <[email protected]>
1 parent e97f6f9 commit bd20a2a

File tree

1 file changed

+86
-10
lines changed

1 file changed

+86
-10
lines changed

ibpc_py/src/ibpc/ibpc.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import argparse
2+
import hashlib
23
import os
34
import shlex
5+
import shutil
46
import threading
7+
import subprocess
58

69
from rocker.core import DockerImageGenerator
710
from rocker.core import get_rocker_version
811
from rocker.core import RockerExtensionManager
912
from rocker.core import OPERATIONS_NON_INTERACTIVE
1013

1114
from io import BytesIO
12-
from urllib.request import urlopen
15+
from urllib.request import urlretrieve
1316
import urllib.request
1417
from zipfile import ZipFile
18+
from contextlib import nullcontext
1519

1620

1721
def get_bop_template(modelname):
18-
return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{modelname}/{modelname}"
22+
return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{modelname}/"
1923

2024

2125
def get_ipd_template(modelname):
22-
return f"https://huggingface.co/datasets/bop-benchmark/{modelname}/resolve/main/{modelname}"
26+
return f"https://huggingface.co/datasets/bop-benchmark/{modelname}/resolve/main/"
2327

2428

2529
bop_suffixes = [
@@ -33,21 +37,93 @@ def get_ipd_template(modelname):
3337
ipd_suffixes.append("_val.zip")
3438
ipd_suffixes.append("_test_all.z01")
3539

40+
lm_files = {
41+
"lm_base.zip": "a1d793837d4de0dbd33f04e8b04ce4403c909248c527b2d7d61ef5eac3ef2c39",
42+
"lm_models.zip": "cb5b5366ce620d41800c7941c2e770036c7c13c178514fa07e6a89fda5ff0e7f",
43+
"lm_test_all.zip": "28e65e9530b94a87c35f33cba81e8f37bc4d59f95755573dea6e9ca0492f00fe",
44+
"lm_train_pbr.zip": "b7814cc0cd8b6f0d9dddff7b3ded2a189eacfd2c19fa10b3e332f022930551a9",
45+
}
46+
47+
ipd_files = {
48+
"ipd_base.zip": "c4943d90040df0737ac617c30a9b4e451a7fc94d96c03406376ce34e5a9724d1",
49+
"ipd_models.zip": "e7435057b48c66faf3a10353a7ae0bffd63ec6351a422d2c97d4ca4b7e6b797a",
50+
"ipd_test_all.zip": "e1b042f046d7d07f8c8811f7739fb68a25ad8958d1b58c5cbc925f98096eb6f9",
51+
"ipd_train_pbr.zip": "6afde1861ce781adc33fcdb3c91335fa39c5e7208a0b20433deb21f92f3e9a94",
52+
"ipd_val.zip": "50df37c370557a3cccc11b2e6d5f37f13783159ed29f4886e09c9703c1cad8de",
53+
"ipd_test_all.z01": "25ce71feb7d9811db51772e44ebc981d57d9f10c91776707955ab1e616346cb3",
54+
}
55+
3656
available_datasets = {
37-
"ipd": (get_ipd_template("ipd"), ipd_suffixes),
38-
"lm": (get_bop_template("lm"), bop_suffixes),
57+
"ipd": (get_ipd_template("ipd"), ipd_files),
58+
"lm": (get_bop_template("lm"), lm_files),
3959
}
4060

4161

62+
def sha256_file(filename):
63+
block_size = 65536
64+
sha256 = hashlib.sha256()
65+
with open(filename, "rb") as fh:
66+
while True:
67+
data = fh.read(block_size)
68+
if not data:
69+
break
70+
sha256.update(data)
71+
return sha256.hexdigest()
72+
73+
4274
def fetch_dataset(dataset, output_path):
43-
(url_base, suffixes) = available_datasets[dataset]
44-
for suffix in suffixes:
75+
(url_base, files) = available_datasets[dataset]
76+
fetched_files = []
77+
for suffix in files.keys():
78+
79+
# Sorted so that zip comes before z01
4580

4681
url = url_base + suffix
82+
83+
outfile = os.path.basename(url)
84+
if os.path.exists(outfile):
85+
print(f"File {outfile} already present checking hash")
86+
computed_hash = sha256_file(outfile)
87+
expected_hash = files[suffix]
88+
if computed_hash == expected_hash:
89+
print(f"File {outfile} detected with expected sha256 skipping download")
90+
fetched_files.append(outfile)
91+
continue
92+
else:
93+
print(
94+
f"File {outfile}'s hash {computed_hash} didn't match the expected hash {expected_hash}, downloading again."
95+
)
96+
4797
print(f"Downloading from url: {url}")
48-
with urlopen(url) as zipurlfile:
49-
with ZipFile(BytesIO(zipurlfile.read())) as zfile:
50-
zfile.extractall(output_path)
98+
99+
(filename, headers) = urlretrieve(url, outfile)
100+
fetched_files.append(filename)
101+
102+
for filename in fetched_files:
103+
# Append shard if found
104+
if filename.endswith("01"):
105+
# Let 7z find the other files zipfile can't handle file sharding "multiple disks"
106+
fetched_files.remove(filename)
107+
108+
# Logic for combining files
109+
# orig_filename = filename[:-2] + "ip"
110+
# combined_filename = "combined_" + orig_filename
111+
# with open(combined_filename,'wb') as zipfile:
112+
# with open(orig_filename,'rb') as fd:
113+
# print(f"Appending shard {orig_filename} to {combined_filename}")
114+
# shutil.copyfileobj(fd, zipfile)
115+
# with open(filename,'rb') as fd:
116+
# print(f"Appending shard {filename} to {combined_filename}")
117+
# shutil.copyfileobj(fd, zipfile)
118+
# fetched_files.remove(orig_filename)
119+
# fetched_files.remove(filename)
120+
# fetched_files.append(combined_filename)
121+
122+
for filename in fetched_files:
123+
print(f"Unzipping {filename}")
124+
subprocess.check_call(["7z", "x", "-y", filename, f"-o{output_path}"])
125+
# with ZipFile(filename) as zfile:
126+
# zfile.extractall(output_path)
51127

52128

53129
def main():

0 commit comments

Comments
 (0)