mpri-funcprog-project/src/Defun.ml

109 lines
3.8 KiB
OCaml

(* The source calculus. *)
module S = Tail
(* The target calculus. *)
module T = Top
module IMap = Map.Make(struct type t = int let compare = compare end)
exception InconsistentFuncState of int
(** Thrown when somehow, the two parts of a [funcState] are inconsistent. The
* parameter is the arity at which the two parts diverged. *)
let freshTag =
let count = ref 0 in
fun () ->
incr count;
!count
(** Function state handling *)
type applyFunc = ApplyFunc of T.variable * T.variable list
type funcArityList = (T.branch list) IMap.t
type applyArityMap = applyFunc IMap.t
type funcState = FuncState of applyArityMap * funcArityList
let empty_fs = FuncState(IMap.empty, IMap.empty)
let get_apply (FuncState(applyMap, funcList) as fs) arity =
let rec list_init f n acc = match n with
| 0 -> acc
| n -> list_init f (n-1) ((f (n-1)) :: acc) in
let new_apply () =
let name = Atom.fresh ("apply" ^ (string_of_int arity) ^ "_") in
let args = list_init (function
| 0 -> Atom.fresh "fct_"
| n -> Atom.fresh ("arg" ^ (string_of_int n) ^ "_"))
arity [] in
ApplyFunc(name, args)
in
try
IMap.find arity applyMap, fs
with Not_found ->
let newArityFct = new_apply () in
newArityFct, FuncState(IMap.add arity newArityFct applyMap, funcList)
let add_func (FuncState(applyMap, funcList)) arity fct =
let cList = try
IMap.find arity funcList
with Not_found -> [] in
FuncState(applyMap, IMap.add arity (fct::cList) funcList)
(** AST walking *)
let rec walk_term fs t =
match t with
| S.Exit ->
fs, T.Exit
| S.TailCall (func, args) ->
let ApplyFunc(applyFct, _), fs =
get_apply fs ((List.length args) + 1) in
fs, T.TailCall(applyFct, func :: args)
| S.Print (v, next) ->
let fs, nNext = walk_term fs next in
fs, T.Print(v, nNext)
| S.LetVal (var, value, next) ->
let fs, nNext = walk_term fs next in
fs, T.LetVal(var, value, nNext)
| S.LetBlo (var, S.Lam(_, vars, body), next) ->
(* FIXME is handling of recursive functions correct? *)
let fs, nNext = walk_term fs next in
let fs, nBody = walk_term fs body in
let arity = List.length vars + 1 in
let ApplyFunc(_, args), fs =
get_apply fs arity in
let nBody = List.fold_left2 (fun prevBody formal actual ->
T.LetVal(formal, actual, prevBody)) nBody
vars
(S.vvars (List.tl args)) in
let thisTag = freshTag () in
let freeVars = Atom.Set.elements @@
Atom.Set.diff
(S.fv_term body)
(Atom.Set.of_list vars)
in
let thisFunc =
T.Branch(thisTag, freeVars, nBody) in
let fs = add_func fs arity thisFunc in
fs, T.LetBlo (var, T.Con(thisTag, T.vvars freeVars), nNext)
let apply_of_arity name args branches =
(** Creates a [T.function_declaration] for the [apply] function of the
* defunctionalization process, of arity [arity := List.len args]. It thus
* handles functions with original arity of [arity - 1], since it also has
* to get its closure. *)
let body = T.Swi(T.vvar @@ List.hd args, branches) in
T.Fun(name, args, body)
let defun_term (t : S.term) : T.program =
let FuncState(applyOfArity, funcOfArity), mainTerm =
walk_term empty_fs t in
let applyFuncs = IMap.fold
(fun arity (ApplyFunc(name, args)) accu ->
let branches = (try
IMap.find arity funcOfArity
with Not_found -> raise (InconsistentFuncState arity)) in
(apply_of_arity name args branches) :: accu)
applyOfArity [] in
T.Prog(applyFuncs, mainTerm)