From a70781b0f3b8d019dfbbd297ffa3e240b1ee402f Mon Sep 17 00:00:00 2001 From: Owen Conoly Date: Tue, 16 May 2023 23:38:40 -0400 Subject: [PATCH] put the new rewrite rule in its own pass --- fiat-c/src/secp256k1_dettman_64.c | 216 ++++++++++---------- src/BoundsPipeline.v | 3 +- src/Rewriter/All.v | 3 + src/Rewriter/Passes/ArithWithRelaxedCasts.v | 43 ++++ src/Rewriter/Rules.v | 36 +++- src/Rewriter/RulesProofs.v | 107 ++++------ 6 files changed, 213 insertions(+), 195 deletions(-) create mode 100644 src/Rewriter/Passes/ArithWithRelaxedCasts.v diff --git a/fiat-c/src/secp256k1_dettman_64.c b/fiat-c/src/secp256k1_dettman_64.c index 8e25f88b06..10a8d67764 100644 --- a/fiat-c/src/secp256k1_dettman_64.c +++ b/fiat-c/src/secp256k1_dettman_64.c @@ -43,75 +43,71 @@ FIAT_SECP256K1_DETTMAN_FIAT_EXTENSION typedef unsigned __int128 fiat_secp256k1_d static FIAT_SECP256K1_DETTMAN_FIAT_INLINE void fiat_secp256k1_dettman_mul(uint64_t out1[5], const uint64_t arg1[5], const uint64_t arg2[5]) { fiat_secp256k1_dettman_uint128 x1; uint64_t x2; - uint64_t x3; - fiat_secp256k1_dettman_uint128 x4; + fiat_secp256k1_dettman_uint128 x3; + uint64_t x4; uint64_t x5; - uint64_t x6; - fiat_secp256k1_dettman_uint128 x7; + fiat_secp256k1_dettman_uint128 x6; + uint64_t x7; uint64_t x8; uint64_t x9; uint64_t x10; - uint64_t x11; - fiat_secp256k1_dettman_uint128 x12; + fiat_secp256k1_dettman_uint128 x11; + uint64_t x12; uint64_t x13; - uint64_t x14; - fiat_secp256k1_dettman_uint128 x15; + fiat_secp256k1_dettman_uint128 x14; + uint64_t x15; uint64_t x16; - uint64_t x17; - fiat_secp256k1_dettman_uint128 x18; + fiat_secp256k1_dettman_uint128 x17; + uint64_t x18; uint64_t x19; - uint64_t x20; - fiat_secp256k1_dettman_uint128 x21; + fiat_secp256k1_dettman_uint128 x20; + uint64_t x21; uint64_t x22; - uint64_t x23; - fiat_secp256k1_dettman_uint128 x24; - uint64_t x25; + fiat_secp256k1_dettman_uint128 x23; + uint64_t x24; + fiat_secp256k1_dettman_uint128 x25; uint64_t x26; - fiat_secp256k1_dettman_uint128 x27; - uint64_t x28; + uint64_t x27; + fiat_secp256k1_dettman_uint128 x28; uint64_t x29; - fiat_secp256k1_dettman_uint128 x30; + uint64_t x30; uint64_t x31; - uint64_t x32; - uint64_t x33; x1 = ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[4])); x2 = (uint64_t)(x1 >> 64); - x3 = (uint64_t)(x1 & UINT64_C(0xffffffffffffffff)); - x4 = ((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[0]))))) + ((fiat_secp256k1_dettman_uint128)x3 * UINT64_C(0x1000003d10))); - x5 = (uint64_t)(x4 >> 52); - x6 = (uint64_t)(x4 & UINT64_C(0xfffffffffffff)); - x7 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[0])))))) + x5) + ((fiat_secp256k1_dettman_uint128)x2 * UINT64_C(0x1000003d10000))); - x8 = (uint64_t)(x7 >> 52); - x9 = (uint64_t)(x7 & UINT64_C(0xfffffffffffff)); - x10 = (x9 >> 48); - x11 = (x9 & UINT64_C(0xffffffffffff)); - x12 = ((((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[2])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[1]))))) + x8); - x13 = (uint64_t)(x12 >> 52); - x14 = (uint64_t)(x12 & UINT64_C(0xfffffffffffff)); - x15 = (((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[0])) + ((fiat_secp256k1_dettman_uint128)((x14 << 4) + x10) * UINT64_C(0x1000003d1))); - x16 = (uint64_t)(x15 >> 52); - x17 = (uint64_t)(x15 & UINT64_C(0xfffffffffffff)); - x18 = ((((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[3])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[2])))) + x13); - x19 = (uint64_t)(x18 >> 52); - x20 = (uint64_t)(x18 & UINT64_C(0xfffffffffffff)); - x21 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[0]))) + x16) + ((fiat_secp256k1_dettman_uint128)x20 * UINT64_C(0x1000003d10))); - x22 = (uint64_t)(x21 >> 52); - x23 = (uint64_t)(x21 & UINT64_C(0xfffffffffffff)); - x24 = ((((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[4])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[3]))) + x19); - x25 = (uint64_t)(x24 >> 64); - x26 = (uint64_t)(x24 & UINT64_C(0xffffffffffffffff)); - x27 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[0])))) + x22) + ((fiat_secp256k1_dettman_uint128)x26 * UINT64_C(0x1000003d10))); - x28 = (uint64_t)(x27 >> 52); - x29 = (uint64_t)(x27 & UINT64_C(0xfffffffffffff)); - x30 = ((x6 + x28) + ((fiat_secp256k1_dettman_uint128)x25 * UINT64_C(0x1000003d10000))); - x31 = (uint64_t)(x30 >> 52); - x32 = (uint64_t)(x30 & UINT64_C(0xfffffffffffff)); - x33 = (x11 + x31); - out1[0] = x17; - out1[1] = x23; - out1[2] = x29; - out1[3] = x32; - out1[4] = x33; + x3 = ((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[0]))))) + ((fiat_secp256k1_dettman_uint128)(uint64_t)x1 * UINT64_C(0x1000003d10))); + x4 = (uint64_t)(x3 >> 52); + x5 = (uint64_t)(x3 & UINT64_C(0xfffffffffffff)); + x6 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[0])))))) + x4) + ((fiat_secp256k1_dettman_uint128)x2 * UINT64_C(0x1000003d10000))); + x7 = (uint64_t)(x6 >> 52); + x8 = (uint64_t)(x6 & UINT64_C(0xfffffffffffff)); + x9 = (x8 >> 48); + x10 = (x8 & UINT64_C(0xffffffffffff)); + x11 = ((((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[3])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[2])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[1]))))) + x7); + x12 = (uint64_t)(x11 >> 52); + x13 = (uint64_t)(x11 & UINT64_C(0xfffffffffffff)); + x14 = (((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[0])) + ((fiat_secp256k1_dettman_uint128)((x13 << 4) + x9) * UINT64_C(0x1000003d1))); + x15 = (uint64_t)(x14 >> 52); + x16 = (uint64_t)(x14 & UINT64_C(0xfffffffffffff)); + x17 = ((((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[4])) + (((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[3])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[2])))) + x12); + x18 = (uint64_t)(x17 >> 52); + x19 = (uint64_t)(x17 & UINT64_C(0xfffffffffffff)); + x20 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[0]))) + x15) + ((fiat_secp256k1_dettman_uint128)x19 * UINT64_C(0x1000003d10))); + x21 = (uint64_t)(x20 >> 52); + x22 = (uint64_t)(x20 & UINT64_C(0xfffffffffffff)); + x23 = ((((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg2[4])) + ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg2[3]))) + x18); + x24 = (uint64_t)(x23 >> 64); + x25 = (((((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg2[2])) + (((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg2[1])) + ((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg2[0])))) + x21) + ((fiat_secp256k1_dettman_uint128)(uint64_t)x23 * UINT64_C(0x1000003d10))); + x26 = (uint64_t)(x25 >> 52); + x27 = (uint64_t)(x25 & UINT64_C(0xfffffffffffff)); + x28 = ((x5 + x26) + ((fiat_secp256k1_dettman_uint128)x24 * UINT64_C(0x1000003d10000))); + x29 = (uint64_t)(x28 >> 52); + x30 = (uint64_t)(x28 & UINT64_C(0xfffffffffffff)); + x31 = (x10 + x29); + out1[0] = x16; + out1[1] = x22; + out1[2] = x27; + out1[3] = x30; + out1[4] = x31; } /* @@ -132,77 +128,73 @@ static FIAT_SECP256K1_DETTMAN_FIAT_INLINE void fiat_secp256k1_dettman_square(uin uint64_t x4; fiat_secp256k1_dettman_uint128 x5; uint64_t x6; - uint64_t x7; - fiat_secp256k1_dettman_uint128 x8; + fiat_secp256k1_dettman_uint128 x7; + uint64_t x8; uint64_t x9; - uint64_t x10; - fiat_secp256k1_dettman_uint128 x11; + fiat_secp256k1_dettman_uint128 x10; + uint64_t x11; uint64_t x12; uint64_t x13; uint64_t x14; - uint64_t x15; - fiat_secp256k1_dettman_uint128 x16; + fiat_secp256k1_dettman_uint128 x15; + uint64_t x16; uint64_t x17; - uint64_t x18; - fiat_secp256k1_dettman_uint128 x19; + fiat_secp256k1_dettman_uint128 x18; + uint64_t x19; uint64_t x20; - uint64_t x21; - fiat_secp256k1_dettman_uint128 x22; + fiat_secp256k1_dettman_uint128 x21; + uint64_t x22; uint64_t x23; - uint64_t x24; - fiat_secp256k1_dettman_uint128 x25; + fiat_secp256k1_dettman_uint128 x24; + uint64_t x25; uint64_t x26; - uint64_t x27; - fiat_secp256k1_dettman_uint128 x28; - uint64_t x29; + fiat_secp256k1_dettman_uint128 x27; + uint64_t x28; + fiat_secp256k1_dettman_uint128 x29; uint64_t x30; - fiat_secp256k1_dettman_uint128 x31; - uint64_t x32; + uint64_t x31; + fiat_secp256k1_dettman_uint128 x32; uint64_t x33; - fiat_secp256k1_dettman_uint128 x34; + uint64_t x34; uint64_t x35; - uint64_t x36; - uint64_t x37; x1 = ((arg1[3]) * 0x2); x2 = ((arg1[2]) * 0x2); x3 = ((arg1[1]) * 0x2); x4 = ((arg1[0]) * 0x2); x5 = ((fiat_secp256k1_dettman_uint128)(arg1[4]) * (arg1[4])); x6 = (uint64_t)(x5 >> 64); - x7 = (uint64_t)(x5 & UINT64_C(0xffffffffffffffff)); - x8 = ((((fiat_secp256k1_dettman_uint128)x4 * (arg1[3])) + ((fiat_secp256k1_dettman_uint128)x3 * (arg1[2]))) + ((fiat_secp256k1_dettman_uint128)x7 * UINT64_C(0x1000003d10))); - x9 = (uint64_t)(x8 >> 52); - x10 = (uint64_t)(x8 & UINT64_C(0xfffffffffffff)); - x11 = (((((fiat_secp256k1_dettman_uint128)x4 * (arg1[4])) + (((fiat_secp256k1_dettman_uint128)x3 * (arg1[3])) + ((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg1[2])))) + x9) + ((fiat_secp256k1_dettman_uint128)x6 * UINT64_C(0x1000003d10000))); - x12 = (uint64_t)(x11 >> 52); - x13 = (uint64_t)(x11 & UINT64_C(0xfffffffffffff)); - x14 = (x13 >> 48); - x15 = (x13 & UINT64_C(0xffffffffffff)); - x16 = ((((fiat_secp256k1_dettman_uint128)x3 * (arg1[4])) + ((fiat_secp256k1_dettman_uint128)x2 * (arg1[3]))) + x12); - x17 = (uint64_t)(x16 >> 52); - x18 = (uint64_t)(x16 & UINT64_C(0xfffffffffffff)); - x19 = (((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg1[0])) + ((fiat_secp256k1_dettman_uint128)((x18 << 4) + x14) * UINT64_C(0x1000003d1))); - x20 = (uint64_t)(x19 >> 52); - x21 = (uint64_t)(x19 & UINT64_C(0xfffffffffffff)); - x22 = ((((fiat_secp256k1_dettman_uint128)x2 * (arg1[4])) + ((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg1[3]))) + x17); - x23 = (uint64_t)(x22 >> 52); - x24 = (uint64_t)(x22 & UINT64_C(0xfffffffffffff)); - x25 = ((((fiat_secp256k1_dettman_uint128)x4 * (arg1[1])) + x20) + ((fiat_secp256k1_dettman_uint128)x24 * UINT64_C(0x1000003d10))); - x26 = (uint64_t)(x25 >> 52); - x27 = (uint64_t)(x25 & UINT64_C(0xfffffffffffff)); - x28 = (((fiat_secp256k1_dettman_uint128)x1 * (arg1[4])) + x23); - x29 = (uint64_t)(x28 >> 64); - x30 = (uint64_t)(x28 & UINT64_C(0xffffffffffffffff)); - x31 = (((((fiat_secp256k1_dettman_uint128)x4 * (arg1[2])) + ((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg1[1]))) + x26) + ((fiat_secp256k1_dettman_uint128)x30 * UINT64_C(0x1000003d10))); - x32 = (uint64_t)(x31 >> 52); - x33 = (uint64_t)(x31 & UINT64_C(0xfffffffffffff)); - x34 = ((x10 + x32) + ((fiat_secp256k1_dettman_uint128)x29 * UINT64_C(0x1000003d10000))); - x35 = (uint64_t)(x34 >> 52); - x36 = (uint64_t)(x34 & UINT64_C(0xfffffffffffff)); - x37 = (x15 + x35); - out1[0] = x21; - out1[1] = x27; - out1[2] = x33; - out1[3] = x36; - out1[4] = x37; + x7 = ((((fiat_secp256k1_dettman_uint128)x4 * (arg1[3])) + ((fiat_secp256k1_dettman_uint128)x3 * (arg1[2]))) + ((fiat_secp256k1_dettman_uint128)(uint64_t)x5 * UINT64_C(0x1000003d10))); + x8 = (uint64_t)(x7 >> 52); + x9 = (uint64_t)(x7 & UINT64_C(0xfffffffffffff)); + x10 = (((((fiat_secp256k1_dettman_uint128)x4 * (arg1[4])) + (((fiat_secp256k1_dettman_uint128)x3 * (arg1[3])) + ((fiat_secp256k1_dettman_uint128)(arg1[2]) * (arg1[2])))) + x8) + ((fiat_secp256k1_dettman_uint128)x6 * UINT64_C(0x1000003d10000))); + x11 = (uint64_t)(x10 >> 52); + x12 = (uint64_t)(x10 & UINT64_C(0xfffffffffffff)); + x13 = (x12 >> 48); + x14 = (x12 & UINT64_C(0xffffffffffff)); + x15 = ((((fiat_secp256k1_dettman_uint128)x3 * (arg1[4])) + ((fiat_secp256k1_dettman_uint128)x2 * (arg1[3]))) + x11); + x16 = (uint64_t)(x15 >> 52); + x17 = (uint64_t)(x15 & UINT64_C(0xfffffffffffff)); + x18 = (((fiat_secp256k1_dettman_uint128)(arg1[0]) * (arg1[0])) + ((fiat_secp256k1_dettman_uint128)((x17 << 4) + x13) * UINT64_C(0x1000003d1))); + x19 = (uint64_t)(x18 >> 52); + x20 = (uint64_t)(x18 & UINT64_C(0xfffffffffffff)); + x21 = ((((fiat_secp256k1_dettman_uint128)x2 * (arg1[4])) + ((fiat_secp256k1_dettman_uint128)(arg1[3]) * (arg1[3]))) + x16); + x22 = (uint64_t)(x21 >> 52); + x23 = (uint64_t)(x21 & UINT64_C(0xfffffffffffff)); + x24 = ((((fiat_secp256k1_dettman_uint128)x4 * (arg1[1])) + x19) + ((fiat_secp256k1_dettman_uint128)x23 * UINT64_C(0x1000003d10))); + x25 = (uint64_t)(x24 >> 52); + x26 = (uint64_t)(x24 & UINT64_C(0xfffffffffffff)); + x27 = (((fiat_secp256k1_dettman_uint128)x1 * (arg1[4])) + x22); + x28 = (uint64_t)(x27 >> 64); + x29 = (((((fiat_secp256k1_dettman_uint128)x4 * (arg1[2])) + ((fiat_secp256k1_dettman_uint128)(arg1[1]) * (arg1[1]))) + x25) + ((fiat_secp256k1_dettman_uint128)(uint64_t)x27 * UINT64_C(0x1000003d10))); + x30 = (uint64_t)(x29 >> 52); + x31 = (uint64_t)(x29 & UINT64_C(0xfffffffffffff)); + x32 = ((x9 + x30) + ((fiat_secp256k1_dettman_uint128)x28 * UINT64_C(0x1000003d10000))); + x33 = (uint64_t)(x32 >> 52); + x34 = (uint64_t)(x32 & UINT64_C(0xfffffffffffff)); + x35 = (x14 + x33); + out1[0] = x20; + out1[1] = x26; + out1[2] = x31; + out1[3] = x34; + out1[4] = x35; } diff --git a/src/BoundsPipeline.v b/src/BoundsPipeline.v index 2590de9ab1..bfbc990359 100644 --- a/src/BoundsPipeline.v +++ b/src/BoundsPipeline.v @@ -806,7 +806,8 @@ Module Pipeline. match E' with (* rewrites after bounds relaxation---add a new one named arithWithRelaxedCasts or something. *) | inl E - => (E <- match split_mul_to with + => (E <- RewriteAndEliminateDeadAndInline "RewriteArithWithRelaxedCasts" (RewriteRules.RewriteArithWithRelaxedCasts opts) with_dead_code_elimination with_subst01 with_let_bind_return E; + E <- match split_mul_to with | Some (max_bitwidth, lgcarrymax) => wrap_debug_rewrite "RewriteMulSplit" (RewriteRules.RewriteMulSplit max_bitwidth lgcarrymax opts) E | None => Debug.ret E diff --git a/src/Rewriter/All.v b/src/Rewriter/All.v index b4159ffb28..1ef44882b7 100644 --- a/src/Rewriter/All.v +++ b/src/Rewriter/All.v @@ -2,6 +2,7 @@ Require Import Crypto.Rewriter.Passes.NBE. Require Import Crypto.Rewriter.Passes.AddAssocLeft. Require Import Crypto.Rewriter.Passes.Arith. Require Import Crypto.Rewriter.Passes.ArithWithCasts. +Require Import Crypto.Rewriter.Passes.ArithWithRelaxedCasts. Require Import Crypto.Rewriter.Passes.StripLiteralCasts. Require Import Crypto.Rewriter.Passes.FlattenThunkedRects. Require Import Crypto.Rewriter.Passes.MulSplit. @@ -17,6 +18,7 @@ Module Compilers. Export AddAssocLeft.Compilers. Export Arith.Compilers. Export ArithWithCasts.Compilers. + Export ArithWithRelaxedCasts.Compilers. Export StripLiteralCasts.Compilers. Export FlattenThunkedRects.Compilers. Export MulSplit.Compilers. @@ -32,6 +34,7 @@ Module Compilers. Export AddAssocLeft.Compilers.RewriteRules. Export Arith.Compilers.RewriteRules. Export ArithWithCasts.Compilers.RewriteRules. + Export ArithWithRelaxedCasts.Compilers.RewriteRules. Export StripLiteralCasts.Compilers.RewriteRules. Export FlattenThunkedRects.Compilers.RewriteRules. Export MulSplit.Compilers.RewriteRules. diff --git a/src/Rewriter/Passes/ArithWithRelaxedCasts.v b/src/Rewriter/Passes/ArithWithRelaxedCasts.v new file mode 100644 index 0000000000..7c56f6b212 --- /dev/null +++ b/src/Rewriter/Passes/ArithWithRelaxedCasts.v @@ -0,0 +1,43 @@ +Require Import Rewriter.Language.Language. +Require Import Crypto.Language.API. +Require Import Rewriter.Language.Wf. +Require Import Crypto.Language.WfExtra. +Require Import Crypto.Rewriter.AllTacticsExtra. +Require Import Crypto.Rewriter.RulesProofs. + +Module Compilers. + Import Language.Compilers. + Import Language.API.Compilers. + Import Language.Wf.Compilers. + Import Language.WfExtra.Compilers. + Import Rewriter.AllTacticsExtra.Compilers.RewriteRules.GoalType. + Import Rewriter.AllTactics.Compilers.RewriteRules.Tactic. + Import Compilers.Classes. + + Module Import RewriteRules. + Section __. + Definition VerifiedRewriterArithWithRelaxedCasts : VerifiedRewriter_with_args false false true arith_with_relaxed_casts_rewrite_rules_proofs. + Proof using All. make_rewriter. Defined. + + Definition default_opts := Eval hnf in @default_opts VerifiedRewriterArithWithRelaxedCasts. + Let optsT := Eval hnf in optsT VerifiedRewriterArithWithRelaxedCasts. + + Definition RewriteArithWithRelaxedCasts (opts : optsT) {t : API.type} := Eval hnf in @Rewrite VerifiedRewriterArithWithRelaxedCasts opts t. + + Lemma Wf_RewriteArithWithRelaxedCasts opts {t} e (Hwf : Wf e) : Wf (@RewriteArithWithRelaxedCasts opts t e). + Proof. now apply VerifiedRewriterArithWithRelaxedCasts. Qed. + + Lemma Interp_RewriteArithWithRelaxedCasts opts {t} e (Hwf : Wf e) : API.Interp (@RewriteArithWithRelaxedCasts opts t e) == API.Interp e. + Proof. now apply VerifiedRewriterArithWithRelaxedCasts. Qed. + End __. + End RewriteRules. + + Module Export Hints. +#[global] + Hint Resolve Wf_RewriteArithWithRelaxedCasts : wf wf_extra. +#[global] + Hint Opaque RewriteArithWithRelaxedCasts : wf wf_extra interp interp_extra rewrite. +#[global] + Hint Rewrite @Interp_RewriteArithWithRelaxedCasts : interp interp_extra. + End Hints. +End Compilers. diff --git a/src/Rewriter/Rules.v b/src/Rewriter/Rules.v index 33ee22858b..0ce2464d76 100644 --- a/src/Rewriter/Rules.v +++ b/src/Rewriter/Rules.v @@ -561,6 +561,30 @@ Definition arith_with_casts_rewrite_rulesT (adc_no_carry_to_add : bool) : list ( ] ]%Z%zrange. +Definition arith_with_relaxed_casts_rewrite_rulesT : list (bool * Prop) + := Eval cbv [myapp mymap myflatten] in + myflatten + [mymap + dont_do_again + [(forall rland rm1 rv v, + rland.(upper) ∈ rm1 + -> rland.(upper) = Z.ones (Z.succ (Z.log2 rland.(upper))) + -> 0 = rland.(lower) + -> 0 = rv.(lower) + -> 0 <= rv.(upper) + -> (rv.(upper) + 1) mod (rland.(upper) + 1) = 0 + -> cstZ rland (Z.land (cstZ rv v) (cstZ rm1 ('rland.(upper)))) = cstZ rland v) + ; (forall rland rm1 rv v, + rland.(upper) ∈ rm1 + -> rland.(upper) = Z.ones (Z.succ (Z.log2 rland.(upper))) + -> 0 = rland.(lower) + -> 0 = rv.(lower) + -> 0 <= rv.(upper) + -> (rv.(upper) + 1) mod (rland.(upper) + 1) = 0 + -> cstZ rland (Z.land (cstZ rm1 ('rland.(upper))) (cstZ rv v)) = cstZ rland v) + ] + ]%Z%zrange. + Definition strip_literal_casts_rewrite_rulesT : list (bool * Prop) := [dont_do_again (forall rx x, x ∈ rx -> cstZ rx ('x) = 'x)]%Z%zrange. @@ -1060,17 +1084,7 @@ Section with_bitwidth. [mymap dont_do_again [] ; mymap do_again - [ - (* owen put this here, and he needs to remove it. *) - (forall rland rm1 rv v, - rland.(upper) ∈ rm1 - -> rland.(upper) = Z.ones (Z.succ (Z.log2 rland.(upper))) - -> 0 = rland.(lower) - -> 0 = rv.(lower) - -> 0 <= rv.(upper) - -> (rv.(upper) + 1) mod (rland.(upper) + 1) = 0 - -> cstZ rland (Z.land (cstZ rv v) (cstZ rm1 ('rland.(upper)))) = cstZ rland v) - ; (forall A B x y, @fst A B (x, y) = x) + [(forall A B x y, @fst A B (x, y) = x) ; (forall A B x y, @snd A B (x, y) = y) (** In order to avoid tautological compares, we need to deal with carry/borrows being 0 *) ; (forall r0 s x y r1 r2, diff --git a/src/Rewriter/RulesProofs.v b/src/Rewriter/RulesProofs.v index 8edff0131f..1029eb9b55 100644 --- a/src/Rewriter/RulesProofs.v +++ b/src/Rewriter/RulesProofs.v @@ -564,63 +564,52 @@ Local Ltac do_clear_nia x y r H H' := => clear -Hx Hy Hm Hr H' H; nia end. -Search ident.cast. Lemma arith_with_casts_rewrite_rules_proofs (adc_no_carry_to_add : bool) : PrimitiveHList.hlist (@snd bool Prop) (arith_with_casts_rewrite_rulesT adc_no_carry_to_add). Proof using Type. start_proof; auto; intros; try lia. all: repeat interp_good_t_step_related. - (*11: { replace (ident.cast rland v) with (ident.cast rland (ident.cast rv v)). - - interp_good_t_step_arith. interp_good_t_step_arith. rewrite Z.land_ones. - + replace (2 ^ Z.succ (Z.log2 (upper rland))) with (upper rland + 1). - -- rewrite <- ident.cast_out_of_bounds_simple_0_mod. - ++ destruct rland. simpl in H1. subst. apply ident.cast_idempotent. - ++ rewrite H0. apply Ones.Z.ones_nonneg. remember (Z.log2_nonneg (upper rland)). lia. - -- remember (Z.log2 _) as x. rewrite H0. subst. rewrite Z.ones_equiv. lia. - + remember (Z.log2_nonneg (upper rland)). lia. - - Search ident.cast. destruct rland. destruct rv. simpl in *. subst. - (*Search ident.cast. Search ZRange.normalize. - repeat rewrite <- (ident.cast_normalize r[0~>upper]). - repeat rewrite <- (ident.cast_normalize r[0~>upper0]).*) - Check ident.cast_out_of_bounds_simple_0_mod. - repeat rewrite ident.cast_out_of_bounds_simple_0_mod. - + Search ((_ mod _) mod _). rewrite <- Z.mod_div_mod_full. - -- reflexivity. - -- Search Z.divide. rewrite <- Z.mod_divide_full. assumption. - + Search Z.ones. rewrite H0. apply Ones.Z.ones_nonneg. - remember (Z.log2_nonneg (upper)). lia. - + lia. - + Search Z.ones. rewrite H0. apply Ones.Z.ones_nonneg. - remember (Z.log2_nonneg (upper)). lia. - }*) - (* Search Z.ones. rewrite H0. apply Ones.Z.ones_nonneg. - remember (Z.log2_nonneg (upper)). lia. - } - cbv [Z.succ]. Check Z.ones_equiv. rewrite <- ident.cast_out_of_bounds_simple_0_mod. - Search (Z.ones (Z.succ _)). rewrite Z.ones_equiv. - rewrite Z.land_ones. - - - all: repeat interp_good_t_step_arith. - Search (Z.land _ (Z.ones _)). rewrite Z.land_ones. - + Search (ident.cast _ _ = _ mod _). cbv [Z.succ]. replace (2^(Z.log2 (upper rland) + 1)) with ((upper rland) + 1). - -- rewrite <- ident.cast_out_of_bounds_simple_0_mod. - ++ Search (ident.cast _ (ident.cast _ _)). - replace r[0~>upper rland]%zrange with rland. - --- rewrite ident.cast_idempotent. rep apply ident.cast_idempotent. - --- reflect_hyps. destruct rland. simpl in *. subst. reflexivity. - ++ reflect_hyps. simpl in *. Search (0 <= Z.ones _). rewrite H1. - apply Ones.Z.ones_nonneg. Search (0 <= Z.log2 _). remember (Z.log2_nonneg (upper rland)). lia. - -- remember (Z.log2 _) as x. rewrite H1. subst. Search Z.ones. rewrite Z.ones_equiv. cbv [Z.succ Z.pred]. lia. - + remember (Z.log2_nonneg (upper rland)). lia. - - Check Ones.Z.ones_succ. remember (Ones.Z.ones_nonneg (Z.succ (Z.l lia. - interp_good_t_step_arith. - all: repeat interp_good_t_step_arith. assert (is_bounded_by_bool v rland = true). - { reflect_hyps. cbv [is_bounded_by_bool]. lia. } Search is_tighter_than_bool. reflect_hyps.*) all: repeat interp_good_t_step_arith. all: remove_casts; try fin_with_nia. all: try (reflect_hyps; lia). Qed. +Lemma relaxed_rules_work rland rm1 rv v : + is_bounded_by_bool (upper rland) (ZRange.normalize rm1) = true -> + upper rland = Z.ones (Z.succ (Z.log2 (upper rland))) -> + 0 = lower rland -> + 0 = lower rv -> + 0 <= upper rv -> + (upper rv + 1) mod (upper rland + 1) = 0 -> + ident.cast rland (ident.cast rv v &' ident.cast rm1 (upper rland)) = ident.cast rland v. +Proof. + intros H1 H2 H3 H4 H5 H6. + replace (ident.cast rland v) with (ident.cast rland (ident.cast rv v)). + - interp_good_t_step_arith. interp_good_t_step_arith. rewrite Z.land_ones. + + replace (2 ^ Z.succ (Z.log2 (upper rland))) with (upper rland + 1). + -- rewrite <- ident.cast_out_of_bounds_simple_0_mod. + ++ destruct rland. simpl in *. subst. apply ident.cast_idempotent. + ++ rewrite H2. apply Ones.Z.ones_nonneg. remember (Z.log2_nonneg (upper rland)). lia. + -- remember (Z.log2 _) as x. rewrite H2. subst. rewrite Z.ones_equiv. lia. + + remember (Z.log2_nonneg (upper rland)). lia. + - destruct rland. destruct rv. simpl in *. subst. + repeat rewrite ident.cast_out_of_bounds_simple_0_mod. + + rewrite <- Z.mod_div_mod_full. + -- reflexivity. + -- rewrite <- Z.mod_divide_full. assumption. + + rewrite H2. apply Ones.Z.ones_nonneg. remember (Z.log2_nonneg (upper)). lia. + + lia. + + rewrite H2. apply Ones.Z.ones_nonneg. remember (Z.log2_nonneg (upper)). lia. +Qed. + +Lemma arith_with_relaxed_casts_rewrite_rules_proofs + : PrimitiveHList.hlist (@snd bool Prop) arith_with_relaxed_casts_rewrite_rulesT. +Proof using Type. + start_proof; auto; intros; try lia. + - apply relaxed_rules_work; assumption. + - rewrite Z.land_comm. apply relaxed_rules_work; assumption. +Qed. + Lemma strip_literal_casts_rewrite_rules_proofs : PrimitiveHList.hlist (@snd bool Prop) strip_literal_casts_rewrite_rulesT. Proof using Type. @@ -861,30 +850,6 @@ Proof using Type. by (intros; apply Z.pow_gt_lin_r; auto with zarith). start_proof; auto; intros; try lia. - 1: { - replace (ident.cast rland v) with (ident.cast rland (ident.cast rv v)). - - interp_good_t_step_arith. interp_good_t_step_arith. rewrite Z.land_ones. - + replace (2 ^ Z.succ (Z.log2 (upper rland))) with (upper rland + 1). - -- rewrite <- ident.cast_out_of_bounds_simple_0_mod. - ++ destruct rland. simpl in *. subst. apply ident.cast_idempotent. - ++ rewrite H2. apply Ones.Z.ones_nonneg. remember (Z.log2_nonneg (upper rland)). lia. - -- remember (Z.log2 _) as x. rewrite H2. subst. rewrite Z.ones_equiv. lia. - + remember (Z.log2_nonneg (upper rland)). lia. - - Search ident.cast. destruct rland. destruct rv. simpl in *. subst. - (*Search ident.cast. Search ZRange.normalize. - repeat rewrite <- (ident.cast_normalize r[0~>upper]). - repeat rewrite <- (ident.cast_normalize r[0~>upper0]).*) - Check ident.cast_out_of_bounds_simple_0_mod. - repeat rewrite ident.cast_out_of_bounds_simple_0_mod. - + Search ((_ mod _) mod _). rewrite <- Z.mod_div_mod_full. - -- reflexivity. - -- Search Z.divide. rewrite <- Z.mod_divide_full. assumption. - + Search Z.ones. rewrite H2. apply Ones.Z.ones_nonneg. - remember (Z.log2_nonneg (upper)). lia. - + lia. - + Search Z.ones. rewrite H2. apply Ones.Z.ones_nonneg. - remember (Z.log2_nonneg (upper)). lia. - } all: repeat interp_good_t_step_related. all: systematically_handle_casts; autorewrite with zsimplify_fast; try reflexivity. all: subst; rewrite !ident.platform_specific_cast_0_is_mod, ?Z.sub_add, ?Z.mod_mod by lia; try reflexivity.