@@ -648,74 +648,73 @@ public unsafe Scalar8x32 Multiply(in Scalar8x32 b)
648
648
649
649
650
650
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
651
- private static void Muladd ( uint a , uint b , ref ulong c0 , ref ulong c1 , ref ulong c2 )
651
+ private static void Muladd ( uint a , uint b , ref uint c0 , ref uint c1 , ref uint c2 )
652
652
{
653
653
ulong t = ( ulong ) a * b ;
654
654
uint th = ( uint ) ( t >> 32 ) ;
655
655
uint tl = ( uint ) t ;
656
656
657
- c0 = ( c0 & uint . MaxValue ) + tl ; // overflow is handled on the next line
658
- th += ( uint ) ( c0 >> 32 ) ; // at most 0xFFFFFFFF
659
- c1 += th ; // overflow is handled on the next line
660
- c2 += ( uint ) ( c1 >> 32 ) ; // never overflows by contract (verified in the next line)
657
+ c0 += tl ; // overflow is handled on the next line
658
+ th += ( c0 < tl ) ? 1U : 0U ; // at most 0xFFFFFFFF
659
+ c1 += th ; // overflow is handled on the next line
660
+ c2 += ( c1 < th ) ? 1U : 0U ; // never overflows by contract (verified in the next line)
661
661
662
- c1 &= uint . MaxValue ;
663
662
Debug . Assert ( ( c1 >= th ) || ( c2 != 0 ) ) ;
664
663
}
665
664
666
665
// Add a*b to the number defined by (c0,c1). c1 must never overflow.
667
666
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
668
- private static void MuladdFast ( uint a , uint b , ref ulong c0 , ref ulong c1 )
667
+ private static void MuladdFast ( uint a , uint b , ref uint c0 , ref uint c1 )
669
668
{
670
669
ulong t = ( ulong ) a * b ;
671
- uint th = ( uint ) ( t >> 32 ) ; // at most 0xFFFFFFFE
670
+ uint th = ( uint ) ( t >> 32 ) ; // at most 0xFFFFFFFE
672
671
uint tl = ( uint ) t ;
673
672
674
- c0 = ( c0 & uint . MaxValue ) + tl ; // overflow is handled on the next line
675
- th += ( uint ) ( c0 >> 32 ) ; // at most 0xFFFFFFFF
676
- c1 += th ; // never overflows by contract (verified in the next line)
673
+ c0 += tl ; // overflow is handled on the next line
674
+ th += ( c0 < tl ) ? 1U : 0U ; // at most 0xFFFFFFFF
675
+ c1 += th ; // never overflows by contract (verified in the next line)
677
676
678
677
Debug . Assert ( c1 >= th ) ;
679
678
}
680
679
681
680
// Add a to the number defined by (c0,c1,c2). c2 must never overflow.
682
681
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
683
- private static void SumAdd ( uint a , ref ulong c0 , ref ulong c1 , ref ulong c2 )
682
+ private static void SumAdd ( uint a , ref uint c0 , ref uint c1 , ref uint c2 )
684
683
{
685
- c0 = ( c0 & uint . MaxValue ) + a ; // overflow is handled on the next line
686
- uint over = ( uint ) ( c0 >> 32 ) ;
684
+ c0 += a ; // overflow is handled on the next line
685
+ uint over = ( c0 < a ) ? 1U : 0U ;
687
686
c1 += over ; // overflow is handled on the next line
688
- c2 += ( uint ) ( c1 >> 32 ) ; // never overflows by contract
687
+ c2 += ( c1 < over ) ? 1U : 0U ; // never overflows by contract
689
688
690
689
c1 &= uint . MaxValue ;
691
690
}
692
691
693
692
// Add a to the number defined by (c0,c1). c1 must never overflow, c2 must be zero.
694
693
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
695
- private static void SumaddFast ( uint a , ref ulong c0 , ref ulong c1 , ref ulong c2 )
694
+ private static void SumaddFast ( uint a , ref uint c0 , ref uint c1 , ref uint c2 )
696
695
{
697
- c0 = ( c0 & uint . MaxValue ) + a ; // overflow is handled on the next line
698
- c1 += ( uint ) ( c0 >> 32 ) ; // never overflows by contract (verified the next line)
696
+ c0 += a ; // overflow is handled on the next line
697
+ c1 += ( c0 < a ) ? 1U : 0U ; // never overflows by contract (verified the next line)
699
698
700
699
Debug . Assert ( ( c1 != 0 ) | ( c0 >= a ) ) ;
701
700
Debug . Assert ( c2 == 0 ) ;
702
701
}
703
702
704
703
// Extract the lowest 32 bits of (c0,c1,c2) into n, and left shift the number 32 bits.
705
704
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
706
- private static void Extract ( ref uint n , ref ulong c0 , ref ulong c1 , ref ulong c2 )
705
+ private static void Extract ( ref uint n , ref uint c0 , ref uint c1 , ref uint c2 )
707
706
{
708
- n = ( uint ) c0 ;
707
+ n = c0 ;
709
708
c0 = c1 ;
710
709
c1 = c2 ;
711
710
c2 = 0 ;
712
711
}
713
712
714
713
// Extract the lowest 32 bits of (c0,c1,c2) into n, and left shift the number 32 bits. c2 is required to be zero.
715
714
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
716
- private static void ExtractFast ( ref uint n , ref ulong c0 , ref ulong c1 , ref ulong c2 )
715
+ private static void ExtractFast ( ref uint n , ref uint c0 , ref uint c1 , ref uint c2 )
717
716
{
718
- n = ( uint ) c0 ;
717
+ n = c0 ;
719
718
c0 = c1 ;
720
719
c1 = 0 ;
721
720
@@ -725,7 +724,7 @@ private static void ExtractFast(ref uint n, ref ulong c0, ref ulong c1, ref ulon
725
724
private static unsafe void Mult512 ( uint * l , in Scalar8x32 a , in Scalar8x32 b )
726
725
{
727
726
// 96 bit accumulator
728
- ulong c0 = 0 , c1 = 0 , c2 = 0 ;
727
+ uint c0 = 0 , c1 = 0 , c2 = 0 ;
729
728
730
729
// l[0..15] = a[0..7] * b[0..7]
731
730
MuladdFast ( a . b0 , b . b0 , ref c0 , ref c1 ) ;
@@ -808,7 +807,7 @@ private static unsafe void Mult512(uint* l, in Scalar8x32 a, in Scalar8x32 b)
808
807
MuladdFast ( a . b7 , b . b7 , ref c0 , ref c1 ) ;
809
808
ExtractFast ( ref l [ 14 ] , ref c0 , ref c1 , ref c2 ) ;
810
809
Debug . Assert ( c1 == 0 ) ;
811
- l [ 15 ] = ( uint ) c0 ;
810
+ l [ 15 ] = c0 ;
812
811
}
813
812
814
813
private static unsafe Scalar8x32 Reduce512 ( uint * l )
@@ -819,7 +818,7 @@ private static unsafe Scalar8x32 Reduce512(uint* l)
819
818
uint p0 = 0 , p1 = 0 , p2 = 0 , p3 = 0 , p4 = 0 , p5 = 0 , p6 = 0 , p7 = 0 , p8 = 0 ;
820
819
821
820
// 96 bit accumulator
822
- ulong c0 , c1 , c2 ;
821
+ uint c0 , c1 , c2 ;
823
822
824
823
// Reduce 512 bits into 385
825
824
// m[0..12] = l[0..7] + n[0..7] * NC
@@ -884,7 +883,7 @@ private static unsafe Scalar8x32 Reduce512(uint* l)
884
883
SumaddFast ( n7 , ref c0 , ref c1 , ref c2 ) ;
885
884
ExtractFast ( ref m11 , ref c0 , ref c1 , ref c2 ) ;
886
885
Debug . Assert ( c0 <= 1 ) ;
887
- m12 = ( uint ) c0 ;
886
+ m12 = c0 ;
888
887
889
888
// Reduce 385 bits into 258
890
889
// p[0..8] = m[0..7] + m[8..12] * NC
@@ -928,7 +927,7 @@ private static unsafe Scalar8x32 Reduce512(uint* l)
928
927
MuladdFast ( m12 , NC3 , ref c0 , ref c1 ) ;
929
928
SumaddFast ( m11 , ref c0 , ref c1 , ref c2 ) ;
930
929
ExtractFast ( ref p7 , ref c0 , ref c1 , ref c2 ) ;
931
- p8 = ( uint ) c0 + m12 ;
930
+ p8 = c0 + m12 ;
932
931
Debug . Assert ( p8 <= 2 ) ;
933
932
934
933
// Reduce 258 bits into 256
0 commit comments