Skip to content

Commit fb8c6cb

Browse files
experiment with operator overloading
1 parent e46b917 commit fb8c6cb

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed
+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
Require Import Coq.ZArith.ZArith.
2+
Require Import coqutil.Map.Interface.
3+
Require Import coqutil.Word.Interface.
4+
Require Import bedrock2.Map.SeparationLogic.
5+
6+
7+
Ltac get_type x :=
8+
let tx := type of x in
9+
let __ := match constr:(Set) with
10+
| _ => tryif has_evar tx
11+
then fail 1 "type of" x "should be fully determined, but is" tx
12+
else idtac
13+
end in
14+
tx.
15+
16+
(* might insert a coercion, or use subsumption to turn `nat: Set` into `nat: Type` *)
17+
Ltac has_type x T :=
18+
tryif (let _ := constr:(x: T) in idtac) then idtac
19+
else fail 0 x "does not have type" T.
20+
21+
Ltac binop cl x y :=
22+
(* At the moment, we require both the type of x and y to be fully determined
23+
(no evars), but later we might be more adventurous and allow the type of
24+
one operand to determine the type of the other.
25+
But then we'll have to reconsider what to do about coercions. *)
26+
let tx := get_type x in
27+
let ty := get_type y in
28+
(* unify instead of constr_eq because has_type also uses unification *)
29+
tryif unify tx ty then
30+
(* beta gets rid of cast and of eta expansion in case of prod used for Set *)
31+
let r := eval cbv beta in ((_: cl tx) x y) in exact r
32+
else
33+
tryif has_type x ty then
34+
tryif has_type y tx then
35+
fail "bidirectional coercion between" tx "and" ty
36+
"makes" cl "of" x "and" y "ambiguous"
37+
else
38+
(* will coerce type of x to ty *)
39+
let r := eval cbv beta in ((_: cl ty) x y) in exact r
40+
else
41+
tryif has_type y tx then
42+
(* will coerce type of y to tx *)
43+
let r := eval cbv beta in ((_: cl tx) x y) in exact r
44+
else
45+
fail "operands" x "and" y "have incompatible types" tx "and" ty.
46+
47+
48+
Declare Scope oo_scope.
49+
Local Open Scope oo_scope.
50+
51+
Definition Multiplication(T: Type) := T -> T -> T.
52+
Existing Class Multiplication.
53+
#[export] Hint Mode Multiplication + : typeclass_instances.
54+
55+
#[export] Hint Extern 1 (Multiplication nat) => exact Nat.mul : typeclass_instances.
56+
#[export] Hint Extern 1 (Multiplication N) => exact N.mul : typeclass_instances.
57+
#[export] Hint Extern 1 (Multiplication Z) => exact Z.mul : typeclass_instances.
58+
#[export] Hint Extern 1 (Multiplication Set) => exact prod : typeclass_instances.
59+
#[export] Hint Extern 1 (Multiplication Prop) => exact prod : typeclass_instances.
60+
#[export] Hint Extern 1 (Multiplication Type) => exact prod : typeclass_instances.
61+
#[export] Hint Extern 1 (Multiplication (@word.rep ?wi ?wo)) =>
62+
exact (@word.mul wi wo) : typeclass_instances.
63+
#[export] Hint Extern 1 (Multiplication (@map.rep ?K ?V ?M -> Prop)) =>
64+
exact (@sep K V M) : typeclass_instances.
65+
66+
Notation "x * y" := (ltac:(binop Multiplication x y)) (only parsing) : oo_scope.
67+
Notation "x * y" := (Nat.mul x y) (only printing): oo_scope.
68+
Notation "x * y" := (N.mul x y) (only printing): oo_scope.
69+
Notation "x * y" := (Z.mul x y) (only printing): oo_scope.
70+
Notation "x * y" := (prod x y) (only printing): oo_scope.
71+
Notation "x * y" := (word.mul x y) (only printing): oo_scope.
72+
Notation "x * y" := (sep x y) (only printing): oo_scope.
73+
74+
75+
Definition Addition(T: Type) := T -> T -> T.
76+
Existing Class Addition.
77+
#[export] Hint Mode Addition + : typeclass_instances.
78+
79+
#[export] Hint Extern 1 (Addition nat) => exact Nat.add : typeclass_instances.
80+
#[export] Hint Extern 1 (Addition N) => exact N.add : typeclass_instances.
81+
#[export] Hint Extern 1 (Addition Z) => exact Z.add : typeclass_instances.
82+
#[export] Hint Extern 1 (Addition (@word.rep ?wi ?wo)) =>
83+
exact (@word.add wi wo) : typeclass_instances.
84+
85+
Notation "x + y" := (ltac:(binop Addition x y)) (only parsing) : oo_scope.
86+
Notation "x + y" := (Nat.add x y) (only printing): oo_scope.
87+
Notation "x + y" := (N.add x y) (only printing): oo_scope.
88+
Notation "x + y" := (Z.add x y) (only printing): oo_scope.
89+
Notation "x + y" := (word.add x y) (only printing): oo_scope.
90+
91+
(* Tests:
92+
93+
Goal False.
94+
has_type 4%nat nat.
95+
let T := open_constr:(_: Type) in has_type 3 T; idtac T.
96+
Fail has_type 4%nat unit.
97+
has_type nat Set.
98+
has_type nat Type.
99+
has_type (fun (x: Set) => (x * x)%type) (Set -> Set).
100+
has_type (fun (x: Type) => (x * x)%type) (Set -> Set).
101+
has_type (fun (x: Type) => (x * x)%type) (Set -> Type).
102+
Fail has_type (fun (x: Set) => (x * x)%type) (Type -> Type).
103+
Fail has_type (fun (x: Set) => (x * x)%type) (Type -> Set).
104+
Abort.
105+
106+
Fail Check (fun a b => a * b).
107+
Fail Check (fun a (b: Z) => a * b).
108+
Fail Check (fun (a: nat) b => a * b).
109+
Check (fun (a b: nat) => a * b).
110+
Fail Check (fun (a: nat) (b: Z) => a * b).
111+
Fail Check (tt * 2).
112+
Fail Check (tt * tt).
113+
Check (nat * Z)%type.
114+
Check (nat * Z).
115+
Check (nat * Set).
116+
Check (nat * Type).
117+
118+
Section WithParameters.
119+
Context {word: word.word 32} {mem: map.map word Byte.byte}.
120+
Context {word_ok: word.ok word} {mem_ok: map.ok mem}.
121+
122+
Check (fun (P Q: mem -> Prop) => P * Q).
123+
124+
Local Coercion Z.of_nat : nat >-> Z.
125+
Local Coercion word.unsigned : word.rep >-> Z.
126+
127+
(* Local Set Printing Coercions.*)
128+
129+
Check (fun (x: word) => x: Z).
130+
Check (fun (x: nat) => x: Z).
131+
132+
(*
133+
New coercion path [word_to_nat; Z.of_nat] : word.rep >-> Z is ambiguous with existing
134+
[word.unsigned] : word.rep >-> Z. [ambiguous-paths,typechecker]
135+
Coercion word_to_nat(w: word): nat := Z.to_nat (word.unsigned w).
136+
*)
137+
138+
Check (fun (a: nat) (b: Z) => a * b).
139+
Check (fun (a: Z) (b: nat) => a * b).
140+
Check (fun (a: Z) (b: nat) (c: Z) => a * b * c).
141+
Check (fun (a b c: Z) => a * b * c).
142+
Check (fun (a: nat) (b: nat) (c: Z) => a * b * c).
143+
144+
Fail Check (fun (x: word) (y: mem -> Prop) => x * y).
145+
146+
Check (fun (P Q R: mem -> Prop) m => (P * Q * R) m).
147+
End WithParameters.
148+
*)

bedrock2/src/bedrock2Examples/SepAutoArrayTests.v

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Require Import Coq.micromega.Lia.
22
Require Import coqutil.Word.Bitwidth32.
33
Require Import bedrock2.SepAutoArray bedrock2.SepAuto.
4+
Require Import bedrock2.OperatorOverloading. Local Open Scope oo_scope.
45

56
Lemma list_goal_after_simplifications[A: Type]{inh: inhabited A}(f: A -> A): forall ws,
67
(16 <= length ws)%nat ->
@@ -44,7 +45,7 @@ Qed.
4445
Section WithParameters.
4546
Context {word: word.word 32} {mem: map.map word Byte.byte}.
4647
Context {word_ok: word.ok word} {mem_ok: map.ok mem}.
47-
Local Open Scope Z_scope. Local Open Scope list_scope.
48+
Local Open Scope list_scope.
4849

4950
Add Ring wring : (Properties.word.ring_theory (word := word))
5051
((*This preprocessing is too expensive to be always run, especially if
@@ -58,7 +59,7 @@ Section WithParameters.
5859
Context (wp: cmd -> mem -> (mem -> Prop) -> Prop).
5960
Context (sample_call: word -> word -> cmd).
6061

61-
Hypothesis sample_call_correct: forall m a1 n vs R (post: mem -> Prop),
62+
Hypothesis sample_call_correct: forall m a1 n (vs: list word) R (post: mem -> Prop),
6263
seps [a1 |-> with_len (Z.to_nat (word.unsigned n)) word_array vs; R] m /\
6364
(forall m',
6465
(* Currently, the postcondition also needs a `with_len` so that when the caller
@@ -67,7 +68,7 @@ Section WithParameters.
6768
TODO consider ways of supporting to drop with_len in the postcondition when
6869
it can be derived like here (List.upd is guaranteed to preserve the length). *)
6970
seps [a1 |-> with_len (Z.to_nat (word.unsigned n))
70-
word_array (List.upd vs 5 (word.mul (List.nth 5 vs default) (word.of_Z 2)));
71+
word_array (List.upd vs 5 (List.nth 5 vs default * (word.of_Z (width := 32) 2)));
7172
R] m' ->
7273
post m') ->
7374
wp (sample_call a1 n) m post.
@@ -85,7 +86,7 @@ Section WithParameters.
8586
(word.of_Z 10))
8687
m
8788
(fun m' => seps [addr |-> word_array
88-
(List.upd ws 7 (word.mul (List.nth 7 ws default) (word.of_Z 2))); R] m').
89+
(List.upd ws 7 (List.nth 7 ws default * word.of_Z (width := 32) 2)); R] m').
8990

9091
Lemma use_sample_call_with_tactics_unfolded: sample_call_usage_goal.
9192
Proof.

0 commit comments

Comments
 (0)