Skip to content

Commit

Permalink
accel FFT by 30+% with vartime endomorphism support
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Aug 28, 2023
1 parent ad04e6e commit 9b9fe7c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
75 changes: 71 additions & 4 deletions constantine/math/elliptic/ec_scalar_mul_vartime.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,28 @@

import
# Internals
./ec_endomorphism_accel,
../arithmetic,
../extension_fields,
../ec_shortweierstrass,
../io/io_bigints,
../../platforms/abstractions
../constants/zoo_endomorphisms,
../../platforms/abstractions,
../../math_arbitrary_precision/arithmetic/limbs_views

{.push raises: [].} # No exceptions allowed in core cryptographic operations
{.push checks: off.} # No defects due to array bound checking or signed integer overflow allowed

# Support files for testing Elliptic Curve arithmetic
# Bit operations
# ------------------------------------------------------------------------------

iterator unpackBE(scalarByte: byte): bool =
for i in countdown(7, 0):
yield bool((scalarByte shr i) and 1)

# Variable-time scalar multiplication
# ------------------------------------------------------------------------------

func scalarMul_doubleAdd_vartime*[EC](P: var EC, scalar: BigInt) {.tags:[VarTime].} =
## **Variable-time** Elliptic Curve Scalar Multiplication
##
Expand All @@ -39,11 +46,17 @@ func scalarMul_doubleAdd_vartime*[EC](P: var EC, scalar: BigInt) {.tags:[VarTime
Paff.affine(P)

P.setInf()
var isInf = true

for scalarByte in scalarCanonical:
for bit in unpackBE(scalarByte):
P.double()
if not isInf:
P.double()
if bit:
P += Paff
if isInf:
P.fromAffine(Paff)
else:
P += Paff

func scalarMul_minHammingWeight_vartime*[EC](P: var EC, scalar: BigInt) {.tags:[VarTime].} =
## **Variable-time** Elliptic Curve Scalar Multiplication
Expand Down Expand Up @@ -120,3 +133,57 @@ func scalarMul_minHammingWeight_windowed_vartime*[EC](P: var EC, scalar: BigInt,
P += tab[digit shr 1]
elif digit < 0:
P -= tab[-digit shr 1]

func scalarMul_vartime*[scalBits; EC](
P: var EC,
scalar: BigInt[scalBits]
) {.inline.} =
## Elliptic Curve Scalar Multiplication
##
## P <- [k] P
##
## This select the best algorithm depending on heuristics
## and the scalar being multiplied.
## The scalar MUST NOT be a secret as this does not use side-channel countermeasures
##
## This may use endomorphism acceleration.
## As endomorphism acceleration requires:
## - Cofactor to be cleared
## - 0 <= scalar < curve order
## Those conditions will be assumed.

when P.F is Fp:
const M = 2
elif P.F is Fp2:
const M = 4
else:
{.error: "Unconfigured".}

const L = scalBits.ceilDiv_vartime(M) + 1

let usedBits = scalar.limbs.getBits_vartime()

when scalBits == EC.F.C.getCurveOrderBitwidth and
EC.F.C.hasEndomorphismAcceleration():
if usedBits >= L:
# The constant-time implementation is extremely efficient
when EC.F is Fp:
P.scalarMulGLV_m2w2(scalar)
elif EC.F is Fp2:
P.scalarMulEndo(scalar)
else: # Curves defined on Fp^m with m > 2
{.error: "Unreachable".}
return

if 64 < usedBits:
# With a window of 5, we precompute 2^3 = 8 points
P.scalarMul_minHammingWeight_windowed_vartime(scalar, window = 5)
elif 8 <= usedBits and usedBits <= 64:
# With a window of 3, we precompute 2^1 = 2 points
P.scalarMul_minHammingWeight_windowed_vartime(scalar, window = 3)
elif usedBits == 1:
discard
elif usedBits == 0:
P.setInf()
else:
P.scalarMul_doubleAdd_vartime(scalar)
8 changes: 4 additions & 4 deletions constantine/math/polynomials/fft.nim
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ func simpleFT[EC; bits: static int](

for i in 0 ..< L:
last = vals[0]
last.scalarMul_minHammingWeight_windowed_vartime(rootsOfUnity[0], window = 5)
last.scalarMul_vartime(rootsOfUnity[0])
for j in 1 ..< L:
v = vals[j]
v.scalarMul_minHammingWeight_windowed_vartime(rootsOfUnity[(i*j) mod L], window = 5)
v.scalarMul_vartime(rootsOfUnity[(i*j) mod L])
last += v
output[i] = last

Expand All @@ -100,7 +100,7 @@ func fft_internal[EC; bits: static int](
for i in 0 ..< half:
# FFT Butterfly
y_times_root = output[i+half]
y_times_root .scalarMul_minHammingWeight_windowed_vartime(rootsOfUnity[i], window = 5)
y_times_root .scalarMul_vartime(rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root

Expand Down Expand Up @@ -144,7 +144,7 @@ func ifft*[EC](
invLen.invmod_vartime(invLen, EC.F.C.getCurveOrder())

for i in 0 ..< output.len:
output[i].scalarMul_minHammingWeight_windowed_vartime(invLen, window = 5)
output[i].scalarMul_vartime(invLen)

return FFTS_Success

Expand Down
8 changes: 4 additions & 4 deletions constantine/signatures/bls_signatures.nim
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ func update*[Pubkey, Sig: ECP_ShortW_Aff](

var randFactor{.noInit.}: BigInt[64]
randFactor.unmarshal(ctx.secureBlinding.toOpenArray(0, 7), bigEndian)
pkG1_jac.scalarMul_minHammingWeight_windowed_vartime(randFactor, window = 3)
sigG2_jac.scalarMul_minHammingWeight_windowed_vartime(randFactor, window = 3)
pkG1_jac.scalarMul_vartime(randFactor)
sigG2_jac.scalarMul_vartime(randFactor)

if ctx.aggSigOnce == false:
ctx.aggSig = sigG2_jac
Expand Down Expand Up @@ -455,8 +455,8 @@ func update*[Pubkey, Sig: ECP_ShortW_Aff](

var randFactor{.noInit.}: BigInt[64]
randFactor.unmarshal(ctx.secureBlinding.toOpenArray(0, 7), bigEndian)
hmsgG1_jac.scalarMul_minHammingWeight_windowed_vartime(randFactor, window = 3)
sigG1_jac.scalarMul_minHammingWeight_windowed_vartime(randFactor, window = 3)
hmsgG1_jac.scalarMul_vartime(randFactor)
sigG1_jac.scalarMul_vartime(randFactor)

if ctx.aggSigOnce == false:
ctx.aggSig = sigG1_jac
Expand Down
11 changes: 6 additions & 5 deletions research/kzg/fft_g1.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import
../../constantine/math/config/curves,
../../constantine/math/arithmetic,
../../constantine/math/ec_shortweierstrass,
../../constantine/math/elliptic/ec_scalar_mul_vartime,
../../constantine/math/io/[io_fields, io_ec, io_bigints],
# Research
./strided_views,
Expand Down Expand Up @@ -105,10 +106,10 @@ func simpleFT[EC; bits: static int](

for i in 0 ..< L:
last = vals[0]
last.scalarMul(rootsOfUnity[0])
last.scalarMul_vartime(rootsOfUnity[0])
for j in 1 ..< L:
v = vals[j]
v.scalarMul(rootsOfUnity[(i*j) mod L])
v.scalarMul_vartime(rootsOfUnity[(i*j) mod L])
last += v
output[i] = last

Expand All @@ -135,7 +136,7 @@ func fft_internal[EC; bits: static int](
for i in 0 ..< half:
# FFT Butterfly
y_times_root = output[i+half]
y_times_root .scalarMul(rootsOfUnity[i])
y_times_root .scalarMul_vartime(rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root

Expand Down Expand Up @@ -180,7 +181,7 @@ func ifft*[EC](
let inv = invLen.toBig()

for i in 0..< output.len:
output[i].scalarMul(inv)
output[i].scalarMul_vartime(inv)

return FFTS_Success

Expand Down Expand Up @@ -262,7 +263,7 @@ when isMainModule:

warmup()

for scale in 4 ..< 10:
for scale in 4 ..< 16:
# Setup

let desc = FFTDescriptor[EC_G1].init(uint8 scale)
Expand Down

0 comments on commit 9b9fe7c

Please sign in to comment.