Skip to content

Commit

Permalink
Ast: add binders positions to EAbs
Browse files Browse the repository at this point in the history
  • Loading branch information
vincent-botbol committed Oct 28, 2024
1 parent 7974044 commit aeb15cd
Show file tree
Hide file tree
Showing 22 changed files with 210 additions and 148 deletions.
1 change: 1 addition & 0 deletions compiler/catala_utils/mark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type 'a pos = ('a, Pos.t) ed
let add m e = e, m
let remove (x, _) = x
let get (_, m) = m
let ghost x = x, Pos.no_pos
let set m (x, _) = x, m
let map f (x, m) = f x, m
let map_mark f (a, m) = a, f m
Expand Down
1 change: 1 addition & 0 deletions compiler/catala_utils/mark.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type 'a pos = ('a, Pos.t) ed
val add : 'm -> 'a -> ('a, 'm) ed
val remove : ('a, 'm) ed -> 'a
val get : ('a, 'm) ed -> 'm
val ghost : 'a -> 'a pos
val set : 'm -> ('a, _) ed -> ('a, 'm) ed
val map : ('a -> 'b) -> ('a, 'm) ed -> ('b, 'm) ed
val map_mark : ('m1 -> 'm2) -> ('a, 'm1) ed -> ('a, 'm2) ed
Expand Down
13 changes: 7 additions & 6 deletions compiler/dcalc/from_scopelang.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ let merge_defaults
let m_callee = Mark.get callee in
let unboxed_callee = Expr.unbox callee in
match Mark.remove unboxed_callee with
| EAbs { binder; tys } ->
| EAbs { binder; pos; tys } ->
let vars, body = Bindlib.unmbind binder in
let m_body = Mark.get body in
let caller =
Expand All @@ -103,6 +103,7 @@ let merge_defaults
let d =
Expr.edefault ~excepts:[caller] ~just:ltrue ~cons (Mark.get cons)
in
let vars = List.map2 (fun v p -> Mark.add p v) (Array.to_list vars) pos in
Expr.make_abs vars (Expr.make_erroronempty d) tys (Expr.mark_pos m_callee)
| _ -> assert false
(* should not happen because there should always be a lambda at the
Expand Down Expand Up @@ -225,7 +226,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
let v =
match var_ctx.scope_input_typ with
| TArrow ([t_arg], t_ret) ->
Expr.make_abs [| Var.make "_" |] (e_empty t_ret) [t_arg] pos
Expr.make_ghost_abs [Var.make "_"] (e_empty t_ret) [t_arg] pos
| TDefault _ as ty -> e_empty (ty, pos)
| _ -> assert false
in
Expand Down Expand Up @@ -366,8 +367,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
@ [Mark.add (Expr.pos e) ("input" ^ string_of_int i)]))
(List.combine params_vars ts_in)
in
Expr.make_abs
(Array.of_list params_vars)
Expr.make_ghost_abs params_vars
(tag_with_log_entry
(tag_with_log_entry
(Expr.eapp
Expand Down Expand Up @@ -411,10 +411,11 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
in
(* let result_var = calling_expr in let result_eta_expanded_var =
result_eta_expaneded in log (if_then_else_returned ) *)
Expr.make_let_in result_var
Expr.make_let_in (Mark.ghost result_var)
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
calling_expr
(Expr.make_let_in result_eta_expanded_var
(Expr.make_let_in
(Mark.ghost result_eta_expanded_var)
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
result_eta_expanded
(tag_with_log_entry
Expand Down
116 changes: 64 additions & 52 deletions compiler/desugared/from_surface.ml
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ let rec translate_expr
Var.make
(match names with [] -> "zip" | _ -> String.concat "_" names)
in
Expr.make_abs [| x1; x2 |]
Expr.make_ghost_abs [x1; x2] (*?*)
(Expr.make_tuple (Expr.evar x1 m :: explode (Expr.evar x2 m)) m)
tys pos
in
Expand Down Expand Up @@ -302,7 +302,7 @@ let rec translate_expr
(fun c_uid' tau ->
if EnumConstructor.compare c_uid c_uid' <> 0 then
let nop_var = Var.make "_" in
Expr.make_abs [| nop_var |]
Expr.make_ghost_abs [nop_var]
(Expr.elit (LBool false) emark)
[tau] pos_op
else
Expand All @@ -311,7 +311,9 @@ let rec translate_expr
Ident.Map.add (Mark.remove binding) binding_var local_vars
in
let e2 = rec_helper ~local_vars e2 in
Expr.make_abs [| binding_var |] e2 [tau] pos_op)
Expr.make_abs
[Mark.add (Mark.get binding) binding_var]
e2 [tau] pos_op)
(fst (EnumName.Map.find enum_uid ctxt.enums))
in
Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark
Expand Down Expand Up @@ -556,17 +558,18 @@ let rec translate_expr
in
Expr.escopecall ~scope:called_scope ~args:in_struct emark
| LetIn (xs, e1, e2) ->
let vs = List.map (fun x -> Var.make (Mark.remove x)) xs in
let m_xs : _ Var.t Mark.pos list =
List.map (fun x -> Mark.map Var.make x) xs
in
let local_vars =
List.fold_left2
(fun local_vars x v -> Ident.Map.add (Mark.remove x) v local_vars)
local_vars xs vs
(fun local_vars x v ->
Ident.Map.add (Mark.remove x) (Mark.remove v) local_vars)
local_vars xs m_xs
in
let taus = List.map (fun x -> TAny, Mark.get x) xs in
(* This type will be resolved in Scopelang.Desambiguation *)
let f =
Expr.make_abs (Array.of_list vs) (rec_helper ~local_vars e2) taus pos
in
let f = Expr.make_abs m_xs (rec_helper ~local_vars e2) taus pos in
Expr.eapp ~f ~args:[rec_helper e1] ~tys:[] emark
| StructReplace (e, fields) ->
let fields =
Expand Down Expand Up @@ -732,7 +735,7 @@ let rec translate_expr
EnumConstructor.Map.mapi
(fun c_uid' tau ->
let nop_var = Var.make "_" in
Expr.make_abs [| nop_var |]
Expr.make_ghost_abs [nop_var]
(Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark)
[tau] pos)
(fst (EnumName.Map.find enum_uid ctxt.enums))
Expand All @@ -747,14 +750,14 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let f_pred =
Expr.make_abs (Array.of_list params)
Expr.make_abs params
(rec_helper ~local_vars predicate)
(List.map (fun _ -> TAny, pos) params)
pos
Expand All @@ -770,7 +773,8 @@ let rec translate_expr
in
let x = Expr.evar v emark in
let tys = List.map (fun _ -> TAny, pos) param_names in
Expr.make_abs [| v |]
Expr.make_abs
[Mark.add Pos.no_pos v]
(Expr.make_app f_pred
(List.init nb_args (fun i ->
Expr.etupleaccess ~e:x ~index:i ~size:nb_args emark))
Expand All @@ -791,22 +795,21 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let accs = List.map (fun n -> Var.make (Mark.remove n)) acc_names in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let accs = List.map (fun n -> Mark.map Var.make n) acc_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let init = rec_helper ~local_vars init in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars acc_names accs
in
let f_proc =
Expr.make_abs
(Array.of_list (accs @ params))
Expr.make_abs (accs @ params)
(rec_helper ~local_vars fct)
(List.map (fun _ -> TAny, pos) (accs @ params))
pos
Expand All @@ -818,18 +821,18 @@ let rec translate_expr
| nb_accs, nb_args ->
let v_acc =
match accs with
| [v] -> v
| [v] -> Mark.remove v
| _ -> Var.make (String.concat "_" (List.map Mark.remove acc_names))
in
let v_param =
match params with
| [v] -> v
| [v] -> Mark.remove v
| _ -> Var.make (String.concat "_" (List.map Mark.remove param_names))
in
let x_acc = Expr.evar v_acc emark in
let x_param = Expr.evar v_param emark in
let tys = List.init (nb_accs + nb_args) (fun _ -> TAny, pos) in
Expr.make_abs [| v_acc; v_param |]
Expr.make_ghost_abs [v_acc; v_param]
(Expr.make_app f_proc
((if nb_accs = 1 then [x_acc]
else
Expand Down Expand Up @@ -860,24 +863,23 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let cmp_op = if max then Op.Gt, opos else Op.Lt, opos in
let f_pred =
Expr.make_abs (Array.of_list params)
(rec_helper ~local_vars predicate)
[TAny, pos]
pos
Expr.make_abs params (rec_helper ~local_vars predicate) [TAny, pos] pos
in
let add_weight_f =
let vs = List.map (fun p -> Var.make (Bindlib.name_of p)) params in
let vs =
List.map (fun p -> Var.make (Bindlib.name_of (Mark.remove p))) params
in
let xs = List.map (fun v -> Expr.evar v emark) vs in
let x = match xs with [x] -> x | xs -> Expr.etuple xs emark in
Expr.make_abs (Array.of_list vs)
Expr.make_ghost_abs vs
(Expr.make_tuple [x; Expr.eapp ~f:f_pred ~args:xs ~tys:[] emark] emark)
[TAny, pos]
pos
Expand All @@ -886,7 +888,7 @@ let rec translate_expr
(* fun x1 x2 -> if cmp_op (x1.2) (x2.2) cmp *)
let v1, v2 = Var.make "x1", Var.make "x2" in
let x1, x2 = Expr.make_var v1 emark, Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse
(Expr.eappop ~op:cmp_op
~tys:[TAny, pos_dft; TAny, pos_dft]
Expand All @@ -903,7 +905,7 @@ let rec translate_expr
let weights_var = Var.make "weights" in
let default = Expr.make_app add_weight_f [default] [TAny, pos] pos_dft in
let weighted_result =
Expr.make_let_in weights_var
Expr.make_let_in (Mark.ghost weights_var)
(TArray (TTuple [TAny, pos; TAny, pos], pos), pos)
(Expr.eappop ~op:(Map, opos)
~tys:[TAny, pos; TArray (TAny, pos), pos]
Expand All @@ -929,23 +931,25 @@ let rec translate_expr
in
let init = Expr.elit (LBool init) emark in
let params0, predicate = predicate in
let params = List.map (fun n -> Var.make (Mark.remove n)) params0 in
let params = List.map (fun n -> Mark.map Var.make n) params0 in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars params0 params
in
let f =
let acc_var = Var.make "acc" in
let acc =
Expr.make_var acc_var (Untyped { pos = Mark.get (List.hd params0) })
in
Expr.eabs
(Expr.bind
(Array.of_list (acc_var :: params))
(translate_binop op pos acc (rec_helper ~local_vars predicate)))
[TAny, pos; TAny, pos]
emark
let vs = Mark.ghost acc_var :: params in
let vs_marks = List.map Mark.get vs in
let mvars =
Expr.bind
(Array.of_list (List.map Mark.remove vs))
(translate_binop op pos acc (rec_helper ~local_vars predicate))
in
Expr.eabs mvars vs_marks [TAny, pos; TAny, pos] emark
in
Expr.eappop ~op:(Fold, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
Expand All @@ -960,7 +964,7 @@ let rec translate_expr
let v1, v2 = Var.make (vname ^ "1"), Var.make (vname ^ "2") in
let x1 = Expr.make_var v1 emark in
let x2 = Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse (translate_binop (op, pos) pos x1 x2) x1 x2 emark)
[TAny, pos; TAny, pos]
pos
Expand Down Expand Up @@ -990,7 +994,7 @@ let rec translate_expr
let v1, v2 = Var.make "sum1", Var.make "sum2" in
let x1 = Expr.make_var v1 emark in
let x2 = Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(translate_binop (S.Add KPoly, opos) pos x1 x2)
[TAny, pos; TAny, pos]
pos
Expand Down Expand Up @@ -1019,9 +1023,11 @@ let rec translate_expr
]
emark
in
let vars = [Mark.ghost acc_var; Mark.add opos param_var] in
let f =
Expr.eabs
(Expr.bind [| acc_var; param_var |] f_body)
(Expr.bind (Array.of_list (List.map Mark.remove vars)) f_body)
(List.map Mark.get vars)
[TLit TBool, pos; TAny, pos]
emark
in
Expand All @@ -1047,8 +1053,9 @@ and disambiguate_match_and_build_expression
(e_uid : EnumName.t)
(ctxt : Name_resolution.context)
case_body
e_binder =
Expr.eabs e_binder
e_binder
pos_binder =
Expr.eabs e_binder pos_binder
[
EnumConstructor.Map.find c_uid
(fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums));
Expand Down Expand Up @@ -1091,7 +1098,14 @@ and disambiguate_match_and_build_expression
case.S.match_case_expr
in
let e_binder = Expr.bind [| param_var |] case_body in
let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in
let pos_binder =
match binding with
| None -> [Pos.no_pos]
| Some binding -> [Mark.get binding]
in
let case_expr =
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
in
( EnumConstructor.Map.add c_uid case_expr cases_d,
Some e_uid,
curr_index + 1 )
Expand Down Expand Up @@ -1147,12 +1161,12 @@ and disambiguate_match_and_build_expression
match_case_expr
in
let e_binder = Expr.bind [| payload_var |] case_body in

let pos_binder = [Pos.no_pos] in
(* For each missing cases, binds the wildcard payload. *)
EnumConstructor.Map.fold
(fun c_uid _ (cases_d, e_uid_opt, curr_index) ->
let case_expr =
bind_case_body c_uid e_uid ctxt case_body e_binder
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
in
( EnumConstructor.Map.add c_uid case_expr cases_d,
e_uid_opt,
Expand Down Expand Up @@ -1568,9 +1582,7 @@ let process_topdef
| _ -> ()
in
let e =
Expr.make_abs
(Array.of_list (List.map Mark.remove args))
body
Expr.make_abs args body
(List.map translate_tbase tys)
(Mark.get def.S.topdef_name)
in
Expand Down
Loading

0 comments on commit aeb15cd

Please sign in to comment.