Skip to content

Commit 8dfef3e

Browse files
committed
Extract unzip via subprocess
Signed-off-by: Tully Foote <[email protected]>
1 parent e97f6f9 commit 8dfef3e

File tree

1 file changed

+84
-10
lines changed

1 file changed

+84
-10
lines changed

ibpc_py/src/ibpc/ibpc.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import argparse
2+
import hashlib
23
import os
34
import shlex
5+
import shutil
46
import threading
7+
import subprocess
58

69
from rocker.core import DockerImageGenerator
710
from rocker.core import get_rocker_version
811
from rocker.core import RockerExtensionManager
912
from rocker.core import OPERATIONS_NON_INTERACTIVE
1013

1114
from io import BytesIO
12-
from urllib.request import urlopen
15+
from urllib.request import urlretrieve
1316
import urllib.request
1417
from zipfile import ZipFile
18+
from contextlib import nullcontext
1519

1620

1721
def get_bop_template(modelname):
18-
return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{modelname}/{modelname}"
22+
return f"https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/{modelname}/"
1923

2024

2125
def get_ipd_template(modelname):
22-
return f"https://huggingface.co/datasets/bop-benchmark/{modelname}/resolve/main/{modelname}"
26+
return f"https://huggingface.co/datasets/bop-benchmark/{modelname}/resolve/main/"
2327

2428

2529
bop_suffixes = [
@@ -33,21 +37,91 @@ def get_ipd_template(modelname):
3337
ipd_suffixes.append("_val.zip")
3438
ipd_suffixes.append("_test_all.z01")
3539

40+
lm_files = {
41+
"lm_base.zip": 'a1d793837d4de0dbd33f04e8b04ce4403c909248c527b2d7d61ef5eac3ef2c39',
42+
"lm_models.zip": 'cb5b5366ce620d41800c7941c2e770036c7c13c178514fa07e6a89fda5ff0e7f',
43+
"lm_test_all.zip": '28e65e9530b94a87c35f33cba81e8f37bc4d59f95755573dea6e9ca0492f00fe',
44+
"lm_train_pbr.zip": 'b7814cc0cd8b6f0d9dddff7b3ded2a189eacfd2c19fa10b3e332f022930551a9',
45+
}
46+
47+
ipd_files = {
48+
"ipd_base.zip": 'c4943d90040df0737ac617c30a9b4e451a7fc94d96c03406376ce34e5a9724d1',
49+
"ipd_models.zip": 'e7435057b48c66faf3a10353a7ae0bffd63ec6351a422d2c97d4ca4b7e6b797a',
50+
"ipd_test_all.zip": 'e1b042f046d7d07f8c8811f7739fb68a25ad8958d1b58c5cbc925f98096eb6f9',
51+
"ipd_train_pbr.zip": '6afde1861ce781adc33fcdb3c91335fa39c5e7208a0b20433deb21f92f3e9a94',
52+
"ipd_val.zip": '50df37c370557a3cccc11b2e6d5f37f13783159ed29f4886e09c9703c1cad8de',
53+
"ipd_test_all.z01": '25ce71feb7d9811db51772e44ebc981d57d9f10c91776707955ab1e616346cb3',
54+
}
55+
3656
available_datasets = {
37-
"ipd": (get_ipd_template("ipd"), ipd_suffixes),
38-
"lm": (get_bop_template("lm"), bop_suffixes),
57+
"ipd": (get_ipd_template("ipd"), ipd_files),
58+
"lm": (get_bop_template("lm"), lm_files),
3959
}
4060

61+
def sha256_file(filename):
62+
block_size = 65536
63+
sha256 = hashlib.sha256()
64+
with open(filename, 'rb') as fh:
65+
while True:
66+
data = fh.read(block_size)
67+
if not data:
68+
break
69+
sha256.update(data)
70+
return sha256.hexdigest()
71+
4172

4273
def fetch_dataset(dataset, output_path):
43-
(url_base, suffixes) = available_datasets[dataset]
44-
for suffix in suffixes:
74+
(url_base, files) = available_datasets[dataset]
75+
fetched_files = []
76+
for suffix in files.keys():
77+
78+
# Sorted so that zip comes before z01
4579

4680
url = url_base + suffix
81+
82+
83+
outfile = os.path.basename(url)
84+
if os.path.exists(outfile):
85+
print(f"File {outfile} already present checking hash")
86+
computed_hash = sha256_file(outfile)
87+
expected_hash = files[suffix]
88+
if computed_hash == expected_hash:
89+
print(f"File {outfile} detected with expected sha256 skipping download")
90+
fetched_files.append(outfile)
91+
continue
92+
else:
93+
print(f"File {outfile}'s hash {computed_hash} didn't match the expected hash {expected_hash}, downloading again.")
94+
4795
print(f"Downloading from url: {url}")
48-
with urlopen(url) as zipurlfile:
49-
with ZipFile(BytesIO(zipurlfile.read())) as zfile:
50-
zfile.extractall(output_path)
96+
97+
(filename, headers) = urlretrieve(url, outfile)
98+
fetched_files.append(filename)
99+
100+
for filename in fetched_files:
101+
# Append shard if found
102+
if filename.endswith("01"):
103+
# Let 7z find the other files zipfile can't handle file sharding "multiple disks"
104+
fetched_files.remove(filename)
105+
106+
# Logic for combining files
107+
#orig_filename = filename[:-2] + "ip"
108+
#combined_filename = "combined_" + orig_filename
109+
#with open(combined_filename,'wb') as zipfile:
110+
# with open(orig_filename,'rb') as fd:
111+
# print(f"Appending shard {orig_filename} to {combined_filename}")
112+
# shutil.copyfileobj(fd, zipfile)
113+
# with open(filename,'rb') as fd:
114+
# print(f"Appending shard {filename} to {combined_filename}")
115+
# shutil.copyfileobj(fd, zipfile)
116+
#fetched_files.remove(orig_filename)
117+
#fetched_files.remove(filename)
118+
#fetched_files.append(combined_filename)
119+
120+
for filename in fetched_files:
121+
print(f"Unzipping {filename}")
122+
subprocess.check_call(['7z', 'x', '-y', filename, f"-o{output_path}"])
123+
#with ZipFile(filename) as zfile:
124+
# zfile.extractall(output_path)
51125

52126

53127
def main():

0 commit comments

Comments
 (0)