Skip to content

Commit 1837724

Browse files
perf: BN254 Miller Loop "10" "01" compression (keep-starknet-strange#261)
1 parent 5f3b232 commit 1837724

File tree

18 files changed

+7444
-11926
lines changed

18 files changed

+7444
-11926
lines changed

.github/workflows/maturin.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ jobs:
111111
strategy:
112112
matrix:
113113
platform:
114-
- runner: macos-12
114+
- runner: macos-14
115115
target: x86_64
116116
- runner: macos-14
117117
target: aarch64

hydra/garaga/definitions.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -1227,5 +1227,35 @@ def replace_consecutive_zeros(lst):
12271227
return result
12281228

12291229

1230+
def recode_naf_bits(lst):
1231+
result = []
1232+
i = 0
1233+
while i < len(lst):
1234+
if i < len(lst) - 1 and lst[i] == 0 and (lst[i + 1] == 1 or lst[i + 1] == -1):
1235+
# "01" or "0-1"
1236+
if lst[i + 1] == 1:
1237+
result.append(3) # Replace "01" with 3
1238+
else:
1239+
result.append(4) # Replace "0-1" with 4
1240+
i += 2
1241+
elif i < len(lst) - 1 and (lst[i] == 1 or lst[i] == -1) and lst[i + 1] == 0:
1242+
# "10" or "-10"
1243+
if lst[i] == 1:
1244+
result.append(1) # Replace 10 with 6
1245+
else:
1246+
result.append(2) # Replace -10 with 7
1247+
i += 2
1248+
elif i < len(lst) - 1 and lst[i] == 0 and lst[i + 1] == 0:
1249+
result.append(0) # Replace consecutive zeros with 0
1250+
i += 2
1251+
else:
1252+
raise ValueError(f"Unexpected bit sequence at index {i}")
1253+
return result
1254+
1255+
12301256
if __name__ == "__main__":
1231-
pass
1257+
r = recode_naf_bits(jy00(6 * 0x44E992B44A6909F1 + 2)[2:])
1258+
print(r, len(r))
1259+
1260+
# bls = [int(x) for x in bin(0xD201000000010000)[2:]][2:]
1261+
# recode_naf_bits(bls)

hydra/garaga/precompiled_circuits/all_circuits.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
FixedG2MPCheckBit0,
1414
FixedG2MPCheckBit00,
1515
FixedG2MPCheckBit1,
16+
FixedG2MPCheckBit01,
17+
FixedG2MPCheckBit10,
1618
FixedG2MPCheckFinalizeBN,
1719
FixedG2MPCheckInitBit,
1820
FP12MulAssertOne,
@@ -90,6 +92,8 @@ class CircuitID(Enum):
9092
MP_CHECK_BIT0_LOOP = int.from_bytes(b"mp_check_bit0_loop", "big")
9193
MP_CHECK_BIT00_LOOP = int.from_bytes(b"mp_check_bit00_loop", "big")
9294
MP_CHECK_BIT1_LOOP = int.from_bytes(b"mp_check_bit1_loop", "big")
95+
MP_CHECK_BIT01_LOOP = int.from_bytes(b"mp_check_bit01_loop", "big")
96+
MP_CHECK_BIT10_LOOP = int.from_bytes(b"mp_check_bit10_loop", "big")
9397
MP_CHECK_PREPARE_PAIRS = int.from_bytes(b"mp_check_prepare_pairs", "big")
9498
MP_CHECK_PREPARE_LAMBDA_ROOT = int.from_bytes(
9599
b"mp_check_prepare_lambda_root", "big"
@@ -206,7 +210,7 @@ class CircuitID(Enum):
206210
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
207211
],
208212
"filename": "multi_pairing_check",
209-
"curve_ids": [CurveID.BN254, CurveID.BLS12_381],
213+
"curve_ids": [CurveID.BLS12_381],
210214
},
211215
CircuitID.MP_CHECK_BIT00_LOOP: {
212216
"class": FixedG2MPCheckBit00,
@@ -224,7 +228,25 @@ class CircuitID(Enum):
224228
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
225229
],
226230
"filename": "multi_pairing_check",
227-
"curve_ids": [CurveID.BN254, CurveID.BLS12_381],
231+
"curve_ids": [CurveID.BLS12_381],
232+
},
233+
CircuitID.MP_CHECK_BIT01_LOOP: {
234+
"class": FixedG2MPCheckBit01,
235+
"params": [
236+
{"n_pairs": 2, "n_fixed_g2": 2}, # BLS SIG / KZG Verif
237+
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
238+
],
239+
"filename": "multi_pairing_check",
240+
"curve_ids": [CurveID.BN254],
241+
},
242+
CircuitID.MP_CHECK_BIT10_LOOP: {
243+
"class": FixedG2MPCheckBit10,
244+
"params": [
245+
{"n_pairs": 2, "n_fixed_g2": 2}, # BLS SIG / KZG Verif
246+
{"n_pairs": 3, "n_fixed_g2": 2}, # Groth16
247+
],
248+
"filename": "multi_pairing_check",
249+
"curve_ids": [CurveID.BN254],
228250
},
229251
CircuitID.MP_CHECK_PREPARE_PAIRS: {
230252
"class": MPCheckPreparePairs,

hydra/garaga/precompiled_circuits/compilable_circuits/cairo1_mpcheck_circuits.py

+109-61
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,31 @@ def input_map(
120120

121121
def _base_input_map(self, bit_type: str) -> dict:
122122
"""
123-
Base input map for the bit 0, 1, and 00 cases.
123+
Base input map for the bit 0, 1, 00, 01, and 10 cases.
124124
"""
125125
input_map = {}
126126

127127
# Add pair inputs
128128
for k in range(self.n_fixed_g2):
129129
input_map[f"yInv_{k}"] = u384
130130
input_map[f"xNegOverY_{k}"] = u384
131-
input_map[f"G2_line_{k}"] = G2Line
132-
if bit_type == "1":
133-
input_map[f"Q_or_Q_neg_line{k}"] = G2Line
131+
input_map[f"G2_line_dbl_{k}"] = G2Line
132+
if bit_type in ("1"):
133+
input_map[f"G2_line_add_{k}"] = G2Line
134+
if bit_type == "10":
135+
input_map[f"G2_line_add_1_{k}"] = G2Line
136+
input_map[f"G2_line_dbl_0_{k}"] = G2Line
137+
if bit_type == "01":
138+
input_map[f"G2_line_dbl_1{k}"] = G2Line
139+
input_map[f"G2_line_add_1{k}"] = G2Line
134140
if bit_type == "00":
135141
input_map[f"G2_line_2nd_0_{k}"] = G2Line
136142

137143
for k in range(self.n_fixed_g2, self.n_pairs):
138144
input_map[f"yInv_{k}"] = u384
139145
input_map[f"xNegOverY_{k}"] = u384
140146
input_map[f"Q_{k}"] = G2PointCircuit
141-
if bit_type == "1":
147+
if bit_type in ("1", "01", "10"):
142148
input_map[f"Q_or_Q_neg_{k}"] = G2PointCircuit
143149

144150
# Add common inputs
@@ -147,7 +153,7 @@ def _base_input_map(self, bit_type: str) -> dict:
147153
input_map["f_i_plus_one_of_z"] = u384
148154

149155
# Add bit-specific inputs
150-
if bit_type == "1":
156+
if bit_type in ("1", "01", "10"):
151157
input_map["c_or_cinv_of_z"] = u384
152158

153159
input_map["z"] = u384
@@ -235,12 +241,15 @@ def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
235241
Implement the circuit logic using the processed input variables.
236242
"""
237243

238-
def _execute_circuit_bit_logic_base(self, circuit, vars, bit_type):
244+
def _execute_circuit_bit_logic_base(self, circuit: ModuloCircuit, vars, bit_type):
239245
n_pairs = self.n_pairs
240246
assert n_pairs >= 2, f"n_pairs must be >= 2, got {n_pairs}"
241247

242248
current_points, q_or_q_neg_points = parse_precomputed_g1_consts_and_g2_points(
243-
circuit, vars, n_pairs, bit_1=(bit_type == "1")
249+
circuit,
250+
vars,
251+
n_pairs,
252+
bit_1=(bit_type == "1" or bit_type == "01" or bit_type == "10"),
244253
)
245254

246255
circuit.create_lines_z_powers(vars["z"])
@@ -256,8 +265,12 @@ def _execute_circuit_bit_logic_base(self, circuit, vars, bit_type):
256265
circuit, current_points, q_or_q_neg_points, sum_i_prod_k_P, bit_type
257266
)
258267

259-
if bit_type == "1":
268+
if bit_type in ("1", "01"):
260269
sum_i_prod_k_P = circuit.mul(sum_i_prod_k_P, vars["c_or_cinv_of_z"])
270+
elif bit_type == "10":
271+
sum_i_prod_k_P = circuit.mul(
272+
sum_i_prod_k_P, circuit.square(vars["c_or_cinv_of_z"])
273+
)
261274

262275
f_i_plus_one_of_z = vars["f_i_plus_one_of_z"]
263276
new_lhs = circuit.mul(
@@ -304,6 +317,55 @@ def _process_points(
304317
)
305318
new_new_points.append(T)
306319
return new_new_points, sum_i_prod_k_P
320+
elif bit_type == "01":
321+
for k in range(self.n_pairs):
322+
T, l1 = circuit.double_step(current_points[k], k)
323+
sum_i_prod_k_P = self._multiply_line_evaluations(
324+
circuit, sum_i_prod_k_P, [l1], k
325+
)
326+
new_points.append(T)
327+
328+
sum_i_prod_k_P = circuit.mul(
329+
sum_i_prod_k_P,
330+
sum_i_prod_k_P,
331+
"Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2",
332+
)
333+
334+
new_new_points = []
335+
for k in range(self.n_pairs):
336+
T, l1, l2 = circuit.double_and_add_step(
337+
new_points[k], q_or_q_neg_points[k], k
338+
)
339+
sum_i_prod_k_P = self._multiply_line_evaluations(
340+
circuit, sum_i_prod_k_P, [l1, l2], k
341+
)
342+
new_new_points.append(T)
343+
344+
return new_new_points, sum_i_prod_k_P
345+
346+
elif bit_type == "10":
347+
for k in range(self.n_pairs):
348+
T, l1, l2 = circuit.double_and_add_step(
349+
current_points[k], q_or_q_neg_points[k], k
350+
)
351+
sum_i_prod_k_P = self._multiply_line_evaluations(
352+
circuit, sum_i_prod_k_P, [l1, l2], k
353+
)
354+
new_points.append(T)
355+
356+
sum_i_prod_k_P = circuit.mul(
357+
sum_i_prod_k_P,
358+
sum_i_prod_k_P,
359+
"Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2",
360+
)
361+
new_new_points = []
362+
for k in range(self.n_pairs):
363+
T, l1 = circuit.double_step(new_points[k], k)
364+
sum_i_prod_k_P = self._multiply_line_evaluations(
365+
circuit, sum_i_prod_k_P, [l1], k
366+
)
367+
new_new_points.append(T)
368+
return new_new_points, sum_i_prod_k_P
307369
elif bit_type == "0":
308370
for k in range(self.n_pairs):
309371
T, l1 = circuit.double_step(current_points[k], k)
@@ -320,9 +382,18 @@ def _process_points(
320382
circuit, sum_i_prod_k_P, [l1, l2], k
321383
)
322384
new_points.append(T)
385+
386+
else:
387+
raise ValueError(f"Invalid bit type: {bit_type}")
323388
return new_points, sum_i_prod_k_P
324389

325-
def _multiply_line_evaluations(self, circuit, sum_i_prod_k_P, lines, k):
390+
def _multiply_line_evaluations(
391+
self,
392+
circuit: multi_pairing_check.MultiPairingCheckCircuit,
393+
sum_i_prod_k_P,
394+
lines,
395+
k,
396+
):
326397
for i, l in enumerate(lines):
327398
sum_i_prod_k_P = circuit.mul(
328399
sum_i_prod_k_P,
@@ -358,17 +429,27 @@ def _extend_output(self, circuit, new_points, lhs_i_plus_one, ci_plus_one):
358429
circuit.extend_struct_output(u384(name="ci_plus_one", elmts=[ci_plus_one]))
359430

360431

361-
class FixedG2MPCheckBit0(BaseFixedG2PointsMPCheck):
432+
class FixedG2MPCheckBitBase(BaseFixedG2PointsMPCheck):
433+
"""Base class for bit checking circuits with default parameters."""
434+
435+
BIT_TYPE = None # Override in subclasses
436+
DEFAULT_PAIRS = 3
437+
DEFAULT_FIXED_G2 = 2
438+
362439
def __init__(
363440
self,
364441
curve_id: int,
365-
n_pairs: int,
366-
n_fixed_g2: int,
442+
n_pairs: int = None,
443+
n_fixed_g2: int = None,
367444
auto_run: bool = True,
368445
compilation_mode: int = 1,
369446
):
447+
assert compilation_mode == 1, "Compilation mode 1 is required for this circuit"
448+
n_pairs = n_pairs if n_pairs is not None else self.DEFAULT_PAIRS
449+
n_fixed_g2 = n_fixed_g2 if n_fixed_g2 is not None else self.DEFAULT_FIXED_G2
450+
370451
super().__init__(
371-
name=f"mp_check_bit0_{n_pairs}P_{n_fixed_g2}F",
452+
name=f"mp_check_bit{self.BIT_TYPE}_{n_pairs}P_{n_fixed_g2}F",
372453
curve_id=curve_id,
373454
n_pairs=n_pairs,
374455
n_fixed_g2=n_fixed_g2,
@@ -378,63 +459,30 @@ def __init__(
378459

379460
@property
380461
def input_map(self):
381-
return self._base_input_map("0")
462+
return self._base_input_map(self.BIT_TYPE)
382463

383464
def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
384-
return self._execute_circuit_bit_logic_base(circuit, vars, "0")
465+
return self._execute_circuit_bit_logic_base(circuit, vars, self.BIT_TYPE)
385466

386467

387-
class FixedG2MPCheckBit00(BaseFixedG2PointsMPCheck):
388-
def __init__(
389-
self,
390-
curve_id: int,
391-
auto_run: bool = True,
392-
compilation_mode: int = 1,
393-
n_pairs: int = 3,
394-
n_fixed_g2: int = 2,
395-
):
396-
super().__init__(
397-
name=f"mp_check_bit00_{n_pairs}P_{n_fixed_g2}F",
398-
curve_id=curve_id,
399-
n_pairs=n_pairs,
400-
n_fixed_g2=n_fixed_g2,
401-
auto_run=auto_run,
402-
compilation_mode=compilation_mode,
403-
)
468+
class FixedG2MPCheckBit0(FixedG2MPCheckBitBase):
469+
BIT_TYPE = "0"
404470

405-
@property
406-
def input_map(self):
407-
return self._base_input_map("00")
408471

409-
def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
410-
return self._execute_circuit_bit_logic_base(circuit, vars, "00")
472+
class FixedG2MPCheckBit00(FixedG2MPCheckBitBase):
473+
BIT_TYPE = "00"
411474

412475

413-
class FixedG2MPCheckBit1(BaseFixedG2PointsMPCheck):
414-
def __init__(
415-
self,
416-
curve_id: int,
417-
auto_run: bool = True,
418-
n_pairs: int = 3,
419-
n_fixed_g2: int = 2,
420-
compilation_mode: int = 1,
421-
):
422-
assert compilation_mode == 1, "Compilation mode 1 is required for this circuit"
423-
super().__init__(
424-
name=f"mp_check_bit1_{n_pairs}P_{n_fixed_g2}F",
425-
curve_id=curve_id,
426-
n_pairs=n_pairs,
427-
n_fixed_g2=n_fixed_g2,
428-
auto_run=auto_run,
429-
compilation_mode=compilation_mode,
430-
)
476+
class FixedG2MPCheckBit1(FixedG2MPCheckBitBase):
477+
BIT_TYPE = "1"
431478

432-
@property
433-
def input_map(self):
434-
return self._base_input_map("1")
435479

436-
def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
437-
return self._execute_circuit_bit_logic_base(circuit, vars, "1")
480+
class FixedG2MPCheckBit01(FixedG2MPCheckBitBase):
481+
BIT_TYPE = "01"
482+
483+
484+
class FixedG2MPCheckBit10(FixedG2MPCheckBitBase):
485+
BIT_TYPE = "10"
438486

439487

440488
class FixedG2MPCheckInitBit(BaseFixedG2PointsMPCheck):

0 commit comments

Comments
 (0)