Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Carry over binders positions in EAbs #734

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/catala_utils/mark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type 'a pos = ('a, Pos.t) ed
let add m e = e, m
let remove (x, _) = x
let get (_, m) = m
let ghost x = x, Pos.no_pos
let set m (x, _) = x, m
let map f (x, m) = f x, m
let map_mark f (a, m) = a, f m
Expand Down
1 change: 1 addition & 0 deletions compiler/catala_utils/mark.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type 'a pos = ('a, Pos.t) ed
val add : 'm -> 'a -> ('a, 'm) ed
val remove : ('a, 'm) ed -> 'a
val get : ('a, 'm) ed -> 'm
val ghost : 'a -> 'a pos
val set : 'm -> ('a, _) ed -> ('a, 'm) ed
val map : ('a -> 'b) -> ('a, 'm) ed -> ('b, 'm) ed
val map_mark : ('m1 -> 'm2) -> ('a, 'm1) ed -> ('a, 'm2) ed
Expand Down
13 changes: 7 additions & 6 deletions compiler/dcalc/from_scopelang.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ let merge_defaults
let m_callee = Mark.get callee in
let unboxed_callee = Expr.unbox callee in
match Mark.remove unboxed_callee with
| EAbs { binder; tys } ->
| EAbs { binder; pos; tys } ->
let vars, body = Bindlib.unmbind binder in
let m_body = Mark.get body in
let caller =
Expand All @@ -103,6 +103,7 @@ let merge_defaults
let d =
Expr.edefault ~excepts:[caller] ~just:ltrue ~cons (Mark.get cons)
in
let vars = List.map2 (fun v p -> Mark.add p v) (Array.to_list vars) pos in
Expr.make_abs vars (Expr.make_erroronempty d) tys (Expr.mark_pos m_callee)
| _ -> assert false
(* should not happen because there should always be a lambda at the
Expand Down Expand Up @@ -225,7 +226,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
let v =
match var_ctx.scope_input_typ with
| TArrow ([t_arg], t_ret) ->
Expr.make_abs [| Var.make "_" |] (e_empty t_ret) [t_arg] pos
Expr.make_ghost_abs [Var.make "_"] (e_empty t_ret) [t_arg] pos
| TDefault _ as ty -> e_empty (ty, pos)
| _ -> assert false
in
Expand Down Expand Up @@ -366,8 +367,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
@ [Mark.add (Expr.pos e) ("input" ^ string_of_int i)]))
(List.combine params_vars ts_in)
in
Expr.make_abs
(Array.of_list params_vars)
Expr.make_ghost_abs params_vars
(tag_with_log_entry
(tag_with_log_entry
(Expr.eapp
Expand Down Expand Up @@ -411,10 +411,11 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
in
(* let result_var = calling_expr in let result_eta_expanded_var =
result_eta_expaneded in log (if_then_else_returned ) *)
Expr.make_let_in result_var
Expr.make_let_in (Mark.ghost result_var)
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
calling_expr
(Expr.make_let_in result_eta_expanded_var
(Expr.make_let_in
(Mark.ghost result_eta_expanded_var)
(TStruct sc_sig.scope_sig_output_struct, Expr.pos e)
result_eta_expanded
(tag_with_log_entry
Expand Down
116 changes: 64 additions & 52 deletions compiler/desugared/from_surface.ml
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ let rec translate_expr
Var.make
(match names with [] -> "zip" | _ -> String.concat "_" names)
in
Expr.make_abs [| x1; x2 |]
Expr.make_ghost_abs [x1; x2]
(Expr.make_tuple (Expr.evar x1 m :: explode (Expr.evar x2 m)) m)
tys pos
in
Expand Down Expand Up @@ -302,7 +302,7 @@ let rec translate_expr
(fun c_uid' tau ->
if EnumConstructor.compare c_uid c_uid' <> 0 then
let nop_var = Var.make "_" in
Expr.make_abs [| nop_var |]
Expr.make_ghost_abs [nop_var]
(Expr.elit (LBool false) emark)
[tau] pos_op
else
Expand All @@ -311,7 +311,9 @@ let rec translate_expr
Ident.Map.add (Mark.remove binding) binding_var local_vars
in
let e2 = rec_helper ~local_vars e2 in
Expr.make_abs [| binding_var |] e2 [tau] pos_op)
Expr.make_abs
[Mark.add (Mark.get binding) binding_var]
e2 [tau] pos_op)
(fst (EnumName.Map.find enum_uid ctxt.enums))
in
Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark
Expand Down Expand Up @@ -556,17 +558,18 @@ let rec translate_expr
in
Expr.escopecall ~scope:called_scope ~args:in_struct emark
| LetIn (xs, e1, e2) ->
let vs = List.map (fun x -> Var.make (Mark.remove x)) xs in
let m_xs : _ Var.t Mark.pos list =
List.map (fun x -> Mark.map Var.make x) xs
in
let local_vars =
List.fold_left2
(fun local_vars x v -> Ident.Map.add (Mark.remove x) v local_vars)
local_vars xs vs
(fun local_vars x v ->
Ident.Map.add (Mark.remove x) (Mark.remove v) local_vars)
local_vars xs m_xs
in
let taus = List.map (fun x -> TAny, Mark.get x) xs in
(* This type will be resolved in Scopelang.Desambiguation *)
let f =
Expr.make_abs (Array.of_list vs) (rec_helper ~local_vars e2) taus pos
in
let f = Expr.make_abs m_xs (rec_helper ~local_vars e2) taus pos in
Expr.eapp ~f ~args:[rec_helper e1] ~tys:[] emark
| StructReplace (e, fields) ->
let fields =
Expand Down Expand Up @@ -732,7 +735,7 @@ let rec translate_expr
EnumConstructor.Map.mapi
(fun c_uid' tau ->
let nop_var = Var.make "_" in
Expr.make_abs [| nop_var |]
Expr.make_ghost_abs [nop_var]
(Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark)
[tau] pos)
(fst (EnumName.Map.find enum_uid ctxt.enums))
Expand All @@ -747,14 +750,14 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let f_pred =
Expr.make_abs (Array.of_list params)
Expr.make_abs params
(rec_helper ~local_vars predicate)
(List.map (fun _ -> TAny, pos) params)
pos
Expand All @@ -770,7 +773,8 @@ let rec translate_expr
in
let x = Expr.evar v emark in
let tys = List.map (fun _ -> TAny, pos) param_names in
Expr.make_abs [| v |]
Expr.make_abs
[Mark.add Pos.no_pos v]
(Expr.make_app f_pred
(List.init nb_args (fun i ->
Expr.etupleaccess ~e:x ~index:i ~size:nb_args emark))
Expand All @@ -791,22 +795,21 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let accs = List.map (fun n -> Var.make (Mark.remove n)) acc_names in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let accs = List.map (fun n -> Mark.map Var.make n) acc_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let init = rec_helper ~local_vars init in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars acc_names accs
in
let f_proc =
Expr.make_abs
(Array.of_list (accs @ params))
Expr.make_abs (accs @ params)
(rec_helper ~local_vars fct)
(List.map (fun _ -> TAny, pos) (accs @ params))
pos
Expand All @@ -818,18 +821,18 @@ let rec translate_expr
| nb_accs, nb_args ->
let v_acc =
match accs with
| [v] -> v
| [v] -> Mark.remove v
| _ -> Var.make (String.concat "_" (List.map Mark.remove acc_names))
in
let v_param =
match params with
| [v] -> v
| [v] -> Mark.remove v
| _ -> Var.make (String.concat "_" (List.map Mark.remove param_names))
in
let x_acc = Expr.evar v_acc emark in
let x_param = Expr.evar v_param emark in
let tys = List.init (nb_accs + nb_args) (fun _ -> TAny, pos) in
Expr.make_abs [| v_acc; v_param |]
Expr.make_ghost_abs [v_acc; v_param]
(Expr.make_app f_proc
((if nb_accs = 1 then [x_acc]
else
Expand Down Expand Up @@ -860,24 +863,23 @@ let rec translate_expr
let collection =
detuplify_list opos (List.map Mark.remove param_names) collection
in
let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in
let params = List.map (fun n -> Mark.map Var.make n) param_names in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars param_names params
in
let cmp_op = if max then Op.Gt, opos else Op.Lt, opos in
let f_pred =
Expr.make_abs (Array.of_list params)
(rec_helper ~local_vars predicate)
[TAny, pos]
pos
Expr.make_abs params (rec_helper ~local_vars predicate) [TAny, pos] pos
in
let add_weight_f =
let vs = List.map (fun p -> Var.make (Bindlib.name_of p)) params in
let vs =
List.map (fun p -> Var.make (Bindlib.name_of (Mark.remove p))) params
in
let xs = List.map (fun v -> Expr.evar v emark) vs in
let x = match xs with [x] -> x | xs -> Expr.etuple xs emark in
Expr.make_abs (Array.of_list vs)
Expr.make_ghost_abs vs
(Expr.make_tuple [x; Expr.eapp ~f:f_pred ~args:xs ~tys:[] emark] emark)
[TAny, pos]
pos
Expand All @@ -886,7 +888,7 @@ let rec translate_expr
(* fun x1 x2 -> if cmp_op (x1.2) (x2.2) cmp *)
let v1, v2 = Var.make "x1", Var.make "x2" in
let x1, x2 = Expr.make_var v1 emark, Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse
(Expr.eappop ~op:cmp_op
~tys:[TAny, pos_dft; TAny, pos_dft]
Expand All @@ -903,7 +905,7 @@ let rec translate_expr
let weights_var = Var.make "weights" in
let default = Expr.make_app add_weight_f [default] [TAny, pos] pos_dft in
let weighted_result =
Expr.make_let_in weights_var
Expr.make_let_in (Mark.ghost weights_var)
(TArray (TTuple [TAny, pos; TAny, pos], pos), pos)
(Expr.eappop ~op:(Map, opos)
~tys:[TAny, pos; TArray (TAny, pos), pos]
Expand All @@ -929,23 +931,25 @@ let rec translate_expr
in
let init = Expr.elit (LBool init) emark in
let params0, predicate = predicate in
let params = List.map (fun n -> Var.make (Mark.remove n)) params0 in
let params = List.map (fun n -> Mark.map Var.make n) params0 in
let local_vars =
List.fold_left2
(fun vars n p -> Ident.Map.add (Mark.remove n) p vars)
(fun vars n p -> Ident.Map.add (Mark.remove n) (Mark.remove p) vars)
local_vars params0 params
in
let f =
let acc_var = Var.make "acc" in
let acc =
Expr.make_var acc_var (Untyped { pos = Mark.get (List.hd params0) })
in
Expr.eabs
(Expr.bind
(Array.of_list (acc_var :: params))
(translate_binop op pos acc (rec_helper ~local_vars predicate)))
[TAny, pos; TAny, pos]
emark
let vs = Mark.ghost acc_var :: params in
let vs_marks = List.map Mark.get vs in
let mvars =
Expr.bind
(Array.of_list (List.map Mark.remove vs))
(translate_binop op pos acc (rec_helper ~local_vars predicate))
in
Expr.eabs mvars vs_marks [TAny, pos; TAny, pos] emark
in
Expr.eappop ~op:(Fold, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
Expand All @@ -960,7 +964,7 @@ let rec translate_expr
let v1, v2 = Var.make (vname ^ "1"), Var.make (vname ^ "2") in
let x1 = Expr.make_var v1 emark in
let x2 = Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse (translate_binop (op, pos) pos x1 x2) x1 x2 emark)
[TAny, pos; TAny, pos]
pos
Expand Down Expand Up @@ -990,7 +994,7 @@ let rec translate_expr
let v1, v2 = Var.make "sum1", Var.make "sum2" in
let x1 = Expr.make_var v1 emark in
let x2 = Expr.make_var v2 emark in
Expr.make_abs [| v1; v2 |]
Expr.make_ghost_abs [v1; v2]
(translate_binop (S.Add KPoly, opos) pos x1 x2)
[TAny, pos; TAny, pos]
pos
Expand Down Expand Up @@ -1019,9 +1023,11 @@ let rec translate_expr
]
emark
in
let vars = [Mark.ghost acc_var; Mark.add opos param_var] in
let f =
Expr.eabs
(Expr.bind [| acc_var; param_var |] f_body)
(Expr.bind (Array.of_list (List.map Mark.remove vars)) f_body)
(List.map Mark.get vars)
[TLit TBool, pos; TAny, pos]
emark
in
Expand All @@ -1047,8 +1053,9 @@ and disambiguate_match_and_build_expression
(e_uid : EnumName.t)
(ctxt : Name_resolution.context)
case_body
e_binder =
Expr.eabs e_binder
e_binder
pos_binder =
Expr.eabs e_binder pos_binder
[
EnumConstructor.Map.find c_uid
(fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums));
Expand Down Expand Up @@ -1091,7 +1098,14 @@ and disambiguate_match_and_build_expression
case.S.match_case_expr
in
let e_binder = Expr.bind [| param_var |] case_body in
let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in
let pos_binder =
match binding with
| None -> [Pos.no_pos]
| Some binding -> [Mark.get binding]
in
let case_expr =
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
in
( EnumConstructor.Map.add c_uid case_expr cases_d,
Some e_uid,
curr_index + 1 )
Expand Down Expand Up @@ -1147,12 +1161,12 @@ and disambiguate_match_and_build_expression
match_case_expr
in
let e_binder = Expr.bind [| payload_var |] case_body in

let pos_binder = [Pos.no_pos] in
(* For each missing cases, binds the wildcard payload. *)
EnumConstructor.Map.fold
(fun c_uid _ (cases_d, e_uid_opt, curr_index) ->
let case_expr =
bind_case_body c_uid e_uid ctxt case_body e_binder
bind_case_body c_uid e_uid ctxt case_body e_binder pos_binder
in
( EnumConstructor.Map.add c_uid case_expr cases_d,
e_uid_opt,
Expand Down Expand Up @@ -1568,9 +1582,7 @@ let process_topdef
| _ -> ()
in
let e =
Expr.make_abs
(Array.of_list (List.map Mark.remove args))
body
Expr.make_abs args body
(List.map translate_tbase tys)
(Mark.get def.S.topdef_name)
in
Expand Down
Loading