Skip to content

Commit aeb15cd

Browse files
Ast: add binders positions to EAbs
1 parent 7974044 commit aeb15cd

File tree

22 files changed

+210
-148
lines changed

22 files changed

+210
-148
lines changed

compiler/catala_utils/mark.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type 'a pos = ('a, Pos.t) ed
2121
let add m e = e, m
2222
let remove (x, _) = x
2323
let get (_, m) = m
24+
let ghost x = x, Pos.no_pos
2425
let set m (x, _) = x, m
2526
let map f (x, m) = f x, m
2627
let map_mark f (a, m) = a, f m

compiler/catala_utils/mark.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type 'a pos = ('a, Pos.t) ed
2828
val add : 'm -> 'a -> ('a, 'm) ed
2929
val remove : ('a, 'm) ed -> 'a
3030
val get : ('a, 'm) ed -> 'm
31+
val ghost : 'a -> 'a pos
3132
val set : 'm -> ('a, _) ed -> ('a, 'm) ed
3233
val map : ('a -> 'b) -> ('a, 'm) ed -> ('b, 'm) ed
3334
val map_mark : ('m1 -> 'm2) -> ('a, 'm1) ed -> ('a, 'm2) ed

compiler/dcalc/from_scopelang.ml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ let merge_defaults
7777
let m_callee = Mark.get callee in
7878
let unboxed_callee = Expr.unbox callee in
7979
match Mark.remove unboxed_callee with
80-
| EAbs { binder; tys } ->
80+
| EAbs { binder; pos; tys } ->
8181
let vars, body = Bindlib.unmbind binder in
8282
let m_body = Mark.get body in
8383
let caller =
@@ -103,6 +103,7 @@ let merge_defaults
103103
let d =
104104
Expr.edefault ~excepts:[caller] ~just:ltrue ~cons (Mark.get cons)
105105
in
106+
let vars = List.map2 (fun v p -> Mark.add p v) (Array.to_list vars) pos in
106107
Expr.make_abs vars (Expr.make_erroronempty d) tys (Expr.mark_pos m_callee)
107108
| _ -> assert false
108109
(* should not happen because there should always be a lambda at the
@@ -225,7 +226,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
225226
let v =
226227
match var_ctx.scope_input_typ with
227228
| TArrow ([t_arg], t_ret) ->
228-
Expr.make_abs [| Var.make "_" |] (e_empty t_ret) [t_arg] pos
229+
Expr.make_ghost_abs [Var.make "_"] (e_empty t_ret) [t_arg] pos
229230
| TDefault _ as ty -> e_empty (ty, pos)
230231
| _ -> assert false
231232
in
@@ -366,8 +367,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
366367
@ [Mark.add (Expr.pos e) ("input" ^ string_of_int i)]))
367368
(List.combine params_vars ts_in)
368369
in
369-
Expr.make_abs
370-
(Array.of_list params_vars)
370+
Expr.make_ghost_abs params_vars
371371
(tag_with_log_entry
372372
(tag_with_log_entry
373373
(Expr.eapp
@@ -411,10 +411,11 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
411411
in
412412
(* let result_var = calling_expr in let result_eta_expanded_var =
413413
result_eta_expaneded in log (if_then_else_returned ) *)
414-
Expr.make_let_in result_var
414+
Expr.make_let_in (Mark.ghost result_var)
415415
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
416416
calling_expr
417-
(Expr.make_let_in result_eta_expanded_var
417+
(Expr.make_let_in
418+
(Mark.ghost result_eta_expanded_var)
418419
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
419420
result_eta_expanded
420421
(tag_with_log_entry

compiler/desugared/from_surface.ml

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ let rec translate_expr
270270
Var.make
271271
(match names with [] -> "zip" | _ -> String.concat "_" names)
272272
in
273-
Expr.make_abs [| x1; x2 |]
273+
Expr.make_ghost_abs [x1; x2] (*?*)
274274
(Expr.make_tuple (Expr.evar x1 m :: explode (Expr.evar x2 m)) m)
275275
tys pos
276276
in
@@ -302,7 +302,7 @@ let rec translate_expr
302302
(fun c_uid' tau ->
303303
if EnumConstructor.compare c_uid c_uid' <> 0 then
304304
let nop_var = Var.make "_" in
305-
Expr.make_abs [| nop_var |]
305+
Expr.make_ghost_abs [nop_var]
306306
(Expr.elit (LBool false) emark)
307307
[tau] pos_op
308308
else
@@ -311,7 +311,9 @@ let rec translate_expr
311311
Ident.Map.add (Mark.remove binding) binding_var local_vars
312312
in
313313
let e2 = rec_helper ~local_vars e2 in
314-
Expr.make_abs [| binding_var |] e2 [tau] pos_op)
314+
Expr.make_abs
315+
[Mark.add (Mark.get binding) binding_var]
316+
e2 [tau] pos_op)
315317
(fst (EnumName.Map.find enum_uid ctxt.enums))
316318
in
317319
Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark
@@ -556,17 +558,18 @@ let rec translate_expr
556558
in
557559
Expr.escopecall ~scope:called_scope ~args:in_struct emark
558560
| LetIn (xs, e1, e2) ->
559-
let vs = List.map (fun x -> Var.make (Mark.remove x)) xs in
561+
let m_xs : _ Var.t Mark.pos list =
562+
List.map (fun x -> Mark.map Var.make x) xs
563+
in
560564
let local_vars =
561565
List.fold_left2
562-
(fun local_vars x v -> Ident.Map.add (Mark.remove x) v local_vars)
563-
local_vars xs vs
566+
(fun local_vars x v ->
567+
Ident.Map.add (Mark.remove x) (Mark.remove v) local_vars)
568+
local_vars xs m_xs
564569
in
565570
let taus = List.map (fun x -> TAny, Mark.get x) xs in
566571
(* This type will be resolved in Scopelang.Desambiguation *)
567-
let f =
568-
Expr.make_abs (Array.of_list vs) (rec_helper ~local_vars e2) taus pos
569-
in
572+
let f = Expr.make_abs m_xs (rec_helper ~local_vars e2) taus pos in
570573
Expr.eapp ~f ~args:[rec_helper e1] ~tys:[] emark
571574
| StructReplace (e, fields) ->
572575
let fields =
@@ -732,7 +735,7 @@ let rec translate_expr
732735
EnumConstructor.Map.mapi
733736
(fun c_uid' tau ->
734737
let nop_var = Var.make "_" in
735-
Expr.make_abs [| nop_var |]
738+
Expr.make_ghost_abs [nop_var]
736739
(Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark)
737740
[tau] pos)
738741
(fst (EnumName.Map.find enum_uid ctxt.enums))
@@ -747,14 +750,14 @@ let rec translate_expr
747750
let collection =
748751
detuplify_list opos (List.map Mark.remove param_names) collection
749752
in
750-
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
753+
let params = List.map (fun n -> Mark.map Var.make n) param_names in
751754
let local_vars =
752755
List.fold_left2
753-
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
756+
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
754757
local_vars param_names params
755758
in
756759
let f_pred =
757-
Expr.make_abs (Array.of_list params)
760+
Expr.make_abs params
758761
(rec_helper ~local_vars predicate)
759762
(List.map (fun _ -> TAny, pos) params)
760763
pos
@@ -770,7 +773,8 @@ let rec translate_expr
770773
in
771774
let x = Expr.evar v emark in
772775
let tys = List.map (fun _ -> TAny, pos) param_names in
773-
Expr.make_abs [| v |]
776+
Expr.make_abs
777+
[Mark.add Pos.no_pos v]
774778
(Expr.make_app f_pred
775779
(List.init nb_args (fun i ->
776780
Expr.etupleaccess ~e:x ~index:i ~size:nb_args emark))
@@ -791,22 +795,21 @@ let rec translate_expr
791795
let collection =
792796
detuplify_list opos (List.map Mark.remove param_names) collection
793797
in
794-
let accs = List.map (fun n -> Var.make (Mark.remove n)) acc_names in
795-
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
798+
let accs = List.map (fun n -> Mark.map Var.make n) acc_names in
799+
let params = List.map (fun n -> Mark.map Var.make n) param_names in
796800
let init = rec_helper ~local_vars init in
797801
let local_vars =
798802
List.fold_left2
799-
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
803+
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
800804
local_vars param_names params
801805
in
802806
let local_vars =
803807
List.fold_left2
804-
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
808+
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
805809
local_vars acc_names accs
806810
in
807811
let f_proc =
808-
Expr.make_abs
809-
(Array.of_list (accs @ params))
812+
Expr.make_abs (accs @ params)
810813
(rec_helper ~local_vars fct)
811814
(List.map (fun _ -> TAny, pos) (accs @ params))
812815
pos
@@ -818,18 +821,18 @@ let rec translate_expr
818821
| nb_accs, nb_args ->
819822
let v_acc =
820823
match accs with
821-
| [v] -> v
824+
| [v] -> Mark.remove v
822825
| _ -> Var.make (String.concat "_" (List.map Mark.remove acc_names))
823826
in
824827
let v_param =
825828
match params with
826-
| [v] -> v
829+
| [v] -> Mark.remove v
827830
| _ -> Var.make (String.concat "_" (List.map Mark.remove param_names))
828831
in
829832
let x_acc = Expr.evar v_acc emark in
830833
let x_param = Expr.evar v_param emark in
831834
let tys = List.init (nb_accs + nb_args) (fun _ -> TAny, pos) in
832-
Expr.make_abs [| v_acc; v_param |]
835+
Expr.make_ghost_abs [v_acc; v_param]
833836
(Expr.make_app f_proc
834837
((if nb_accs = 1 then [x_acc]
835838
else
@@ -860,24 +863,23 @@ let rec translate_expr
860863
let collection =
861864
detuplify_list opos (List.map Mark.remove param_names) collection
862865
in
863-
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
866+
let params = List.map (fun n -> Mark.map Var.make n) param_names in
864867
let local_vars =
865868
List.fold_left2
866-
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
869+
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
867870
local_vars param_names params
868871
in
869872
let cmp_op = if max then Op.Gt, opos else Op.Lt, opos in
870873
let f_pred =
871-
Expr.make_abs (Array.of_list params)
872-
(rec_helper ~local_vars predicate)
873-
[TAny, pos]
874-
pos
874+
Expr.make_abs params (rec_helper ~local_vars predicate) [TAny, pos] pos
875875
in
876876
let add_weight_f =
877-
let vs = List.map (fun p -> Var.make (Bindlib.name_of p)) params in
877+
let vs =
878+
List.map (fun p -> Var.make (Bindlib.name_of (Mark.remove p))) params
879+
in
878880
let xs = List.map (fun v -> Expr.evar v emark) vs in
879881
let x = match xs with [x] -> x | xs -> Expr.etuple xs emark in
880-
Expr.make_abs (Array.of_list vs)
882+
Expr.make_ghost_abs vs
881883
(Expr.make_tuple [x; Expr.eapp ~f:f_pred ~args:xs ~tys:[] emark] emark)
882884
[TAny, pos]
883885
pos
@@ -886,7 +888,7 @@ let rec translate_expr
886888
(* fun x1 x2 -> if cmp_op (x1.2) (x2.2) cmp *)
887889
let v1, v2 = Var.make "x1", Var.make "x2" in
888890
let x1, x2 = Expr.make_var v1 emark, Expr.make_var v2 emark in
889-
Expr.make_abs [| v1; v2 |]
891+
Expr.make_ghost_abs [v1; v2]
890892
(Expr.eifthenelse
891893
(Expr.eappop ~op:cmp_op
892894
~tys:[TAny, pos_dft; TAny, pos_dft]
@@ -903,7 +905,7 @@ let rec translate_expr
903905
let weights_var = Var.make "weights" in
904906
let default = Expr.make_app add_weight_f [default] [TAny, pos] pos_dft in
905907
let weighted_result =
906-
Expr.make_let_in weights_var
908+
Expr.make_let_in (Mark.ghost weights_var)
907909
(TArray (TTuple [TAny, pos; TAny, pos], pos), pos)
908910
(Expr.eappop ~op:(Map, opos)
909911
~tys:[TAny, pos; TArray (TAny, pos), pos]
@@ -929,23 +931,25 @@ let rec translate_expr
929931
in
930932
let init = Expr.elit (LBool init) emark in
931933
let params0, predicate = predicate in
932-
let params = List.map (fun n -> Var.make (Mark.remove n)) params0 in
934+
let params = List.map (fun n -> Mark.map Var.make n) params0 in
933935
let local_vars =
934936
List.fold_left2
935-
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
937+
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
936938
local_vars params0 params
937939
in
938940
let f =
939941
let acc_var = Var.make "acc" in
940942
let acc =
941943
Expr.make_var acc_var (Untyped { pos = Mark.get (List.hd params0) })
942944
in
943-
Expr.eabs
944-
(Expr.bind
945-
(Array.of_list (acc_var :: params))
946-
(translate_binop op pos acc (rec_helper ~local_vars predicate)))
947-
[TAny, pos; TAny, pos]
948-
emark
945+
let vs = Mark.ghost acc_var :: params in
946+
let vs_marks = List.map Mark.get vs in
947+
let mvars =
948+
Expr.bind
949+
(Array.of_list (List.map Mark.remove vs))
950+
(translate_binop op pos acc (rec_helper ~local_vars predicate))
951+
in
952+
Expr.eabs mvars vs_marks [TAny, pos; TAny, pos] emark
949953
in
950954
Expr.eappop ~op:(Fold, opos)
951955
~tys:[TAny, pos; TAny, pos; TAny, pos]
@@ -960,7 +964,7 @@ let rec translate_expr
960964
let v1, v2 = Var.make (vname ^ "1"), Var.make (vname ^ "2") in
961965
let x1 = Expr.make_var v1 emark in
962966
let x2 = Expr.make_var v2 emark in
963-
Expr.make_abs [| v1; v2 |]
967+
Expr.make_ghost_abs [v1; v2]
964968
(Expr.eifthenelse (translate_binop (op, pos) pos x1 x2) x1 x2 emark)
965969
[TAny, pos; TAny, pos]
966970
pos
@@ -990,7 +994,7 @@ let rec translate_expr
990994
let v1, v2 = Var.make "sum1", Var.make "sum2" in
991995
let x1 = Expr.make_var v1 emark in
992996
let x2 = Expr.make_var v2 emark in
993-
Expr.make_abs [| v1; v2 |]
997+
Expr.make_ghost_abs [v1; v2]
994998
(translate_binop (S.Add KPoly, opos) pos x1 x2)
995999
[TAny, pos; TAny, pos]
9961000
pos
@@ -1019,9 +1023,11 @@ let rec translate_expr
10191023
]
10201024
emark
10211025
in
1026+
let vars = [Mark.ghost acc_var; Mark.add opos param_var] in
10221027
let f =
10231028
Expr.eabs
1024-
(Expr.bind [| acc_var; param_var |] f_body)
1029+
(Expr.bind (Array.of_list (List.map Mark.remove vars)) f_body)
1030+
(List.map Mark.get vars)
10251031
[TLit TBool, pos; TAny, pos]
10261032
emark
10271033
in
@@ -1047,8 +1053,9 @@ and disambiguate_match_and_build_expression
10471053
(e_uid : EnumName.t)
10481054
(ctxt : Name_resolution.context)
10491055
case_body
1050-
e_binder =
1051-
Expr.eabs e_binder
1056+
e_binder
1057+
pos_binder =
1058+
Expr.eabs e_binder pos_binder
10521059
[
10531060
EnumConstructor.Map.find c_uid
10541061
(fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums));
@@ -1091,7 +1098,14 @@ and disambiguate_match_and_build_expression
10911098
case.S.match_case_expr
10921099
in
10931100
let e_binder = Expr.bind [| param_var |] case_body in
1094-
let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in
1101+
let pos_binder =
1102+
match binding with
1103+
| None -> [Pos.no_pos]
1104+
| Some binding -> [Mark.get binding]
1105+
in
1106+
let case_expr =
1107+
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
1108+
in
10951109
( EnumConstructor.Map.add c_uid case_expr cases_d,
10961110
Some e_uid,
10971111
curr_index + 1 )
@@ -1147,12 +1161,12 @@ and disambiguate_match_and_build_expression
11471161
match_case_expr
11481162
in
11491163
let e_binder = Expr.bind [| payload_var |] case_body in
1150-
1164+
let pos_binder = [Pos.no_pos] in
11511165
(* For each missing cases, binds the wildcard payload. *)
11521166
EnumConstructor.Map.fold
11531167
(fun c_uid _ (cases_d, e_uid_opt, curr_index) ->
11541168
let case_expr =
1155-
bind_case_body c_uid e_uid ctxt case_body e_binder
1169+
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
11561170
in
11571171
( EnumConstructor.Map.add c_uid case_expr cases_d,
11581172
e_uid_opt,
@@ -1568,9 +1582,7 @@ let process_topdef
15681582
| _ -> ()
15691583
in
15701584
let e =
1571-
Expr.make_abs
1572-
(Array.of_list (List.map Mark.remove args))
1573-
body
1585+
Expr.make_abs args body
15741586
(List.map translate_tbase tys)
15751587
(Mark.get def.S.topdef_name)
15761588
in

0 commit comments

Comments
 (0)