Skip to content

Commit 219298e

Browse files
committed
[Script] Add rsp_sha3_gen.py
1 parent 547559d commit 219298e

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

tests/scripts/rsp_sha3_gen.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#!/bin/python
2+
# Test case generator for SHA3 test vectors from NIST FIPS 202 standard
3+
# Author: Crt Vavros
4+
5+
import os, sys
6+
7+
supported_lengths = [
8+
256, 384, 512
9+
]
10+
11+
class TestVectors:
12+
def __init__(self):
13+
self.header = ''
14+
self.hash_len = 0
15+
16+
class MsgTestVector:
17+
def __init__(self):
18+
self.msg = ''
19+
self.msg_len = 0
20+
self.md =''
21+
22+
class MsgTestVectors(TestVectors):
23+
def __init__(self):
24+
self.entries = [MsgTestVector]
25+
26+
class MonteCarloTestVector:
27+
def __init__(self):
28+
self.count = 0
29+
self.md =''
30+
31+
class MonteCarloTestVectors(TestVectors):
32+
def __init__(self, seed: str):
33+
self.seed = seed
34+
self.entries = [MonteCarloTestVector]
35+
36+
def parse_header(file):
37+
file.seek(0)
38+
header = ''
39+
hash_len = 0
40+
for num, line in enumerate(file):
41+
line = line.strip()
42+
if line.startswith('#'):
43+
header += line[1:].strip() + "\n"
44+
elif "[L = " in line:
45+
hash_len = round(int(line.strip('[]= L').strip()))
46+
return (header, hash_len, num)
47+
elif len(line) != 0:
48+
print("warning: unexpected end of header or corrupted header at line: {}".format(num + 1))
49+
return (header.strip(), hash_len, num)
50+
51+
def parse_monte_carlo_entries(file, current_line):
52+
seed = ''
53+
entries = []
54+
tv = MonteCarloTestVector()
55+
for num, line in enumerate(file):
56+
line = line.strip()
57+
if 'Seed' in line:
58+
seed = line.replace('Seed', '')
59+
seed = seed.strip('= ')
60+
elif 'COUNT' in line:
61+
tv.count = int(line.replace('COUNT', '').strip('= '))
62+
elif 'MD' in line:
63+
tv.md = line.replace('MD', '').strip('= ')
64+
65+
if len(tv.md) != 0:
66+
entries.append(tv)
67+
tv = MonteCarloTestVector()
68+
return (seed, entries)
69+
70+
def parse_msg_test_vector_entries(file, current_line):
71+
entries = []
72+
tv = MsgTestVector()
73+
for num, line in enumerate(file):
74+
line = line.strip()
75+
if 'Msg' in line:
76+
tv.msg = line.replace('Msg', '').strip('= ')
77+
elif 'Len' in line:
78+
tv.msg_len = int(line.replace('Len', '').strip('= '))
79+
elif 'MD' in line:
80+
tv.md = line.replace('MD', '').strip('= ')
81+
82+
if len(tv.md) != 0:
83+
entries.append(tv)
84+
tv = MsgTestVector()
85+
return entries
86+
87+
88+
def parse_rsp(file_path):
89+
header = ''
90+
hash_len = 0
91+
with open(file_path) as f:
92+
header, hash_len, end_line_num = parse_header(f)
93+
if hash_len not in supported_lengths:
94+
return None
95+
96+
# Try to make sure that this is a SHA3 RSP file
97+
if "SHA3" not in header and "Sha3" not in header and "sha3" not in header:
98+
return None
99+
100+
if "Monte" in header:
101+
seed, entries = parse_monte_carlo_entries(f, end_line_num)
102+
tests = MonteCarloTestVectors(seed)
103+
tests.header = header
104+
tests.hash_len = hash_len
105+
tests.entries = entries
106+
return tests
107+
else:
108+
entries = parse_msg_test_vector_entries(f, end_line_num)
109+
tests = MsgTestVectors()
110+
tests.header = header
111+
tests.hash_len = hash_len
112+
tests.entries = entries
113+
return tests
114+
115+
def indent(text:str, width, ch=' '):
116+
padding = width * ch
117+
return ''.join(padding+line for line in text.splitlines(True))
118+
119+
def format_var(var: str, decl: bool, indent_size: int = 0, var_type = 'auto') -> str:
120+
str = f'{f"{var_type} " if decl else ""}{var};'
121+
if indent_size > 0:
122+
str = indent(str, indent_size)
123+
return str
124+
125+
def main():
126+
if len(sys.argv) < 2:
127+
print("Usage:\n rsp_sha3_gen.py <path_to_rsp_fle>")
128+
return 0
129+
elif os.path.splitext(sys.argv[1])[1].lower() != '.rsp':
130+
print("Invalid file!", file=sys.stderr)
131+
print("Usage:\n rsp_sha3_gen.py <path_to_rsp_fle>")
132+
return 1
133+
134+
tests = parse_rsp(sys.argv[1])
135+
if tests is None or len(tests.entries) == 0:
136+
print("Invalid file or unsupported SHA3 RSP test vector file!", file=sys.stderr)
137+
print("Usage:\n rsp_sha3_gen.py <path_to_rsp_fle>")
138+
return 1
139+
140+
indent_size = 4
141+
if isinstance(tests, MonteCarloTestVectors):
142+
print("/*NIST Monte Carlo tests")
143+
print(tests.header)
144+
print("*/\n{")
145+
print(format_var(f'hashes = monte_carlo_sha3_{tests.hash_len}( "{tests.seed}"_hex )', True, indent_size))
146+
for e in tests.entries:
147+
print(indent(f'REQUIRE_EQUAL( hashes[{e.count}], \"{e.md}\"_hex );', indent_size))
148+
print("}")
149+
else:
150+
out_file = os.path.splitext(sys.argv[1])[0] + '.hpp'
151+
with open(out_file, "w") as f:
152+
print("/* NIST tests", file=f)
153+
print(tests.header, file=f)
154+
print("*/\n{", file=f)
155+
declvar = True
156+
for tv in tests.entries:
157+
print(indent(f"// Len = { tv.msg_len }", indent_size), file=f)
158+
msg = f'"{ tv.msg }"_hex' if tv.msg_len > 0 else 'bytes()'
159+
print(format_var(f'msg = { msg }', declvar, indent_size), file=f)
160+
print(format_var(f'md = "{ tv.md }"_hex', declvar, indent_size), file=f)
161+
print(indent(f'REQUIRE_EQUAL( sha3_{tests.hash_len}( msg ), md );\n', indent_size), file=f)
162+
declvar = False
163+
print("}\n", file=f)
164+
print(f"Generated test(s) written to file: '{out_file}'" )
165+
return 0
166+
167+
if __name__ == "__main__":
168+
exit(main())

0 commit comments

Comments
 (0)