1
+ # This is a notably edited copy of STARKWare's code for generating MDS.
2
+ # New Fields, optimized MDS matrix generation for low state size,
3
+ # Poseidon parameterization, and generating Cpp/Rust code has been added here.
4
+ # Copying their license here
5
+ # Copyright 2019 StarkWare Industries Ltd.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License").
8
+ # You may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # https://www.starkware.co/open-source-license/
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions
17
+ # and limitations under the License.
18
+ # language enum
19
+ cpp = 0
20
+ rust = 1
21
+ # Prime fields.
22
+ F61 = GF (2 ** 61 + 20 * 2 ** 32 + 1 )
23
+ F81 = GF (2 ** 81 + 80 * 2 ** 64 + 1 )
24
+ F91 = GF (2 ** 91 + 5 * 2 ** 64 + 1 )
25
+ F125 = GF (2 ** 125 + 266 * 2 ** 64 + 1 )
26
+ F161 = GF (2 ** 161 + 23 * 2 ** 128 + 1 )
27
+ F253 = GF (2 ** 253 + 2 ** 199 + 1 )
28
+ AltBn = GF (21888242871839275222246405745257275088548364400416034343698204186575808495617 )
29
+ bw6 = GF (258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 )
30
+ # Binary fields.
31
+ X = GF (2 )['X' ].gen ()
32
+ Bin63 = GF (2 ** 63 , name = 'a' , modulus = X ** 63 + X + 1 )
33
+ Bin81 = GF (2 ** 81 , name = 'a' , modulus = X ** 81 + X ** 4 + 1 )
34
+ Bin91 = GF (2 ** 91 , name = 'a' , modulus = X ** 91 + X ** 8 + X ** 5 + X + 1 )
35
+ Bin127 = GF (2 ** 127 , name = 'a' , modulus = X ** 127 + X + 1 )
36
+ Bin161 = GF (2 ** 161 , name = 'a' , modulus = X ** 161 + X ** 18 + 1 )
37
+ Bin255 = GF (2 ** 255 , name = 'a' , modulus = X ** 255 + X ** 5 + X ** 3 + X ** 2 + 1 )
38
+ def sponge (permutation_func , inputs , params ):
39
+ """
40
+ Applies the sponge construction to permutation_func.
41
+ inputs should be a vector of field elements whose size is divisible by
42
+ params.r.
43
+ permutation_func should be a function which gets (state, params) where state
44
+ is a vector of params.m field elements, and returns a vector of params.m
45
+ field elements.
46
+ """
47
+ assert parent (inputs ) == VectorSpace (params .field , len (inputs )), \
48
+ 'inputs must be a vector of field elements. Found: %r' % parent (inputs )
49
+ assert len (inputs ) % params .r == 0 , \
50
+ 'Number of field elements must be divisible by %s. Found: %s' % (
51
+ params .r , len (inputs ))
52
+ state = vector ([params .field (0 )] * params .m )
53
+ for i in xrange (0 , len (inputs ), params .r ):
54
+ state [:params .r ] += inputs [i :i + params .r ]
55
+ state = permutation_func (state , params )
56
+ # We do not support more than r output elements, since this requires
57
+ # additional invocations of permutation_func.
58
+ assert params .output_size <= params .r
59
+ return state [:params .output_size ]
60
+ def generate_round_constant (fn_name , field , idx ):
61
+ """
62
+ Returns a field element based on the result of sha256.
63
+ The input to sha256 is the concatenation of the name of the hash function
64
+ and an index.
65
+ For example, the first element for MiMC will be computed using the value
66
+ of sha256('MiMC0').
67
+ """
68
+ from hashlib import sha256
69
+ val = int (sha256 ('%s%d' % (fn_name , idx )).hexdigest (), 16 )
70
+ if field .is_prime_field ():
71
+ return field (val )
72
+ else :
73
+ return int2field (field , val % field .order ())
74
+ def int2field (field , val ):
75
+ """
76
+ Converts val to an element of a binary field according to the binary
77
+ representation of val.
78
+ For example, 11=0b1011 is converted to 1*a^3 + 0*a^2 + 1*a + 1.
79
+ """
80
+ assert field .characteristic () == 2
81
+ assert 0 <= val < field .order (), \
82
+ 'Value %d out of range. Expected 0 <= val < %d.' % (val , field .order ())
83
+ res = field (map (int , bin (val )[2 :][::- 1 ]))
84
+ assert res .integer_representation () == val
85
+ return res
86
+ def binary_vector (field , values ):
87
+ """
88
+ Converts a list of integers to field elements using int2field.
89
+ """
90
+ return vector (field , [int2field (field , val ) for val in values ])
91
+ def binary_matrix (field , values ):
92
+ """
93
+ Converts a list of lists of integers to field elements using int2field.
94
+ """
95
+ return matrix (field , [[int2field (field , val ) for val in row ]
96
+ for row in values ])
97
+ def generate_mds_matrix (name , field , m , optimize_mds = True ):
98
+ """
99
+ Generates an MDS matrix of size m x m over the given field, with no
100
+ eigenvalues in the field.
101
+ Given two disjoint sets of size m: {x_1, ..., x_m}, {y_1, ..., y_m} we set
102
+ A_{ij} = 1 / (x_i - y_j).
103
+ """
104
+ for attempt in xrange (100 ):
105
+ x_values = [generate_round_constant (name + 'x' , field , attempt * m + i )
106
+ for i in xrange (m )]
107
+ y_values = [generate_round_constant (name + 'y' , field , attempt * m + i )
108
+ for i in xrange (m )]
109
+ # Make sure the values are distinct.
110
+ assert len (set (x_values + y_values )) == 2 * m , \
111
+ 'The values of x_values and y_values are not distinct'
112
+ mds_proto = ([[1 / (x_values [i ] - y_values [j ]) for j in xrange (m )]
113
+ for i in xrange (m )])
114
+ if optimize_mds :
115
+ # massive reduction in constraint complexity, and computation time
116
+ # These are near-MDS matrices
117
+ if m == 3 :
118
+ # [[1, 0, 1],
119
+ # [1, 1, 0],
120
+ # [0, 1, 1]]
121
+ mds_proto = [[field (1 ), field (0 ), field (1 )],
122
+ [field (1 ), field (1 ), field (0 )],
123
+ [field (0 ), field (1 ), field (1 )]]
124
+ elif m == 4 :
125
+ mds_proto = [[field (1 ) if x != y else field (0 ) for x in range (m )] for y in range (m )]
126
+ else :
127
+ raise RuntimeError
128
+ mds = matrix (mds_proto )
129
+ return mds
130
+ mds = matrix (mds_proto )
131
+ assert mds .determinant () != 0
132
+ if not optimize_mds :
133
+ # Sanity check: check the determinant of the matrix.
134
+ x_prod = product (
135
+ [x_values [i ] - x_values [j ] for i in xrange (m ) for j in xrange (i )])
136
+ y_prod = product (
137
+ [y_values [i ] - y_values [j ] for i in xrange (m ) for j in xrange (i )])
138
+ xy_prod = product (
139
+ [x_values [i ] - y_values [j ] for i in xrange (m ) for j in xrange (m )])
140
+ expected_det = (1 if m % 4 < 2 else - 1 ) * x_prod * y_prod / xy_prod
141
+ det = mds .determinant ()
142
+ assert det != 0
143
+ assert det == expected_det , \
144
+ 'Expected determinant %s. Found %s' % (expected_det , det )
145
+ if len (mds .characteristic_polynomial ().roots ()) == 0 :
146
+ # There are no eigenvalues in the field.
147
+ return mds
148
+ print (mds .characteristic_polynomial ().roots ())
149
+ raise Exception ('No good MDS found' )
150
+ def generate_mds_cpp (mds ):
151
+ s = ["std::vector<std::vector<FieldT>> mds_matrix;" ]
152
+ # bigint<FieldT::num_limbs>("1234")
153
+ for r in mds :
154
+ line = "mds_matrix.push_back(std::vector<FieldT>({"
155
+ for c in r :
156
+ line += "FieldT(bigint<FieldT::num_limbs>(\" " + str (c ) + "\" )),"
157
+ line = line [:- 1 ]
158
+ line += "}));"
159
+ s += [line ]
160
+ print ('\n ' .join (s ))
161
+ def generate_mds_rs (mds ):
162
+ # let mds = vec![vec![F::one(), F::zero(), F::one()],
163
+ # vec![F::one(), F::one(), F::zero()],
164
+ # vec![F::zero(), F::one(), F::one()]];
165
+ # but F::from_str(\"" + str(c) + "\").map_err(|_| ()).unwrap()
166
+ s = ["let mds = vec![" ]
167
+ for r in mds :
168
+ line = "vec!["
169
+ for c in r :
170
+ line += "F::from_str(\" " + str (c ) + "\" ).map_err(|_| ()).unwrap(),"
171
+ line = line [:- 1 ]
172
+ line += "],"
173
+ s += [line ]
174
+ s [- 1 ] += "];"
175
+ print ('\n ' .join (s ))
176
+ def generate_ark (hash_name , field , state_size , num_rounds , optimize_ark = False , R_p = 0 , mds = None ):
177
+ ark = [vector (generate_round_constant ('Hades' , field , state_size * i + j )
178
+ for j in xrange (state_size ))
179
+ for i in xrange (num_rounds )]
180
+ s = ["std::vector<std::vector<FieldT>> ark_matrix;" ]
181
+ # bigint<FieldT::num_limbs>("1234")
182
+ for r in ark :
183
+ line = "ark_matrix.push_back(std::vector<FieldT>({"
184
+ for c in r :
185
+ line += "FieldT(bigint<FieldT::num_limbs>(\" " + str (c ) + "\" )),"
186
+ line = line [:- 1 ]
187
+ line += "}));"
188
+ s += [line ]
189
+ print ('\n ' .join (s ))
190
+ if optimize_ark :
191
+ # Remove Rf
192
+ ark = ark [4 :- 4 ]
193
+ # ark = ark[::-1]
194
+ mds_inv = mds .inverse ()
195
+ mds_pow = mds_inv
196
+ linear_constants = ark [0 ]
197
+ linear_constants [0 ] = field (0 )
198
+ non_linear_constants = [ark [0 ][0 ]]
199
+ for i in range (1 , R_p ):
200
+ cur = mds_pow * ark [i ]
201
+ non_linear_constants = non_linear_constants + [cur [0 ]]
202
+ cur [0 ] = field (0 )
203
+ linear_constants += cur
204
+ mds_pow *= mds_inv
205
+ s = "std::vector<FieldT> rp_linear_ark_constants = std::vector<FieldT>({"
206
+ for c in linear_constants :
207
+ s += "FieldT(bigint<FieldT::num_limbs>(\" " + str (c ) + "\" )),"
208
+ s += "});"
209
+ print (s )
210
+ s = "std::vector<FieldT> rp_non-linear_ark_constants = std::vector<FieldT>({"
211
+ def generate_ark_rs (hash_name , field , state_size , num_rounds , optimize_ark = False , R_p = 0 , mds = None ):
212
+ ark = [vector (generate_round_constant ('Hades' , field , state_size * i + j )
213
+ for j in xrange (state_size ))
214
+ for i in xrange (num_rounds )]
215
+ s = ["let ark = vec![" ]
216
+ # bigint<FieldT::num_limbs>("1234")
217
+ for r in ark :
218
+ line = "vec!["
219
+ for c in r :
220
+ line += "F::from_str(\" " + str (c ) + "\" ).map_err(|_| ()).unwrap(),"
221
+ line = line [:- 1 ]
222
+ line += "],"
223
+ s += [line ]
224
+ print ('\n ' .join (s ) + '];' )
225
+ # if optimize_ark:
226
+ # # Remove Rf
227
+ # ark = ark[4:-4]
228
+ # # ark = ark[::-1]
229
+ # mds_inv = mds.inverse()
230
+ # mds_pow = mds_inv
231
+ # linear_constants = ark[0]
232
+ # linear_constants[0] = field(0)
233
+ # non_linear_constants = [ark[0][0]]
234
+ # for i in range(1, R_p):
235
+ # cur = mds_pow * ark[i]
236
+ # non_linear_constants = non_linear_constants + [cur[0]]
237
+ # cur[0] = field(0)
238
+ # linear_constants += cur
239
+ # mds_pow *= mds_inv
240
+ # s = "std::vector<FieldT> rp_linear_ark_constants = std::vector<FieldT>({"
241
+ # for c in linear_constants:
242
+ # s += "FieldT(bigint<FieldT::num_limbs>(\"" + str(c) + "\")),"
243
+ # s += "});"
244
+ # print(s)
245
+ # s = "std::vector<FieldT> rp_non-linear_ark_constants = std::vector<FieldT>({"
246
+
247
+ def generate_mds_code (hash_name , field , state_size , optimize_mds , lang = cpp ):
248
+ mds = generate_mds_matrix (hash_name + "MDS" , field , state_size , optimize_mds )
249
+ if lang == cpp :
250
+ generate_mds_cpp (mds )
251
+ else :
252
+ generate_mds_rs (mds )
253
+ return mds
254
+ def generate_poseidon_param_code (hash_name , field , state_size , num_rounds , optimize_mds = True , optimize_ark = False , lang = cpp ):
255
+ mds = generate_mds_code (hash_name , field , state_size , optimize_mds , lang )
256
+ R_f = 8
257
+ R_p = num_rounds - R_f
258
+ if lang == cpp :
259
+ generate_ark (hash_name , field , state_size , num_rounds , optimize_ark , R_p , mds )
260
+ else :
261
+ generate_ark_rs (hash_name , field , state_size , num_rounds , optimize_ark , R_p , mds )
262
+ def generate_rescue_param_code (hash_name , field , state_size , num_rounds , optimize_mds = False , lang = cpp ):
263
+ mds = generate_mds_code (hash_name , field , state_size , optimize_mds , lang )
264
+ num_steps = 2 * num_rounds
265
+ # Poseidon params that we don't need
266
+ optimize_ark = False
267
+ R_p = 0
268
+ if lang == cpp :
269
+ generate_ark (hash_name , field , state_size , num_steps , optimize_ark , R_p , mds )
270
+ else :
271
+ generate_ark_rs (hash_name , field , state_size , num_steps , optimize_ark , R_p , mds )
272
+ # This just calculates the minimum number rounds per the dominating constraint
273
+ # We do a full calculation against all relevant equations in the CPP logic, this is just for sanity checks.
274
+ # The dominating constraint is the defense against interpolation attack.
275
+ # TODO: Include analysis not in the paper of capacity > 1 improving num rounds
276
+ def calculate_num_poseidon_rounds (field , sec , alpha , num_capacity_elems , state_size ):
277
+ assert num_capacity_elems * len (bin (field .order ())[2 :]) > sec
278
+ max_num_rounds = log (2 , alpha ) * sec + log (state_size , 2 )
279
+ partial_rounds = max_num_rounds - 6
280
+ # apply security margin of the paper, 7.5%
281
+ partial_rounds = int (ceil (partial_rounds * 1.075 ))
282
+ num_rounds = partial_rounds + 8
283
+ return num_rounds
284
+ poseidon_hash_name = "Hades"
285
+ rescue_hash_name = "Rescue"
286
+ # default = True
287
+ # if default:
288
+ # # alpha = 5
289
+ # generate_poseidon_param_code(poseidon_hash_name, AltBn, 17, 66, optimize_mds=False, lang=rust)
290
+ # generate_rescue_param_code(rescue_hash_name, AltBn, 17, 10, False, rust)
291
+ # Dev's recommended params for Fractal recursion / Marlin recursion
292
+ alpha = 17
293
+ capacity_size = 1
294
+ arity = 2
295
+ sec = 128
296
+ state_size = arity + capacity_size
297
+ num_rounds = calculate_num_poseidon_rounds (bw6 , sec , alpha , capacity_size , state_size )
298
+ # You can set optimize_mds to True if you want to take a heuristic on using near-MDS matrices
299
+ # The paper authors were of the opinion that this worked (with an update to the differential analysis given)
300
+ # which will be satisfied for any large field
301
+ generate_poseidon_param_code (poseidon_hash_name , bw6 , state_size , num_rounds , optimize_mds = False , lang = rust )
0 commit comments