Skip to content

Commit b23a4dd

Browse files
Add overflow checks to Scalar methods + add CMov method
1 parent b689178 commit b23a4dd

File tree

1 file changed

+102
-16
lines changed

1 file changed

+102
-16
lines changed

Src/Autarkysoft.Bitcoin/Cryptography/EllipticCurve/Scalar8x32.cs

+102-16
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ private static unsafe bool SetB32(byte* pt, uint* r)
159159
uint of = GetOverflow(r);
160160
Debug.Assert(of == 0 || of == 1);
161161
Reduce(r, of);
162-
Debug.Assert(GetOverflow(r) == 0);
163162
return of != 0;
164163
}
165164

@@ -238,22 +237,48 @@ private static unsafe bool SetB32(byte* pt, uint* r)
238237
/// <summary>
239238
/// Returns if the value is equal to zero
240239
/// </summary>
241-
public bool IsZero => (b0 | b1 | b2 | b3 | b4 | b5 | b6 | b7) == 0;
240+
public bool IsZero
241+
{
242+
get
243+
{
244+
Debug.Assert(GetOverflow(this) == 0);
245+
return (b0 | b1 | b2 | b3 | b4 | b5 | b6 | b7) == 0;
246+
}
247+
}
248+
242249
/// <summary>
243250
/// Returns if the value is equal to one
244251
/// </summary>
245-
public bool IsOne => ((b0 ^ 1) | b1 | b2 | b3 | b4 | b5 | b6 | b7) == 0;
252+
public bool IsOne
253+
{
254+
get
255+
{
256+
Debug.Assert(GetOverflow(this) == 0);
257+
return ((b0 ^ 1) | b1 | b2 | b3 | b4 | b5 | b6 | b7) == 0;
258+
}
259+
}
260+
246261
/// <summary>
247262
/// Returns if the value is even
248263
/// </summary>
249-
public bool IsEven => (b0 & 1) == 0;
264+
public bool IsEven
265+
{
266+
get
267+
{
268+
Debug.Assert(GetOverflow(this) == 0);
269+
return (b0 & 1) == 0;
270+
}
271+
}
272+
250273
/// <summary>
251274
/// Returns if this scalar is higher than the group order divided by 2
252275
/// </summary>
253276
public bool IsHigh
254277
{
255278
get
256279
{
280+
Debug.Assert(GetOverflow(this) == 0);
281+
257282
int yes = 0;
258283
int no = 0;
259284
no |= (b7 < NH7 ? 1 : 0);
@@ -350,6 +375,8 @@ private static unsafe void Reduce(uint* r, uint overflow)
350375
r[6] = (uint)t; t >>= 32;
351376
t += r[7];
352377
r[7] = (uint)t;
378+
379+
Debug.Assert(GetOverflow(r) == 0);
353380
}
354381

355382
/// <summary>
@@ -360,6 +387,9 @@ private static unsafe void Reduce(uint* r, uint overflow)
360387
/// <returns>Result</returns>
361388
public unsafe Scalar8x32 Add(in Scalar8x32 other, out bool overflow)
362389
{
390+
Debug.Assert(GetOverflow(this) == 0);
391+
Debug.Assert(GetOverflow(other) == 0);
392+
363393
uint* r = stackalloc uint[8];
364394

365395
ulong t = (ulong)b0 + other.b0;
@@ -397,7 +427,9 @@ public unsafe Scalar8x32 Add(in Scalar8x32 other, out bool overflow)
397427
/// <returns></returns>
398428
public Scalar8x32 CAddBit(uint bit, uint flag)
399429
{
430+
Debug.Assert(GetOverflow(this) == 0);
400431
Debug.Assert(bit < 256);
432+
401433
bit += (flag - 1) & 0x100; // forcing (bit >> 5) > 7 makes this a noop
402434
ulong t = (ulong)b0 + (((bit >> 5) == 0 ? 1U : 0) << ((int)bit & 0x1F));
403435
uint r0 = (uint)t; t >>= 32;
@@ -424,12 +456,14 @@ public Scalar8x32 CAddBit(uint bit, uint flag)
424456

425457
public static unsafe uint GetBits(uint* pt, int offset, int count)
426458
{
459+
Debug.Assert(GetOverflow(pt) == 0);
427460
Debug.Assert((offset + count - 1) >> 5 == offset >> 5);
428461
return (pt[offset >> 5] >> (offset & 0x1F)) & ((1U << count) - 1);
429462
}
430463

431464
public static unsafe uint GetBitsVar(uint* pt, int offset, int count)
432465
{
466+
Debug.Assert(GetOverflow(pt) == 0);
433467
Debug.Assert(count < 32);
434468
Debug.Assert(offset + count <= 256);
435469
if ((offset + count - 1) >> 5 == offset >> 5)
@@ -630,6 +664,8 @@ public Scalar8x32 Inverse_old()
630664

631665
public Scalar8x32 InverseVar_old()
632666
{
667+
Debug.Assert(GetOverflow(this) == 0);
668+
633669
return Inverse_old();
634670
}
635671

@@ -641,6 +677,9 @@ public Scalar8x32 InverseVar_old()
641677
/// <returns></returns>
642678
public unsafe Scalar8x32 Multiply(in Scalar8x32 b)
643679
{
680+
Debug.Assert(GetOverflow(this) == 0);
681+
Debug.Assert(GetOverflow(b) == 0);
682+
644683
uint* l = stackalloc uint[16];
645684
Mult512(l, this, b);
646685
return Reduce512(l);
@@ -987,6 +1026,7 @@ private static Scalar8x32 Reduce(in Scalar8x32 r, uint overflow)
9871026
/// <returns>Shifted scalar</returns>
9881027
public unsafe Scalar8x32 Shr16(int shift, out uint ret)
9891028
{
1029+
Debug.Assert(GetOverflow(this) == 0);
9901030
Debug.Assert(shift > 0);
9911031
Debug.Assert(shift < 16);
9921032

@@ -1004,26 +1044,36 @@ public unsafe Scalar8x32 Shr16(int shift, out uint ret)
10041044
}
10051045

10061046

1047+
/// <summary>
1048+
/// Multiply a and b (without taking the modulus!), divide by 2**shift, and round to the nearest integer.
1049+
/// Shift must be at least 256
1050+
/// </summary>
1051+
/// <param name="a">A</param>
1052+
/// <param name="b">B</param>
1053+
/// <param name="shift">Shift must be at least 256</param>
1054+
/// <returns>Result</returns>
10071055
public static unsafe Scalar8x32 MulShiftVar(in Scalar8x32 a, in Scalar8x32 b, int shift)
10081056
{
1057+
Debug.Assert(GetOverflow(a) == 0);
1058+
Debug.Assert(GetOverflow(b) == 0);
10091059
Debug.Assert(shift >= 256);
10101060

10111061
uint* l = stackalloc uint[16];
10121062
Mult512(l, a, b);
10131063

1014-
int shiftlimbs = shift >> 5;
1064+
int shLimbs = shift >> 5;
10151065
int shiftlow = shift & 0x1F;
10161066
int shifthigh = 32 - shiftlow;
10171067
bool sb = shiftlow != 0;
10181068

1019-
uint r0 = shift < 512 ? (l[0 + shiftlimbs] >> shiftlow | (shift < 480 && sb ? (l[1 + shiftlimbs] << shifthigh) : 0)) : 0;
1020-
uint r1 = shift < 480 ? (l[1 + shiftlimbs] >> shiftlow | (shift < 448 && sb ? (l[2 + shiftlimbs] << shifthigh) : 0)) : 0;
1021-
uint r2 = shift < 448 ? (l[2 + shiftlimbs] >> shiftlow | (shift < 416 && sb ? (l[3 + shiftlimbs] << shifthigh) : 0)) : 0;
1022-
uint r3 = shift < 416 ? (l[3 + shiftlimbs] >> shiftlow | (shift < 384 && sb ? (l[4 + shiftlimbs] << shifthigh) : 0)) : 0;
1023-
uint r4 = shift < 384 ? (l[4 + shiftlimbs] >> shiftlow | (shift < 352 && sb ? (l[5 + shiftlimbs] << shifthigh) : 0)) : 0;
1024-
uint r5 = shift < 352 ? (l[5 + shiftlimbs] >> shiftlow | (shift < 320 && sb ? (l[6 + shiftlimbs] << shifthigh) : 0)) : 0;
1025-
uint r6 = shift < 320 ? (l[6 + shiftlimbs] >> shiftlow | (shift < 288 && sb ? (l[7 + shiftlimbs] << shifthigh) : 0)) : 0;
1026-
uint r7 = shift < 288 ? (l[7 + shiftlimbs] >> shiftlow) : 0;
1069+
uint r0 = shift < 512 ? (l[0 + shLimbs] >> shiftlow | (shift < 480 && sb ? (l[1 + shLimbs] << shifthigh) : 0)) : 0;
1070+
uint r1 = shift < 480 ? (l[1 + shLimbs] >> shiftlow | (shift < 448 && sb ? (l[2 + shLimbs] << shifthigh) : 0)) : 0;
1071+
uint r2 = shift < 448 ? (l[2 + shLimbs] >> shiftlow | (shift < 416 && sb ? (l[3 + shLimbs] << shifthigh) : 0)) : 0;
1072+
uint r3 = shift < 416 ? (l[3 + shLimbs] >> shiftlow | (shift < 384 && sb ? (l[4 + shLimbs] << shifthigh) : 0)) : 0;
1073+
uint r4 = shift < 384 ? (l[4 + shLimbs] >> shiftlow | (shift < 352 && sb ? (l[5 + shLimbs] << shifthigh) : 0)) : 0;
1074+
uint r5 = shift < 352 ? (l[5 + shLimbs] >> shiftlow | (shift < 320 && sb ? (l[6 + shLimbs] << shifthigh) : 0)) : 0;
1075+
uint r6 = shift < 320 ? (l[6 + shLimbs] >> shiftlow | (shift < 288 && sb ? (l[7 + shLimbs] << shifthigh) : 0)) : 0;
1076+
uint r7 = shift < 288 ? (l[7 + shLimbs] >> shiftlow) : 0;
10271077

10281078
Scalar8x32 r = new Scalar8x32(r0, r1, r2, r3, r4, r5, r6, r7);
10291079
return r.CAddBit(0, (l[(shift - 1) >> 5] >> ((shift - 1) & 0x1f)) & 1);
@@ -1035,6 +1085,8 @@ public static unsafe Scalar8x32 MulShiftVar(in Scalar8x32 a, in Scalar8x32 b, in
10351085
/// <returns></returns>
10361086
public Scalar8x32 Negate()
10371087
{
1088+
Debug.Assert(GetOverflow(this) == 0);
1089+
10381090
// uint32_t nonzero = 0xFFFFFFFFUL * (secp256k1_scalar_is_zero(a) == 0);
10391091
// Instead of a branch to get 1/0 then multiply it by the constant we use branch to get the resulting constant directly
10401092
// ie. we skip multiplication (optimization effect is minuscule though!)
@@ -1067,6 +1119,8 @@ public Scalar8x32 Negate()
10671119
/// <returns>-1 if the number was negated; otherwise 1.</returns>
10681120
public int NegateConditional(int flag, out Scalar8x32 result)
10691121
{
1122+
Debug.Assert(GetOverflow(this) == 0);
1123+
10701124
// If flag = 0 then mask = 00...00 so this is a no-op
10711125
// if flag = 1 then mask = 11...11 so this is identical Negate()
10721126
uint mask = (uint)-flag;
@@ -1094,6 +1148,27 @@ public int NegateConditional(int flag, out Scalar8x32 result)
10941148
}
10951149

10961150

1151+
1152+
public static Scalar8x32 CMov(in Scalar8x32 r, in Scalar8x32 a, uint flag)
1153+
{
1154+
Debug.Assert(GetOverflow(r) == 0);
1155+
Debug.Assert(GetOverflow(a) == 0);
1156+
1157+
uint mask0 = flag + ~0U;
1158+
uint mask1 = ~mask0;
1159+
uint r0 = (r.b0 & mask0) | (a.b0 & mask1);
1160+
uint r1 = (r.b1 & mask0) | (a.b1 & mask1);
1161+
uint r2 = (r.b2 & mask0) | (a.b2 & mask1);
1162+
uint r3 = (r.b3 & mask0) | (a.b3 & mask1);
1163+
uint r4 = (r.b4 & mask0) | (a.b4 & mask1);
1164+
uint r5 = (r.b5 & mask0) | (a.b5 & mask1);
1165+
uint r6 = (r.b6 & mask0) | (a.b6 & mask1);
1166+
uint r7 = (r.b7 & mask0) | (a.b7 & mask1);
1167+
1168+
return new Scalar8x32(r0, r1, r2, r3, r4, r5, r6, r7);
1169+
}
1170+
1171+
10971172
/// <summary>
10981173
/// Find r1 and r2 such that r1+r2*2^128 = k
10991174
/// </summary>
@@ -1102,6 +1177,8 @@ public int NegateConditional(int flag, out Scalar8x32 result)
11021177
/// <param name="r2"></param>
11031178
internal static void Split128(in Scalar8x32 k, out Scalar8x32 r1, out Scalar8x32 r2)
11041179
{
1180+
Debug.Assert(GetOverflow(k) == 0);
1181+
11051182
r1 = new Scalar8x32(k.b0, k.b1, k.b2, k.b3, 0, 0, 0, 0);
11061183
r2 = new Scalar8x32(k.b4, k.b5, k.b6, k.b7, 0, 0, 0, 0);
11071184
}
@@ -1178,6 +1255,8 @@ private static int MemCmpVar(Span<byte> s1, Span<byte> s2, int n)
11781255
/// <returns>32 bytes</returns>
11791256
public byte[] ToByteArray()
11801257
{
1258+
Debug.Assert(GetOverflow(this) == 0);
1259+
11811260
return new byte[32]
11821261
{
11831262
(byte)(b7 >> 24), (byte)(b7 >> 16), (byte)(b7 >> 8), (byte)b7,
@@ -1197,6 +1276,8 @@ public void WriteToSpan(Span<byte> stream)
11971276
if (stream.Length < 32)
11981277
throw new ArgumentOutOfRangeException();
11991278

1279+
Debug.Assert(GetOverflow(this) == 0);
1280+
12001281
stream[0] = (byte)(b7 >> 24); stream[1] = (byte)(b7 >> 16); stream[2] = (byte)(b7 >> 8); stream[3] = (byte)b7;
12011282
stream[4] = (byte)(b6 >> 24); stream[5] = (byte)(b6 >> 16); stream[6] = (byte)(b6 >> 8); stream[7] = (byte)b6;
12021283
stream[8] = (byte)(b5 >> 24); stream[9] = (byte)(b5 >> 16); stream[10] = (byte)(b5 >> 8); stream[11] = (byte)b5;
@@ -1223,9 +1304,14 @@ public void WriteToSpan(Span<byte> stream)
12231304
/// <param name="left">First scalar</param>
12241305
/// <param name="right">Second scalar</param>
12251306
/// <returns>True if the two scalars are equal; otherwise false.</returns>
1226-
public static bool operator ==(in Scalar8x32 left, in Scalar8x32 right) =>
1227-
((left.b0 ^ right.b0) | (left.b1 ^ right.b1) | (left.b2 ^ right.b2) | (left.b3 ^ right.b3) |
1228-
(left.b4 ^ right.b4) | (left.b5 ^ right.b5) | (left.b6 ^ right.b6) | (left.b7 ^ right.b7)) == 0;
1307+
public static bool operator ==(in Scalar8x32 left, in Scalar8x32 right)
1308+
{
1309+
Debug.Assert(GetOverflow(left) == 0);
1310+
Debug.Assert(GetOverflow(right) == 0);
1311+
1312+
return ((left.b0 ^ right.b0) | (left.b1 ^ right.b1) | (left.b2 ^ right.b2) | (left.b3 ^ right.b3) |
1313+
(left.b4 ^ right.b4) | (left.b5 ^ right.b5) | (left.b6 ^ right.b6) | (left.b7 ^ right.b7)) == 0;
1314+
}
12291315

12301316
/// <summary>
12311317
/// Returns if the two scalars are not equal to each other

0 commit comments

Comments
 (0)