diff --git a/middle_end/flambda/from_lambda/closure_conversion.ml b/middle_end/flambda/from_lambda/closure_conversion.ml index f667fccb3090..539cc88ae7dc 100644 --- a/middle_end/flambda/from_lambda/closure_conversion.ml +++ b/middle_end/flambda/from_lambda/closure_conversion.ml @@ -547,17 +547,33 @@ let rec close acc env (ilam : Ilambda.t) : Acc.t * Expr_with_acc.t = in let callee = find_simple_from_id env func in let acc, args = find_simples acc env args in - let apply = - Apply.create ~callee - ~continuation:(Return continuation) - exn_continuation - ~args - ~call_kind - (Debuginfo.from_location loc) - ~inline:(LC.inline_attribute inlined) - ~inlining_state:(Inlining_state.default) + let call_is_tail = + (* No need to check for the exception continuation. If the exception + continuation is not the tail one, the return continuation cannot + be the tail one either: there must be a pop continuation. *) + Continuation.equal continuation (Env.return_continuation env) in - Expr_with_acc.create_apply acc apply + begin match Env.current_function env with + | Some { let_rec_ident; arity; self_tail_call_continuation } when + call_is_tail && + Ident.same func let_rec_ident && + List.length args = arity -> + Apply_cont.create self_tail_call_continuation ~args + ~dbg:(Debuginfo.from_location loc) + |> Expr_with_acc.create_apply_cont acc + | _ -> + let apply = + Apply.create ~callee + ~continuation:(Return continuation) + exn_continuation + ~args + ~call_kind + (Debuginfo.from_location loc) + ~inline:(LC.inline_attribute inlined) + ~inlining_state:(Inlining_state.default) + in + Expr_with_acc.create_apply acc apply + end | Apply_cont (cont, trap_action, args) -> let acc, args = find_simples acc env args in let trap_action = close_trap_action_opt trap_action in @@ -881,8 +897,26 @@ and close_one_function acc ~external_env ~by_closure_id decl (Variable.Map.empty, Ident.Map.empty) (Function_decls.to_list function_declarations) in + let self_tail_call_continuation = + Continuation.create ~name:"self_tail_call" () + in let closure_env_without_parameters = - let empty_env = Env.clear_local_bindings external_env in + let arity = + match Function_decl.kind decl with + | Curried -> List.length params + | Tupled -> 1 + in + let current_function_info : Env.function_being_defined = + { let_rec_ident = our_let_rec_ident; + arity; + self_tail_call_continuation; + } + in + let empty_env = + Env.clear_local_bindings external_env + ~return_continuation:(Function_decl.return_continuation decl) + current_function_info + in Env.add_var_map (Env.add_var_map empty_env var_within_closures_for_idents) vars_for_project_closure in @@ -977,6 +1011,122 @@ and close_one_function acc ~external_env ~by_closure_id decl close_exn_continuation acc external_env (Function_decl.exn_continuation decl) in + let acc, body = + let self_tail_call_is_used = + (* CR pchambart: can we avoid computing free names here ? *) + Name_occurrences.mem_continuation (Flambda.Expr.free_names body) + self_tail_call_continuation + in + if not self_tail_call_is_used then + acc, body + else + let acc, rec_cont_handler, bound_continuation = + match Function_decl.kind decl with + | Curried -> + acc, body, self_tail_call_continuation + | Tupled -> + (* If the function is tupled a recursive call must first unbox the + argument. The function body is replaced by: + + let rec cont self_tail_cal_tupled a b = + let cont self_tail_call tupled_param = + let a = block_load 0 tupled_param in + let b = block_load 1 tupled_param in + apply_cont self_tail_cal_tupled a b + in + function body using 'self_tail_call' + in + apply_cont self_tail_cal_tupled a b + + The tuple is expected to be unboxed and the intermediate continuation + to be eliminated by simplification + *) + let self_tail_call_tupled_continuation = + Continuation.create ~name:"self_tail_call_tupled" () + in + let tupled_var = Variable.create "tupled_param" in + let tupled_param = + Kinded_parameter.create tupled_var Flambda_kind.With_subkind.any_value + in + let block_access : P.Block_access_kind.t = + Values { + tag = Tag.Scannable.zero; + size = Known (Targetint.OCaml.of_int (List.length params)); + field_kind = Any_value; + } + in + let unbox_arg ~pos ~param ~body ~acc = + let var = VB.create param Name_mode.normal in + let pos = Target_imm.int (Targetint.OCaml.of_int pos) in + let named = + Named.create_prim + (Binary ( + Block_load (block_access, Immutable), + Simple.var tupled_var, + Simple.const (Reg_width_const.tagged_immediate pos))) + Debuginfo.none + in + Let_with_acc.create acc + (Bindable_let_bound.singleton var) named + ~body ~free_names_of_body:Unknown + |> Expr_with_acc.create_let + in + let args = List.map Simple.var (Kinded_parameter.List.vars params) in + let cost_metrics_of_handler, acc, tupled_handler = + Acc.measure_cost_metrics acc ~f:(fun acc -> + let acc, tupled_call = + Apply_cont.create self_tail_call_tupled_continuation + ~args + ~dbg:Debuginfo.none + |> Expr_with_acc.create_apply_cont acc + in + let acc, tupled_handler = + List.fold_right (fun (pos, param) (acc, body) -> + unbox_arg ~pos ~param ~body ~acc) + (List.mapi (fun i p -> i, p) + (Kinded_parameter.List.vars params)) + (acc, tupled_call) + in + let acc, tupled_handler = + Continuation_handler_with_acc.create acc + [tupled_param] + ~handler:tupled_handler + ~free_names_of_handler:Unknown + ~is_exn_handler:false + in + acc, tupled_handler) + in + let acc, body' = + Let_cont_with_acc.create_non_recursive acc + self_tail_call_continuation + tupled_handler + ~body + ~free_names_of_body:Unknown + ~cost_metrics_of_handler + in + acc, body', self_tail_call_tupled_continuation + in + let args = + List.map Simple.var (Kinded_parameter.List.vars params) + in + let acc, handler = + Continuation_handler_with_acc.create acc params + ~handler:rec_cont_handler + ~free_names_of_handler:Unknown + ~is_exn_handler:false + in + let handlers = + Continuation.Map.singleton bound_continuation handler + in + let cost_metrics_of_handlers, acc, continuation_body = + Acc.measure_cost_metrics acc ~f:(fun acc -> + Apply_cont.create bound_continuation + ~args ~dbg:Debuginfo.none + |> Expr_with_acc.create_apply_cont acc) + in + Let_cont_with_acc.create_recursive acc handlers ~body:continuation_body + ~cost_metrics_of_handlers + in let inline : Inline_attribute.t = (* We make a decision based on [fallback_inlining_heuristic] here to try to mimic Closure's behaviour as closely as possible, particularly @@ -1027,7 +1177,7 @@ and close_one_function acc ~external_env ~by_closure_id decl let ilambda_to_flambda ~backend ~module_ident ~module_block_size_in_words (ilam : Ilambda.program) = let module Backend = (val backend : Flambda_backend_intf.S) in - let env = Env.empty ~backend in + let env = Env.empty ~backend ~return_continuation:ilam.return_continuation in let module_symbol = Backend.symbol_for_global' ( Ident.create_persistent (Ident.name module_ident)) diff --git a/middle_end/flambda/from_lambda/closure_conversion_aux.ml b/middle_end/flambda/from_lambda/closure_conversion_aux.ml index f391f28734de..97152f8122bc 100644 --- a/middle_end/flambda/from_lambda/closure_conversion_aux.ml +++ b/middle_end/flambda/from_lambda/closure_conversion_aux.ml @@ -17,6 +17,13 @@ [@@@ocaml.warning "+a-4-30-40-41-42"] module Env = struct + + type function_being_defined = { + let_rec_ident : Ident.t; + arity : int; + self_tail_call_continuation : Continuation.t; + } + type t = { variables : Variable.t Ident.Map.t; globals : Symbol.t Numbers.Int.Map.t; @@ -24,13 +31,15 @@ module Env = struct backend : (module Flambda_backend_intf.S); current_unit_id : Ident.t; symbol_for_global' : (Ident.t -> Symbol.t); + return_continuation : Continuation.t; + current_function : function_being_defined option; } let backend t = t.backend let current_unit_id t = t.current_unit_id let symbol_for_global' t = t.symbol_for_global' - let empty ~backend = + let empty ~backend ~return_continuation = let module Backend = (val backend : Flambda_backend_intf.S) in let compilation_unit = Compilation_unit.get_current_exn () in { variables = Ident.Map.empty; @@ -39,17 +48,24 @@ module Env = struct backend; current_unit_id = Compilation_unit.get_persistent_ident compilation_unit; symbol_for_global' = Backend.symbol_for_global'; + return_continuation; + current_function = None; } let clear_local_bindings { variables = _; globals; simples_to_substitute = _; backend; - current_unit_id; symbol_for_global'; } = + current_unit_id; symbol_for_global'; return_continuation = _; + current_function = _; } + ~return_continuation + current_function_info = { variables = Ident.Map.empty; globals; simples_to_substitute = Ident.Map.empty; backend; current_unit_id; symbol_for_global'; + return_continuation; + current_function = Some current_function_info; } let add_var t id var = { t with variables = Ident.Map.add id var t.variables } @@ -117,6 +133,10 @@ module Env = struct let find_simple_to_substitute_exn t id = Ident.Map.find id t.simples_to_substitute + + let return_continuation t = t.return_continuation + + let current_function t = t.current_function end module Acc = struct diff --git a/middle_end/flambda/from_lambda/closure_conversion_aux.mli b/middle_end/flambda/from_lambda/closure_conversion_aux.mli index 4ec42bd0dbfe..529d98022b20 100644 --- a/middle_end/flambda/from_lambda/closure_conversion_aux.mli +++ b/middle_end/flambda/from_lambda/closure_conversion_aux.mli @@ -24,9 +24,22 @@ module Env : sig type t - val empty : backend:(module Flambda_backend_intf.S) -> t - - val clear_local_bindings : t -> t + val empty + : backend:(module Flambda_backend_intf.S) + -> return_continuation:Continuation.t + -> t + + type function_being_defined = { + let_rec_ident : Ident.t; + arity : int; + self_tail_call_continuation : Continuation.t; + } + + val clear_local_bindings + : t + -> return_continuation:Continuation.t + -> function_being_defined + -> t val add_var : t -> Ident.t -> Variable.t -> t val add_vars : t -> Ident.t list -> Variable.t list -> t @@ -57,6 +70,10 @@ module Env : sig val backend : t -> (module Flambda_backend_intf.S) val current_unit_id : t -> Ident.t val symbol_for_global' : t -> (Ident.t -> Symbol.t) + + val return_continuation : t -> Continuation.t + + val current_function : t -> function_being_defined option end (** Used to pipe some data through closure conversion *) diff --git a/middle_end/flambda/naming/name_occurrences.ml b/middle_end/flambda/naming/name_occurrences.ml index e736e9416b6b..5cf49311487f 100644 --- a/middle_end/flambda/naming/name_occurrences.ml +++ b/middle_end/flambda/naming/name_occurrences.ml @@ -1041,6 +1041,8 @@ let mem_newer_version_of_code_id t code_id = For_code_ids.mem t.newer_version_of_code_ids code_id let mem_closure_var t closure_var = For_closure_vars.mem t.closure_vars closure_var +let mem_continuation t continuation = + For_continuations.mem t.continuations continuation let remove_var t var = if For_names.is_empty t.names then t diff --git a/middle_end/flambda/naming/name_occurrences.mli b/middle_end/flambda/naming/name_occurrences.mli index 3b37cdadecd5..5281ee3b5fa4 100644 --- a/middle_end/flambda/naming/name_occurrences.mli +++ b/middle_end/flambda/naming/name_occurrences.mli @@ -153,6 +153,8 @@ val mem_newer_version_of_code_id : t -> Code_id.t -> bool val mem_closure_var : t -> Var_within_closure.t -> bool +val mem_continuation : t -> Continuation.t -> bool + val remove_var : t -> Variable.t -> t val remove_code_id_or_symbol : t -> Code_id_or_symbol.t -> t diff --git a/middle_end/flambda/simplify/simplify_apply_cont_expr.ml b/middle_end/flambda/simplify/simplify_apply_cont_expr.ml index e72eb906fbd1..00eea4234fe3 100644 --- a/middle_end/flambda/simplify/simplify_apply_cont_expr.ml +++ b/middle_end/flambda/simplify/simplify_apply_cont_expr.ml @@ -141,6 +141,35 @@ let rebuild_apply_cont apply_cont ~args ~rewrite_id uacc ~after_rebuild = Cost_metrics.from_size (Code_size.apply_cont apply_cont), Apply_cont.free_names apply_cont) +let apply_cont_use_kind apply_cont : Continuation_use_kind.t = + (* CR mshinwell: Is [Continuation.sort] reliable enough to detect + the toplevel continuation? Probably not -- we should store it in + the environment. *) + match Continuation.sort (AC.continuation apply_cont) with + | Normal_or_exn -> + begin match Apply_cont.trap_action apply_cont with + | None -> Inlinable + | Some (Push _) -> Non_inlinable { escaping = false; } + | Some (Pop { raise_kind; _ }) -> + match raise_kind with + | None | Some Regular | Some Reraise -> + (* Until such time as we can manually add to the backtrace buffer, + we only convert "raise_notrace" into jumps, except if debugging + information generation is disabled. (This matches the handling + at Cmm level; see [Cmm_helpers.raise_prim].) + We set [escaping = true] for the cases we do not want to + convert into jumps. *) + if !Clflags.debug then Non_inlinable { escaping = true; } + else Non_inlinable { escaping = false; } + | Some No_trace -> + Non_inlinable { escaping = false; } + end + | Return | Toplevel_return -> + Non_inlinable { escaping = false; } + | Define_root_symbol -> + assert (Option.is_none (Apply_cont.trap_action apply_cont)); + Inlinable + let simplify_apply_cont dacc apply_cont ~down_to_up = let { S. simples = args; simple_tys = arg_types; } = S.simplify_simples dacc (AC.args apply_cont) @@ -152,35 +181,7 @@ let simplify_apply_cont dacc apply_cont ~down_to_up = (List.map Simple.free_names args) data_flow) in - let use_kind : Continuation_use_kind.t = - (* CR mshinwell: Is [Continuation.sort] reliable enough to detect - the toplevel continuation? Probably not -- we should store it in - the environment. *) - match Continuation.sort (AC.continuation apply_cont) with - | Normal_or_exn -> - begin match Apply_cont.trap_action apply_cont with - | None -> Inlinable - | Some (Push _) -> Non_inlinable { escaping = false; } - | Some (Pop { raise_kind; _ }) -> - match raise_kind with - | None | Some Regular | Some Reraise -> - (* Until such time as we can manually add to the backtrace buffer, - we only convert "raise_notrace" into jumps, except if debugging - information generation is disabled. (This matches the handling - at Cmm level; see [Cmm_helpers.raise_prim].) - We set [escaping = true] for the cases we do not want to - convert into jumps. *) - if !Clflags.debug then Non_inlinable { escaping = true; } - else Non_inlinable { escaping = false; } - | Some No_trace -> - Non_inlinable { escaping = false; } - end - | Return | Toplevel_return -> - Non_inlinable { escaping = false; } - | Define_root_symbol -> - assert (Option.is_none (Apply_cont.trap_action apply_cont)); - Inlinable - in + let use_kind = apply_cont_use_kind apply_cont in let dacc, rewrite_id = DA.record_continuation_use dacc (AC.continuation apply_cont) use_kind ~env_at_use:(DA.denv dacc) ~arg_types diff --git a/middle_end/flambda/simplify/simplify_apply_cont_expr.mli b/middle_end/flambda/simplify/simplify_apply_cont_expr.mli index 95cbabfb34f5..d44c1e9be1da 100644 --- a/middle_end/flambda/simplify/simplify_apply_cont_expr.mli +++ b/middle_end/flambda/simplify/simplify_apply_cont_expr.mli @@ -16,4 +16,6 @@ [@@@ocaml.warning "+a-30-40-41-42"] +val apply_cont_use_kind : Flambda.Apply_cont.t -> Continuation_use_kind.t + val simplify_apply_cont : Flambda.Apply_cont.t Simplify_common.expr_simplifier diff --git a/middle_end/flambda/simplify/simplify_switch_expr.ml b/middle_end/flambda/simplify/simplify_switch_expr.ml index ecd2d2a54378..22afc0f2c4ae 100644 --- a/middle_end/flambda/simplify/simplify_switch_expr.ml +++ b/middle_end/flambda/simplify/simplify_switch_expr.ml @@ -279,12 +279,13 @@ let simplify_switch ~simplify_let dacc switch ~down_to_up = TE.add_env_extension typing_env_at_use env_extension |> DE.with_typing_env (DA.denv dacc) in + let use_kind = Simplify_apply_cont_expr.apply_cont_use_kind action in let args = AC.args action in match args with | [] -> let dacc, rewrite_id = DA.record_continuation_use dacc (AC.continuation action) - (Non_inlinable { escaping = false; }) ~env_at_use ~arg_types:[] + use_kind ~env_at_use ~arg_types:[] in let dacc = DA.map_data_flow dacc ~f:( @@ -300,7 +301,7 @@ let simplify_switch ~simplify_let dacc switch ~down_to_up = in let dacc, rewrite_id = DA.record_continuation_use dacc (AC.continuation action) - (Non_inlinable { escaping = false; }) ~env_at_use ~arg_types + use_kind ~env_at_use ~arg_types in let arity = List.map T.kind arg_types in let action = Apply_cont.update_args action ~args in