diff --git a/backend/groth16/bn254/prove_gpu.go b/backend/groth16/bn254/prove_gpu.go index f6c5952d99..e067c0d450 100644 --- a/backend/groth16/bn254/prove_gpu.go +++ b/backend/groth16/bn254/prove_gpu.go @@ -152,6 +152,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesASize := len(wireValuesA) scalarBytes := wireValuesASize * fr.Bytes + // Copy scalars to the device and retain ptr to them copyDone := make(chan unsafe.Pointer, 1) iciclegnark.CopyToDevice(wireValuesA, scalarBytes, copyDone) wireValuesADevicePtr := <-copyDone @@ -175,6 +176,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesBSize := len(wireValuesB) scalarBytes := wireValuesBSize * fr.Bytes + // Copy scalars to the device and retain ptr to them copyDone := make(chan unsafe.Pointer, 1) iciclegnark.CopyToDevice(wireValuesB, scalarBytes, copyDone) wireValuesBDevicePtr := <-copyDone @@ -237,6 +239,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var krs, krs2, p1 curve.G1Jac sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + // check for small circuits as iciclegnark doesn't handle zero sizes well if len(pk.G1.Z) > 0 { if krs2, _, err = iciclegnark.MsmOnDevice(h, pk.G1Device.Z, sizeH, true); err != nil { return err @@ -249,6 +252,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) scalars := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + // filter zero/infinity points since icicle doesn't handle them + // See https://github.com/ingonyama-zk/icicle/issues/169 for more info for _, indexToRemove := range pk.InfinityPointIndicesK { scalars = append(scalars[:indexToRemove], scalars[indexToRemove+1:]...) } @@ -318,6 +323,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b log.Debug().Dur("took", time.Since(start)).Msg("prover done") + // free device/GPU memory that is not needed for future proofs (scalars/hpoly) go func() { iciclegnark.FreeDevicePointer(wireValuesADevice.P) iciclegnark.FreeDevicePointer(wireValuesBDevice.P) diff --git a/backend/groth16/bn254/setup_gpu.go b/backend/groth16/bn254/setup_gpu.go index e051890080..7826d04ff2 100644 --- a/backend/groth16/bn254/setup_gpu.go +++ b/backend/groth16/bn254/setup_gpu.go @@ -356,6 +356,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // set domain pk.Domain = *domain + // Move static values (points, domain, hpoly denom) to the device/GPU err = pk.setupDevicePointers() return nil