mpri-funcprog-project/src/CPS.ml

128 lines
4.6 KiB
OCaml

(* The source calculus. *)
module S = Lambda
(* The target calculus. *)
module T = Tail
exception NotValue of S.term
(** ^ Raised when trying to use a non-value term as such *)
exception NotLightCPSable of S.term
let freshId =
(** Generates a fresh variable name string *)
let cId = ref 0 in
(fun () ->
incr cId ;
(string_of_int !cId))
let freshWithPrefix pre =
Atom.fresh (pre ^ (freshId ()) ^ "_")
let prefixHint prefix hint = match hint with
| Some h -> prefix ^ h
| None -> prefix
let freshBlockVarHinted hint = freshWithPrefix (prefixHint "bl_" hint)
let freshBlockVar () = freshBlockVarHinted None
let freshVarHinted hint = freshWithPrefix (prefixHint "v_" hint)
let freshVar () = freshVarHinted None
let letCont name varName body next =
(** Allocates a block for a continuation, then runs [next] *)
T.LetBlo(name, T.Lam(T.NoSelf, [varName], body), next)
let rec has_calls (t: S.term): bool = match t with
| S.Var _ | S.Lit _ | S.BinOp _ -> false
| S.Lam _ -> false
(* A lambda itself may contain calls, but this call is not evaluated at
* declaration time *)
| S.App _ -> true
| S.IfZero _ -> true (* Cannot optimize that with the current languages *)
| S.Print _ -> true (* Cannot optimize that with the current languages *)
| S.Let (_, value, next) ->
List.exists has_calls [value; next]
let rec cps_value (t: S.term) : T.value = match t with
| S.Var v -> T.VVar v
| S.Lit v -> T.VLit v
| S.BinOp (l, op, r) -> T.VBinOp (cps_value l, op, cps_value r)
| S.Let _ | S.Lam _ | S.App _ | S.Print _ | S.IfZero _ -> raise (NotValue t)
let cps_value_as_term (t: S.term) (cont: T.variable): T.term =
T.TailCall(T.vvar cont, [cps_value t])
let rec cps_term_inner (t: S.term) (cont: T.variable) (nameHint: string option)
: T.term = match t with
| S.Var _ -> cps_value_as_term t cont
| S.Lit _ -> cps_value_as_term t cont
| S.BinOp (t1, op, t2) ->
(try cps_value_as_term t cont
with NotValue _ -> (
let t1Var = freshVar ()
and t2Var = freshVar () in
light_term t1Var t1 None @@
light_term t2Var t2 None @@
T.TailCall(T.vvar cont,
[T.VBinOp(T.vvar t1Var, op, T.vvar t2Var)])
))
| S.Lam _ as lambda ->
let fName = freshBlockVarHinted nameHint in
light_term fName lambda None @@
T.TailCall(T.vvar cont, T.vvars [fName])
| S.App (f, x) ->
let xVal = freshVarHinted nameHint
and fVal = freshVar () in
light_term xVal x None @@
light_term fVal f None @@
T.TailCall (T.vvar fVal, T.vvars [xVal; cont])
| S.Print term ->
let termVal = freshVar () in
light_term termVal term None @@
T.Print (T.vvar termVal,
T.TailCall(T.vvar cont, T.vvars [termVal]))
| S.Let (var, value, next) ->
light_term var value (Some (Atom.hint var)) @@
cps_term_inner next cont None
| S.IfZero (expr, tIf, tElse) ->
let exprVal = freshVar () in
light_term exprVal expr None @@
(T.IfZero (T.vvar exprVal,
cps_term_inner tIf cont None,
cps_term_inner tElse cont None))
and light_term varName valExpr valHint next =
match has_calls valExpr with
| true ->
let contName = freshBlockVar () in
letCont contName varName next @@
cps_term_inner valExpr contName valHint
| false -> (match valExpr with
(* This term has no calls: no need to CPS-transform it *)
| S.Var _ | S.Lit _ | S.BinOp _ ->
T.LetVal (
varName,
cps_value valExpr,
next)
| S.Let (subLetVar, subLetVal, subLetNext) ->
T.LetVal (
subLetVar,
cps_value subLetVal,
light_term varName subLetNext valHint next)
| S.Lam(self, lamVar, lamBody) ->
let lamCont = freshBlockVar () in
T.LetBlo (
varName, T.Lam(
self, [lamVar; lamCont],
cps_term_inner lamBody lamCont None),
next)
| S.App _ | S.Print _ | S.IfZero _ ->
raise (NotLightCPSable valExpr)
)
let cps_term (t: S.term): T.term =
(** Entry point. Transforms a [Lambda] term into a [Tail] term, applying a
* continuation-passing-style transformation. *)
let exitBlock = freshBlockVarHinted (Some "exit") in
letCont exitBlock (freshVar ()) T.Exit @@
cps_term_inner t exitBlock (Some "main_entry")