Skip to content

Commit ec3c980

Browse files
committed
Fix coset ntt and have correctness
1 parent d9713f8 commit ec3c980

File tree

1 file changed

+47
-140
lines changed

1 file changed

+47
-140
lines changed

backend/groth16/bn254/icicle/icicle.go

+47-140
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ import (
66
"fmt"
77
"math/big"
88
"math/bits"
9-
"runtime"
109
"time"
1110

12-
"github.com/consensys/gnark-crypto/ecc"
1311
curve "github.com/consensys/gnark-crypto/ecc/bn254"
1412
"github.com/consensys/gnark-crypto/ecc/bn254/fp"
1513
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
@@ -32,7 +30,7 @@ import (
3230
icicle_g2 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/g2"
3331
icicle_msm "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/msm"
3432
icicle_ntt "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/ntt"
35-
// icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps"
33+
icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps"
3634

3735
fcs "github.com/consensys/gnark/frontend/cs"
3836
)
@@ -263,11 +261,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
263261
}
264262

265263
// H (witness reduction / FFT part)
266-
// var h icicle_core.DeviceSlice
267-
var hCPU []fr.Element
264+
var h icicle_core.DeviceSlice
268265
chHDone := make(chan struct{}, 1)
269266
go func() {
270-
hCPU = computeH(solution.A, solution.B, solution.C, &pk.Domain)
267+
h = computeH(solution.A, solution.B, solution.C, pk)
271268

272269
solution.A = nil
273270
solution.B = nil
@@ -329,7 +326,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
329326

330327
// computes r[δ], s[δ], kr[δ]
331328
deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr})
332-
n := runtime.NumCPU()
333329

334330
var bs1, ar curve.G1Jac
335331

@@ -368,31 +364,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
368364

369365
computeKRS := func() error {
370366
var krs, krs2, p1 curve.G1Jac
371-
var krs2CPU curve.G1Jac
372367
sizeH := int(pk.Domain.Cardinality - 1)
373368

374-
// CPU START
375-
376-
if _, err := krs2CPU.MultiExp(pk.G1.Z, hCPU[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
377-
panic("krs2CPU didn't complete")
378-
}
379-
380-
// CPU END
381-
382369
cfg := icicle_msm.GetDefaultMSMConfig()
383370
cfg.ArePointsMontgomeryForm = true
384371
cfg.AreScalarsMontgomeryForm = true
385372
resKrs2 := make(icicle_core.HostSlice[icicle_bn254.Projective], 1)
386-
// icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2)
387-
icicle_msm.Msm(icicle_core.HostSliceFromElements(hCPU[:sizeH]), pk.G1Device.Z, &cfg, resKrs2)
388-
373+
icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2)
389374
krs2 = g1ProjectiveToG1Jac(resKrs2[0])
390375

391-
if krs2.Equal(&krs2CPU) {
392-
fmt.Println("krs2 succeeded")
393-
} else {
394-
fmt.Println("krs2 failed correctness")
395-
}
396376
// filter the wire values if needed
397377
// TODO Perf @Tabaie worst memory allocation offender
398378
toRemove := commitmentInfo.GetPrivateCommitted()
@@ -498,68 +478,7 @@ func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr
498478
return
499479
}
500480

501-
// func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice {
502-
// // H part of Krs
503-
// // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
504-
// // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
505-
// // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c)
506-
// // 3 - h = ifft_coset(ca o cb - cc)
507-
508-
// n := len(a)
509-
510-
// // add padding to ensure input length is domain cardinality
511-
// padding := make([]fr.Element, int(pk.Domain.Cardinality)-n)
512-
// a = append(a, padding...)
513-
// b = append(b, padding...)
514-
// c = append(c, padding...)
515-
// n = len(a)
516-
517-
// computeADone := make(chan icicle_core.DeviceSlice, 1)
518-
// computeBDone := make(chan icicle_core.DeviceSlice, 1)
519-
// computeCDone := make(chan icicle_core.DeviceSlice, 1)
520-
521-
// computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) {
522-
// cfg := icicle_ntt.GetDefaultNttConfig()
523-
// scalarsStream, _ := icicle_cr.CreateStream()
524-
// cfg.Ctx.Stream = &scalarsStream
525-
// cfg.Ordering = icicle_core.KNR
526-
// cfg.IsAsync = true
527-
// scalarsHost := icicle_core.HostSliceFromElements(scalars)
528-
// var scalarsDevice icicle_core.DeviceSlice
529-
// scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true)
530-
// icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice)
531-
// cfg.Ordering = icicle_core.KRN
532-
// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGen[:]))
533-
// icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice)
534-
// icicle_cr.SynchronizeStream(&scalarsStream)
535-
// channel <-scalarsDevice
536-
// }
537-
538-
// go computeInttNttOnDevice(a, computeADone)
539-
// go computeInttNttOnDevice(b, computeBDone)
540-
// go computeInttNttOnDevice(c, computeCDone)
541-
542-
// aDevice := <-computeADone
543-
// bDevice := <-computeBDone
544-
// cDevice := <-computeCDone
545-
546-
// vecCfg := icicle_core.DefaultVecOpsConfig()
547-
// icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul)
548-
// icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub)
549-
// icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul)
550-
551-
// cfg := icicle_ntt.GetDefaultNttConfig()
552-
// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGenInv[:]))
553-
// cfg.Ordering = icicle_core.KNR
554-
// icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice)
555-
556-
// resHost := make(icicle_core.HostSlice[fr.Element], n)
557-
// resHost.CopyFromDevice(&aDevice)
558-
559-
// return aDevice
560-
// }
561-
562-
func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
481+
func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice {
563482
// H part of Krs
564483
// Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
565484
// 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
@@ -569,69 +488,57 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
569488
n := len(a)
570489

571490
// add padding to ensure input length is domain cardinality
572-
padding := make([]fr.Element, int(domain.Cardinality)-n)
491+
padding := make([]fr.Element, int(pk.Domain.Cardinality)-n)
573492
a = append(a, padding...)
574493
b = append(b, padding...)
575494
c = append(c, padding...)
576495
n = len(a)
577496

578-
aCopy := make([]fr.Element, n)
579-
copy(aCopy, a)
580-
581-
cfg := icicle_ntt.GetDefaultNttConfig()
582-
cfg.Ordering = icicle_core.KNR
583-
scalarsHost := icicle_core.HostSliceFromElements(aCopy)
584-
scalarsHostOut := make(icicle_core.HostSlice[fr.Element], len(aCopy))
585-
icicle_ntt.Ntt(scalarsHost, icicle_core.KInverse, &cfg, scalarsHostOut)
586-
587-
domain.FFTInverse(a, fft.DIF)
588-
589-
for i, elem := range a {
590-
if !elem.Equal(&scalarsHostOut[i]) {
591-
fmt.Println("computeH: A failed")
592-
}
497+
computeADone := make(chan icicle_core.DeviceSlice, 1)
498+
computeBDone := make(chan icicle_core.DeviceSlice, 1)
499+
computeCDone := make(chan icicle_core.DeviceSlice, 1)
500+
501+
cosetGenBits := pk.Domain.FrMultiplicativeGen.Bits()
502+
cosetGen := icicle_core.ConvertUint64ArrToUint32Arr(cosetGenBits[:])
503+
var configCosetGen [8]uint32
504+
copy(configCosetGen[:], cosetGen[:8])
505+
506+
computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) {
507+
cfg := icicle_ntt.GetDefaultNttConfig()
508+
scalarsStream, _ := icicle_cr.CreateStream()
509+
cfg.Ctx.Stream = &scalarsStream
510+
cfg.Ordering = icicle_core.KNR
511+
cfg.IsAsync = true
512+
scalarsHost := icicle_core.HostSliceFromElements(scalars)
513+
var scalarsDevice icicle_core.DeviceSlice
514+
scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true)
515+
icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice)
516+
cfg.Ordering = icicle_core.KRN
517+
cfg.CosetGen = configCosetGen
518+
icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice)
519+
icicle_cr.SynchronizeStream(&scalarsStream)
520+
channel <-scalarsDevice
593521
}
594522

595-
domain.FFTInverse(b, fft.DIF)
596-
domain.FFTInverse(c, fft.DIF)
597-
523+
go computeInttNttOnDevice(a, computeADone)
524+
go computeInttNttOnDevice(b, computeBDone)
525+
go computeInttNttOnDevice(c, computeCDone)
598526

599-
gen, _ := fft.Generator(2 * domain.Cardinality)
600-
// genBits := gen.Bits()
601-
// limbs := icicle_core.ConvertUint64ArrToUint32Arr(genBits[:])
602-
// var rouIcicle icicle_bn254.ScalarField
603-
// rouIcicle.FromLimbs(limbs)
604-
cfgCustom := icicle_ntt.GetDefaultNttConfig()
605-
cfg.CosetGen = ([8]uint32)(icicle_core.ConvertUint64ArrToUint32Arr(gen[:]))
606-
cfgCustom.Ordering = icicle_core.KRN
607-
icicle_ntt.Ntt(scalarsHostOut, icicle_core.KForward, &cfgCustom, scalarsHost)
527+
aDevice := <-computeADone
528+
bDevice := <-computeBDone
529+
cDevice := <-computeCDone
608530

609-
domain.FFT(a, fft.DIT, fft.OnCoset())
610-
611-
if !scalarsHost[0].Equal(&a[0]) {
612-
fmt.Println("computeH: A Forward failed")
613-
}
531+
vecCfg := icicle_core.DefaultVecOpsConfig()
532+
icicle_bn254.FromMontgomery(&aDevice)
533+
icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul)
534+
icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub)
535+
icicle_bn254.FromMontgomery(&aDevice)
536+
icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul)
614537

615-
domain.FFT(b, fft.DIT, fft.OnCoset())
616-
domain.FFT(c, fft.DIT, fft.OnCoset())
617-
618-
var den, one fr.Element
619-
one.SetOne()
620-
den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
621-
den.Sub(&den, &one).Inverse(&den)
622-
623-
// h = ifft_coset(ca o cb - cc)
624-
// reusing a to avoid unnecessary memory allocation
625-
utils.Parallelize(n, func(start, end int) {
626-
for i := start; i < end; i++ {
627-
a[i].Mul(&a[i], &b[i]).
628-
Sub(&a[i], &c[i]).
629-
Mul(&a[i], &den)
630-
}
631-
})
632-
633-
// ifft_coset
634-
domain.FFTInverse(a, fft.DIF, fft.OnCoset())
538+
cfg := icicle_ntt.GetDefaultNttConfig()
539+
cfg.CosetGen = configCosetGen
540+
cfg.Ordering = icicle_core.KNR
541+
icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice)
635542

636-
return a
543+
return aDevice
637544
}

0 commit comments

Comments
 (0)