@@ -120,25 +120,31 @@ def input_map(
120
120
121
121
def _base_input_map (self , bit_type : str ) -> dict :
122
122
"""
123
- Base input map for the bit 0, 1, and 00 cases.
123
+ Base input map for the bit 0, 1, 00, 01, and 10 cases.
124
124
"""
125
125
input_map = {}
126
126
127
127
# Add pair inputs
128
128
for k in range (self .n_fixed_g2 ):
129
129
input_map [f"yInv_{ k } " ] = u384
130
130
input_map [f"xNegOverY_{ k } " ] = u384
131
- input_map [f"G2_line_{ k } " ] = G2Line
132
- if bit_type == "1" :
133
- input_map [f"Q_or_Q_neg_line{ k } " ] = G2Line
131
+ input_map [f"G2_line_dbl_{ k } " ] = G2Line
132
+ if bit_type in ("1" ):
133
+ input_map [f"G2_line_add_{ k } " ] = G2Line
134
+ if bit_type == "10" :
135
+ input_map [f"G2_line_add_1_{ k } " ] = G2Line
136
+ input_map [f"G2_line_dbl_0_{ k } " ] = G2Line
137
+ if bit_type == "01" :
138
+ input_map [f"G2_line_dbl_1{ k } " ] = G2Line
139
+ input_map [f"G2_line_add_1{ k } " ] = G2Line
134
140
if bit_type == "00" :
135
141
input_map [f"G2_line_2nd_0_{ k } " ] = G2Line
136
142
137
143
for k in range (self .n_fixed_g2 , self .n_pairs ):
138
144
input_map [f"yInv_{ k } " ] = u384
139
145
input_map [f"xNegOverY_{ k } " ] = u384
140
146
input_map [f"Q_{ k } " ] = G2PointCircuit
141
- if bit_type == "1" :
147
+ if bit_type in ( "1" , "01" , "10" ) :
142
148
input_map [f"Q_or_Q_neg_{ k } " ] = G2PointCircuit
143
149
144
150
# Add common inputs
@@ -147,7 +153,7 @@ def _base_input_map(self, bit_type: str) -> dict:
147
153
input_map ["f_i_plus_one_of_z" ] = u384
148
154
149
155
# Add bit-specific inputs
150
- if bit_type == "1" :
156
+ if bit_type in ( "1" , "01" , "10" ) :
151
157
input_map ["c_or_cinv_of_z" ] = u384
152
158
153
159
input_map ["z" ] = u384
@@ -235,12 +241,15 @@ def _execute_circuit_logic(self, circuit, vars) -> ModuloCircuit:
235
241
Implement the circuit logic using the processed input variables.
236
242
"""
237
243
238
- def _execute_circuit_bit_logic_base (self , circuit , vars , bit_type ):
244
+ def _execute_circuit_bit_logic_base (self , circuit : ModuloCircuit , vars , bit_type ):
239
245
n_pairs = self .n_pairs
240
246
assert n_pairs >= 2 , f"n_pairs must be >= 2, got { n_pairs } "
241
247
242
248
current_points , q_or_q_neg_points = parse_precomputed_g1_consts_and_g2_points (
243
- circuit , vars , n_pairs , bit_1 = (bit_type == "1" )
249
+ circuit ,
250
+ vars ,
251
+ n_pairs ,
252
+ bit_1 = (bit_type == "1" or bit_type == "01" or bit_type == "10" ),
244
253
)
245
254
246
255
circuit .create_lines_z_powers (vars ["z" ])
@@ -256,8 +265,12 @@ def _execute_circuit_bit_logic_base(self, circuit, vars, bit_type):
256
265
circuit , current_points , q_or_q_neg_points , sum_i_prod_k_P , bit_type
257
266
)
258
267
259
- if bit_type == "1" :
268
+ if bit_type in ( "1" , "01" ) :
260
269
sum_i_prod_k_P = circuit .mul (sum_i_prod_k_P , vars ["c_or_cinv_of_z" ])
270
+ elif bit_type == "10" :
271
+ sum_i_prod_k_P = circuit .mul (
272
+ sum_i_prod_k_P , circuit .square (vars ["c_or_cinv_of_z" ])
273
+ )
261
274
262
275
f_i_plus_one_of_z = vars ["f_i_plus_one_of_z" ]
263
276
new_lhs = circuit .mul (
@@ -304,6 +317,55 @@ def _process_points(
304
317
)
305
318
new_new_points .append (T )
306
319
return new_new_points , sum_i_prod_k_P
320
+ elif bit_type == "01" :
321
+ for k in range (self .n_pairs ):
322
+ T , l1 = circuit .double_step (current_points [k ], k )
323
+ sum_i_prod_k_P = self ._multiply_line_evaluations (
324
+ circuit , sum_i_prod_k_P , [l1 ], k
325
+ )
326
+ new_points .append (T )
327
+
328
+ sum_i_prod_k_P = circuit .mul (
329
+ sum_i_prod_k_P ,
330
+ sum_i_prod_k_P ,
331
+ "Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2" ,
332
+ )
333
+
334
+ new_new_points = []
335
+ for k in range (self .n_pairs ):
336
+ T , l1 , l2 = circuit .double_and_add_step (
337
+ new_points [k ], q_or_q_neg_points [k ], k
338
+ )
339
+ sum_i_prod_k_P = self ._multiply_line_evaluations (
340
+ circuit , sum_i_prod_k_P , [l1 , l2 ], k
341
+ )
342
+ new_new_points .append (T )
343
+
344
+ return new_new_points , sum_i_prod_k_P
345
+
346
+ elif bit_type == "10" :
347
+ for k in range (self .n_pairs ):
348
+ T , l1 , l2 = circuit .double_and_add_step (
349
+ current_points [k ], q_or_q_neg_points [k ], k
350
+ )
351
+ sum_i_prod_k_P = self ._multiply_line_evaluations (
352
+ circuit , sum_i_prod_k_P , [l1 , l2 ], k
353
+ )
354
+ new_points .append (T )
355
+
356
+ sum_i_prod_k_P = circuit .mul (
357
+ sum_i_prod_k_P ,
358
+ sum_i_prod_k_P ,
359
+ "Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2" ,
360
+ )
361
+ new_new_points = []
362
+ for k in range (self .n_pairs ):
363
+ T , l1 = circuit .double_step (new_points [k ], k )
364
+ sum_i_prod_k_P = self ._multiply_line_evaluations (
365
+ circuit , sum_i_prod_k_P , [l1 ], k
366
+ )
367
+ new_new_points .append (T )
368
+ return new_new_points , sum_i_prod_k_P
307
369
elif bit_type == "0" :
308
370
for k in range (self .n_pairs ):
309
371
T , l1 = circuit .double_step (current_points [k ], k )
@@ -320,9 +382,18 @@ def _process_points(
320
382
circuit , sum_i_prod_k_P , [l1 , l2 ], k
321
383
)
322
384
new_points .append (T )
385
+
386
+ else :
387
+ raise ValueError (f"Invalid bit type: { bit_type } " )
323
388
return new_points , sum_i_prod_k_P
324
389
325
- def _multiply_line_evaluations (self , circuit , sum_i_prod_k_P , lines , k ):
390
+ def _multiply_line_evaluations (
391
+ self ,
392
+ circuit : multi_pairing_check .MultiPairingCheckCircuit ,
393
+ sum_i_prod_k_P ,
394
+ lines ,
395
+ k ,
396
+ ):
326
397
for i , l in enumerate (lines ):
327
398
sum_i_prod_k_P = circuit .mul (
328
399
sum_i_prod_k_P ,
@@ -358,17 +429,27 @@ def _extend_output(self, circuit, new_points, lhs_i_plus_one, ci_plus_one):
358
429
circuit .extend_struct_output (u384 (name = "ci_plus_one" , elmts = [ci_plus_one ]))
359
430
360
431
361
- class FixedG2MPCheckBit0 (BaseFixedG2PointsMPCheck ):
432
+ class FixedG2MPCheckBitBase (BaseFixedG2PointsMPCheck ):
433
+ """Base class for bit checking circuits with default parameters."""
434
+
435
+ BIT_TYPE = None # Override in subclasses
436
+ DEFAULT_PAIRS = 3
437
+ DEFAULT_FIXED_G2 = 2
438
+
362
439
def __init__ (
363
440
self ,
364
441
curve_id : int ,
365
- n_pairs : int ,
366
- n_fixed_g2 : int ,
442
+ n_pairs : int = None ,
443
+ n_fixed_g2 : int = None ,
367
444
auto_run : bool = True ,
368
445
compilation_mode : int = 1 ,
369
446
):
447
+ assert compilation_mode == 1 , "Compilation mode 1 is required for this circuit"
448
+ n_pairs = n_pairs if n_pairs is not None else self .DEFAULT_PAIRS
449
+ n_fixed_g2 = n_fixed_g2 if n_fixed_g2 is not None else self .DEFAULT_FIXED_G2
450
+
370
451
super ().__init__ (
371
- name = f"mp_check_bit0_ { n_pairs } P_{ n_fixed_g2 } F" ,
452
+ name = f"mp_check_bit { self . BIT_TYPE } _ { n_pairs } P_{ n_fixed_g2 } F" ,
372
453
curve_id = curve_id ,
373
454
n_pairs = n_pairs ,
374
455
n_fixed_g2 = n_fixed_g2 ,
@@ -378,63 +459,30 @@ def __init__(
378
459
379
460
@property
380
461
def input_map (self ):
381
- return self ._base_input_map ("0" )
462
+ return self ._base_input_map (self . BIT_TYPE )
382
463
383
464
def _execute_circuit_logic (self , circuit , vars ) -> ModuloCircuit :
384
- return self ._execute_circuit_bit_logic_base (circuit , vars , "0" )
465
+ return self ._execute_circuit_bit_logic_base (circuit , vars , self . BIT_TYPE )
385
466
386
467
387
- class FixedG2MPCheckBit00 (BaseFixedG2PointsMPCheck ):
388
- def __init__ (
389
- self ,
390
- curve_id : int ,
391
- auto_run : bool = True ,
392
- compilation_mode : int = 1 ,
393
- n_pairs : int = 3 ,
394
- n_fixed_g2 : int = 2 ,
395
- ):
396
- super ().__init__ (
397
- name = f"mp_check_bit00_{ n_pairs } P_{ n_fixed_g2 } F" ,
398
- curve_id = curve_id ,
399
- n_pairs = n_pairs ,
400
- n_fixed_g2 = n_fixed_g2 ,
401
- auto_run = auto_run ,
402
- compilation_mode = compilation_mode ,
403
- )
468
+ class FixedG2MPCheckBit0 (FixedG2MPCheckBitBase ):
469
+ BIT_TYPE = "0"
404
470
405
- @property
406
- def input_map (self ):
407
- return self ._base_input_map ("00" )
408
471
409
- def _execute_circuit_logic ( self , circuit , vars ) -> ModuloCircuit :
410
- return self . _execute_circuit_bit_logic_base ( circuit , vars , "00" )
472
+ class FixedG2MPCheckBit00 ( FixedG2MPCheckBitBase ) :
473
+ BIT_TYPE = "00"
411
474
412
475
413
- class FixedG2MPCheckBit1 (BaseFixedG2PointsMPCheck ):
414
- def __init__ (
415
- self ,
416
- curve_id : int ,
417
- auto_run : bool = True ,
418
- n_pairs : int = 3 ,
419
- n_fixed_g2 : int = 2 ,
420
- compilation_mode : int = 1 ,
421
- ):
422
- assert compilation_mode == 1 , "Compilation mode 1 is required for this circuit"
423
- super ().__init__ (
424
- name = f"mp_check_bit1_{ n_pairs } P_{ n_fixed_g2 } F" ,
425
- curve_id = curve_id ,
426
- n_pairs = n_pairs ,
427
- n_fixed_g2 = n_fixed_g2 ,
428
- auto_run = auto_run ,
429
- compilation_mode = compilation_mode ,
430
- )
476
+ class FixedG2MPCheckBit1 (FixedG2MPCheckBitBase ):
477
+ BIT_TYPE = "1"
431
478
432
- @property
433
- def input_map (self ):
434
- return self ._base_input_map ("1" )
435
479
436
- def _execute_circuit_logic (self , circuit , vars ) -> ModuloCircuit :
437
- return self ._execute_circuit_bit_logic_base (circuit , vars , "1" )
480
+ class FixedG2MPCheckBit01 (FixedG2MPCheckBitBase ):
481
+ BIT_TYPE = "01"
482
+
483
+
484
+ class FixedG2MPCheckBit10 (FixedG2MPCheckBitBase ):
485
+ BIT_TYPE = "10"
438
486
439
487
440
488
class FixedG2MPCheckInitBit (BaseFixedG2PointsMPCheck ):
0 commit comments