Skip to content

Commit 13d65f6

Browse files
committed
vendor @ValarDragon's MDS generation script
This commit vendors the script at https://gist.github.com/ValarDragon/e7dab59889157758e8469d9e146514a4
1 parent dea5d3f commit 13d65f6

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed

LICENSE

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SPDX: "MIT OR Apache-2.0"
2+
3+
This does *not* apply to the supplementary files in `vendor/`, which have
4+
their own licensing.

vendor/generate_mds.sage

+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# This is a notably edited copy of STARKWare's code for generating MDS.
2+
# New Fields, optimized MDS matrix generation for low state size,
3+
# Poseidon parameterization, and generating Cpp/Rust code has been added here.
4+
# Copying their license here
5+
# Copyright 2019 StarkWare Industries Ltd.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License").
8+
# You may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# https://www.starkware.co/open-source-license/
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions
17+
# and limitations under the License.
18+
# language enum
19+
cpp = 0
20+
rust = 1
21+
# Prime fields.
22+
F61 = GF(2**61 + 20 * 2**32 + 1)
23+
F81 = GF(2**81 + 80 * 2**64 + 1)
24+
F91 = GF(2**91 + 5 * 2**64 + 1)
25+
F125 = GF(2**125 + 266 * 2**64 + 1)
26+
F161 = GF(2**161 + 23 * 2**128 + 1)
27+
F253 = GF(2**253 + 2**199 + 1)
28+
AltBn = GF(21888242871839275222246405745257275088548364400416034343698204186575808495617)
29+
bw6 = GF(258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177)
30+
# Binary fields.
31+
X = GF(2)['X'].gen()
32+
Bin63 = GF(2**63, name='a', modulus=X**63 + X + 1)
33+
Bin81 = GF(2**81, name='a', modulus=X**81 + X**4 + 1)
34+
Bin91 = GF(2**91, name='a', modulus=X**91 + X**8 + X**5 + X + 1)
35+
Bin127 = GF(2**127, name='a', modulus=X**127 + X + 1)
36+
Bin161 = GF(2**161, name='a', modulus=X**161 + X**18 + 1)
37+
Bin255 = GF(2**255, name='a', modulus=X**255 + X**5 + X**3 + X**2 + 1)
38+
def sponge(permutation_func, inputs, params):
39+
"""
40+
Applies the sponge construction to permutation_func.
41+
inputs should be a vector of field elements whose size is divisible by
42+
params.r.
43+
permutation_func should be a function which gets (state, params) where state
44+
is a vector of params.m field elements, and returns a vector of params.m
45+
field elements.
46+
"""
47+
assert parent(inputs) == VectorSpace(params.field, len(inputs)), \
48+
'inputs must be a vector of field elements. Found: %r' % parent(inputs)
49+
assert len(inputs) % params.r == 0, \
50+
'Number of field elements must be divisible by %s. Found: %s' % (
51+
params.r, len(inputs))
52+
state = vector([params.field(0)] * params.m)
53+
for i in xrange(0, len(inputs), params.r):
54+
state[:params.r] += inputs[i:i+params.r]
55+
state = permutation_func(state, params)
56+
# We do not support more than r output elements, since this requires
57+
# additional invocations of permutation_func.
58+
assert params.output_size <= params.r
59+
return state[:params.output_size]
60+
def generate_round_constant(fn_name, field, idx):
61+
"""
62+
Returns a field element based on the result of sha256.
63+
The input to sha256 is the concatenation of the name of the hash function
64+
and an index.
65+
For example, the first element for MiMC will be computed using the value
66+
of sha256('MiMC0').
67+
"""
68+
from hashlib import sha256
69+
val = int(sha256('%s%d' % (fn_name, idx)).hexdigest(), 16)
70+
if field.is_prime_field():
71+
return field(val)
72+
else:
73+
return int2field(field, val % field.order())
74+
def int2field(field, val):
75+
"""
76+
Converts val to an element of a binary field according to the binary
77+
representation of val.
78+
For example, 11=0b1011 is converted to 1*a^3 + 0*a^2 + 1*a + 1.
79+
"""
80+
assert field.characteristic() == 2
81+
assert 0 <= val < field.order(), \
82+
'Value %d out of range. Expected 0 <= val < %d.' % (val, field.order())
83+
res = field(map(int, bin(val)[2:][::-1]))
84+
assert res.integer_representation() == val
85+
return res
86+
def binary_vector(field, values):
87+
"""
88+
Converts a list of integers to field elements using int2field.
89+
"""
90+
return vector(field, [int2field(field, val) for val in values])
91+
def binary_matrix(field, values):
92+
"""
93+
Converts a list of lists of integers to field elements using int2field.
94+
"""
95+
return matrix(field, [[int2field(field, val) for val in row]
96+
for row in values])
97+
def generate_mds_matrix(name, field, m, optimize_mds=True):
98+
"""
99+
Generates an MDS matrix of size m x m over the given field, with no
100+
eigenvalues in the field.
101+
Given two disjoint sets of size m: {x_1, ..., x_m}, {y_1, ..., y_m} we set
102+
A_{ij} = 1 / (x_i - y_j).
103+
"""
104+
for attempt in xrange(100):
105+
x_values = [generate_round_constant(name + 'x', field, attempt * m + i)
106+
for i in xrange(m)]
107+
y_values = [generate_round_constant(name + 'y', field, attempt * m + i)
108+
for i in xrange(m)]
109+
# Make sure the values are distinct.
110+
assert len(set(x_values + y_values)) == 2 * m, \
111+
'The values of x_values and y_values are not distinct'
112+
mds_proto = ([[1 / (x_values[i] - y_values[j]) for j in xrange(m)]
113+
for i in xrange(m)])
114+
if optimize_mds:
115+
# massive reduction in constraint complexity, and computation time
116+
# These are near-MDS matrices
117+
if m == 3:
118+
# [[1, 0, 1],
119+
# [1, 1, 0],
120+
# [0, 1, 1]]
121+
mds_proto = [[field(1), field(0), field(1)],
122+
[field(1), field(1), field(0)],
123+
[field(0), field(1), field(1)]]
124+
elif m == 4:
125+
mds_proto = [[field(1) if x != y else field(0) for x in range(m)] for y in range(m)]
126+
else:
127+
raise RuntimeError
128+
mds = matrix(mds_proto)
129+
return mds
130+
mds = matrix(mds_proto)
131+
assert mds.determinant() != 0
132+
if not optimize_mds:
133+
# Sanity check: check the determinant of the matrix.
134+
x_prod = product(
135+
[x_values[i] - x_values[j] for i in xrange(m) for j in xrange(i)])
136+
y_prod = product(
137+
[y_values[i] - y_values[j] for i in xrange(m) for j in xrange(i)])
138+
xy_prod = product(
139+
[x_values[i] - y_values[j] for i in xrange(m) for j in xrange(m)])
140+
expected_det = (1 if m % 4 < 2 else -1) * x_prod * y_prod / xy_prod
141+
det = mds.determinant()
142+
assert det != 0
143+
assert det == expected_det, \
144+
'Expected determinant %s. Found %s' % (expected_det, det)
145+
if len(mds.characteristic_polynomial().roots()) == 0:
146+
# There are no eigenvalues in the field.
147+
return mds
148+
print(mds.characteristic_polynomial().roots())
149+
raise Exception('No good MDS found')
150+
def generate_mds_cpp(mds):
151+
s = ["std::vector<std::vector<FieldT>> mds_matrix;"]
152+
# bigint<FieldT::num_limbs>("1234")
153+
for r in mds:
154+
line = "mds_matrix.push_back(std::vector<FieldT>({"
155+
for c in r:
156+
line += "FieldT(bigint<FieldT::num_limbs>(\"" + str(c) + "\")),"
157+
line = line[:-1]
158+
line += "}));"
159+
s += [line]
160+
print('\n'.join(s))
161+
def generate_mds_rs(mds):
162+
# let mds = vec![vec![F::one(), F::zero(), F::one()],
163+
# vec![F::one(), F::one(), F::zero()],
164+
# vec![F::zero(), F::one(), F::one()]];
165+
# but F::from_str(\"" + str(c) + "\").map_err(|_| ()).unwrap()
166+
s = ["let mds = vec!["]
167+
for r in mds:
168+
line = "vec!["
169+
for c in r:
170+
line += "F::from_str(\"" + str(c) + "\").map_err(|_| ()).unwrap(),"
171+
line = line[:-1]
172+
line += "],"
173+
s += [line]
174+
s[-1] += "];"
175+
print('\n'.join(s))
176+
def generate_ark(hash_name, field, state_size, num_rounds, optimize_ark=False, R_p=0, mds=None):
177+
ark = [vector(generate_round_constant('Hades', field, state_size * i + j)
178+
for j in xrange(state_size))
179+
for i in xrange(num_rounds)]
180+
s = ["std::vector<std::vector<FieldT>> ark_matrix;"]
181+
# bigint<FieldT::num_limbs>("1234")
182+
for r in ark:
183+
line = "ark_matrix.push_back(std::vector<FieldT>({"
184+
for c in r:
185+
line += "FieldT(bigint<FieldT::num_limbs>(\"" + str(c) + "\")),"
186+
line = line[:-1]
187+
line += "}));"
188+
s += [line]
189+
print('\n'.join(s))
190+
if optimize_ark:
191+
# Remove Rf
192+
ark = ark[4:-4]
193+
# ark = ark[::-1]
194+
mds_inv = mds.inverse()
195+
mds_pow = mds_inv
196+
linear_constants = ark[0]
197+
linear_constants[0] = field(0)
198+
non_linear_constants = [ark[0][0]]
199+
for i in range(1, R_p):
200+
cur = mds_pow * ark[i]
201+
non_linear_constants = non_linear_constants + [cur[0]]
202+
cur[0] = field(0)
203+
linear_constants += cur
204+
mds_pow *= mds_inv
205+
s = "std::vector<FieldT> rp_linear_ark_constants = std::vector<FieldT>({"
206+
for c in linear_constants:
207+
s += "FieldT(bigint<FieldT::num_limbs>(\"" + str(c) + "\")),"
208+
s += "});"
209+
print(s)
210+
s = "std::vector<FieldT> rp_non-linear_ark_constants = std::vector<FieldT>({"
211+
def generate_ark_rs(hash_name, field, state_size, num_rounds, optimize_ark=False, R_p=0, mds=None):
212+
ark = [vector(generate_round_constant('Hades', field, state_size * i + j)
213+
for j in xrange(state_size))
214+
for i in xrange(num_rounds)]
215+
s = ["let ark = vec!["]
216+
# bigint<FieldT::num_limbs>("1234")
217+
for r in ark:
218+
line = "vec!["
219+
for c in r:
220+
line += "F::from_str(\"" + str(c) + "\").map_err(|_| ()).unwrap(),"
221+
line = line[:-1]
222+
line += "],"
223+
s += [line]
224+
print('\n'.join(s) + '];')
225+
# if optimize_ark:
226+
# # Remove Rf
227+
# ark = ark[4:-4]
228+
# # ark = ark[::-1]
229+
# mds_inv = mds.inverse()
230+
# mds_pow = mds_inv
231+
# linear_constants = ark[0]
232+
# linear_constants[0] = field(0)
233+
# non_linear_constants = [ark[0][0]]
234+
# for i in range(1, R_p):
235+
# cur = mds_pow * ark[i]
236+
# non_linear_constants = non_linear_constants + [cur[0]]
237+
# cur[0] = field(0)
238+
# linear_constants += cur
239+
# mds_pow *= mds_inv
240+
# s = "std::vector<FieldT> rp_linear_ark_constants = std::vector<FieldT>({"
241+
# for c in linear_constants:
242+
# s += "FieldT(bigint<FieldT::num_limbs>(\"" + str(c) + "\")),"
243+
# s += "});"
244+
# print(s)
245+
# s = "std::vector<FieldT> rp_non-linear_ark_constants = std::vector<FieldT>({"
246+
247+
def generate_mds_code(hash_name, field, state_size, optimize_mds, lang=cpp):
248+
mds = generate_mds_matrix(hash_name + "MDS", field, state_size, optimize_mds)
249+
if lang == cpp:
250+
generate_mds_cpp(mds)
251+
else:
252+
generate_mds_rs(mds)
253+
return mds
254+
def generate_poseidon_param_code(hash_name, field, state_size, num_rounds, optimize_mds=True, optimize_ark=False, lang=cpp):
255+
mds = generate_mds_code(hash_name, field, state_size, optimize_mds, lang)
256+
R_f = 8
257+
R_p = num_rounds - R_f
258+
if lang == cpp:
259+
generate_ark(hash_name, field, state_size, num_rounds, optimize_ark, R_p, mds)
260+
else:
261+
generate_ark_rs(hash_name, field, state_size, num_rounds, optimize_ark, R_p, mds)
262+
def generate_rescue_param_code(hash_name, field, state_size, num_rounds, optimize_mds=False, lang=cpp):
263+
mds = generate_mds_code(hash_name, field, state_size, optimize_mds, lang)
264+
num_steps = 2*num_rounds
265+
# Poseidon params that we don't need
266+
optimize_ark = False
267+
R_p = 0
268+
if lang == cpp:
269+
generate_ark(hash_name, field, state_size, num_steps, optimize_ark, R_p, mds)
270+
else:
271+
generate_ark_rs(hash_name, field, state_size, num_steps, optimize_ark, R_p, mds)
272+
# This just calculates the minimum number rounds per the dominating constraint
273+
# We do a full calculation against all relevant equations in the CPP logic, this is just for sanity checks.
274+
# The dominating constraint is the defense against interpolation attack.
275+
# TODO: Include analysis not in the paper of capacity > 1 improving num rounds
276+
def calculate_num_poseidon_rounds(field, sec, alpha, num_capacity_elems, state_size):
277+
assert num_capacity_elems * len(bin(field.order())[2:]) > sec
278+
max_num_rounds = log(2, alpha) * sec + log(state_size, 2)
279+
partial_rounds = max_num_rounds - 6
280+
# apply security margin of the paper, 7.5%
281+
partial_rounds = int(ceil(partial_rounds*1.075))
282+
num_rounds = partial_rounds + 8
283+
return num_rounds
284+
poseidon_hash_name = "Hades"
285+
rescue_hash_name = "Rescue"
286+
# default = True
287+
# if default:
288+
# # alpha = 5
289+
# generate_poseidon_param_code(poseidon_hash_name, AltBn, 17, 66, optimize_mds=False, lang=rust)
290+
# generate_rescue_param_code(rescue_hash_name, AltBn, 17, 10, False, rust)
291+
# Dev's recommended params for Fractal recursion / Marlin recursion
292+
alpha=17
293+
capacity_size=1
294+
arity=2
295+
sec = 128
296+
state_size = arity + capacity_size
297+
num_rounds = calculate_num_poseidon_rounds(bw6, sec, alpha, capacity_size, state_size)
298+
# You can set optimize_mds to True if you want to take a heuristic on using near-MDS matrices
299+
# The paper authors were of the opinion that this worked (with an update to the differential analysis given)
300+
# which will be satisfied for any large field
301+
generate_poseidon_param_code(poseidon_hash_name, bw6, state_size, num_rounds, optimize_mds=False, lang=rust)

0 commit comments

Comments
 (0)