Lambda lifting and closures

Anonymous functions

After last class, we can now talk about functions as values, but only when they come with names! We'd like to be able to write anonymous functions like (lambda (x) (+ x 4)). This is analogous to OCaml's (fun x -> x + 4) syntax.

Our implementation will feel a little bit like cheating at first -- it seems too easy, and we won't change our compiler or interpreter at all! Consider the following program, which implements a range function returning a list of numbers and a map over lists:

(define (range lo hi)
  (if (< lo hi) 
    (pair lo (range (add1 lo) hi))
    false))
(define (map f l)
  (if (not l) 
    l
    (pair (f (left l)) (map f (right l)))))
(print (map (lambda (x) (+ x 4)) (range 0 2)))

We can transform this program into an equivalent one without any anonymous functions.

(define (range lo hi)
  (if (< lo hi) 
    (pair lo (range (add1 lo) hi))
    false))
(define (map f l)
  (if (not l) 
    l
    (pair (f (left l)) (map f (right l)))))
(define (lambda_fun_0 x) (+ x 4))
(print (map lambda_fun_0 (range 0 2)))

We'll do this by adding a new version of our AST type, which includes lambdas, and parsing into that type; then we'll convert terms of that type into our familiar, lambda-free version.

ASTs with lambdas

ast_lam.ml

open S_exp
open Util
include Ast

type expr_lam =
  | Prim0 of prim0
  | Prim1 of prim1 * expr_lam
  | Prim2 of prim2 * expr_lam * expr_lam
  | Let of string * expr_lam * expr_lam
  | If of expr_lam * expr_lam * expr_lam
  | Do of expr_lam list
  | Num of int
  | Var of string
  | Call of expr_lam * expr_lam list
  | True
  | False
  | Lambda of string list * expr_lam

let is_sym e = match e with Sym _ -> true | _ -> false

let as_sym e = match e with Sym s -> s | _ -> raise Not_found

let rec expr_lam_of_s_exp : s_exp -> expr_lam = function
  | Num x ->
      Num x
  | Sym "true" ->
      True
  | Sym "false" ->
      False
  | Sym var ->
      Var var
  | Lst [Sym "let"; Lst [Lst [Sym var; exp]]; body] ->
      Let (var, expr_lam_of_s_exp exp, expr_lam_of_s_exp body)
  | Lst (Sym "do" :: exps) when List.length exps > 0 ->
      Do (List.map expr_lam_of_s_exp exps)
  | Lst [Sym "if"; test_s; then_s; else_s] ->
      If
        ( expr_lam_of_s_exp test_s
        , expr_lam_of_s_exp then_s
        , expr_lam_of_s_exp else_s )
  | Lst [Sym "lambda"; Lst args; body] when List.for_all is_sym args ->
      Lambda (List.map as_sym args, expr_lam_of_s_exp body)
  | Lst [Sym prim] when Option.is_some (prim0_of_string prim) ->
      Prim0 (Option.get (prim0_of_string prim))
  | Lst [Sym prim; arg] when Option.is_some (prim1_of_string prim) ->
      Prim1 (Option.get (prim1_of_string prim), expr_lam_of_s_exp arg)
  | Lst [Sym prim; arg1; arg2] when Option.is_some (prim2_of_string prim) ->
      Prim2
        ( Option.get (prim2_of_string prim)
        , expr_lam_of_s_exp arg1
        , expr_lam_of_s_exp arg2 )
  | Lst (f :: args) ->
      Call (expr_lam_of_s_exp f, List.map expr_lam_of_s_exp args)
  | e ->
      raise (BadSExpression e)

let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function
  | Num x ->
      Num x
  | Var s ->
      Var s
  | True ->
      True
  | False ->
      False
  | If (test_exp, then_exp, else_exp) ->
      If
        ( expr_of_expr_lam defns test_exp
        , expr_of_expr_lam defns then_exp
        , expr_of_expr_lam defns else_exp )
  | Let (var, exp, body) ->
      Let (var, expr_of_expr_lam defns exp, expr_of_expr_lam defns body)
  | Prim0 p ->
      Prim0 p
  | Prim1 (p, e) ->
      Prim1 (p, expr_of_expr_lam defns e)
  | Prim2 (p, e1, e2) ->
      Prim2 (p, expr_of_expr_lam defns e1, expr_of_expr_lam defns e2)
  | Do exps ->
      Do (List.map (expr_of_expr_lam defns) exps)
  | Call (exp, args) ->
      Call (expr_of_expr_lam defns exp, List.map (expr_of_expr_lam defns) args)
  | Lambda (args, body) ->
      let name = gensym "_lambda" in
      let body = expr_of_expr_lam defns body in
      defns := {name; args; body} :: !defns ;
      Var name

let program_of_s_exps (exps : s_exp list) : program =
  let defns = ref [] in
  let rec get_args args =
    match args with
    | Sym v :: args ->
        v :: get_args args
    | e :: _ ->
        raise (BadSExpression e)
    | [] ->
        []
  in
  let get_defn = function
    | Lst [Sym "define"; Lst (Sym name :: args); body] ->
        let args = get_args args in
        {name; args; body= body |> expr_lam_of_s_exp |> expr_of_expr_lam defns}
    | e ->
        raise (BadSExpression e)
  in
  let rec go exps =
    match exps with
    | [e] ->
        let body = e |> expr_lam_of_s_exp |> expr_of_expr_lam defns in
        {defns= List.rev !defns; body}
    | d :: exps ->
        let defn = get_defn d in
        defns := defn :: !defns ;
        go exps
    | _ ->
        raise (BadSExpression (Sym "empty"))
  in
  go exps

A hiccup

What should the following program evaluate to?

(define (adder x) (lambda (y) (+ x y)))
(print (adder 2) 3)

This will have some trouble. What would things look like if we did the lambda lifting step by hand? See the problem?

We need to implement something called a closure. A "closure," somehow, is supposed to "close" the environment -- to bundle together the values of variables that are "open," i.e. unbound, in the body of a lifted lambda function.

In the interpreter

Our interpreter first runs into functions when they're represented as variables appearing in the list of function definitions. This isn't going to be good enough anymore.

We need to "undo" our cheat from before, where our AST had no particular representation for functions. We'll add one constructor to the expr datatype:

| Closure of string

and also update the Lambda case of expr_of_expr_lam.

Our lambda-lifting will end up looking something like this:

(define (_lambda_1 y) (+ x y))
(define (adder x) (closure _lambda_1))
(print (adder 2) 3)

A function value now needs to store a value symtab.

Closures in the interpreter

ast.ml

type expr = 
(* ... *)
| Closure of string

ast_lam.ml

let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr =
(* ... *)
| Lambda (args, body) ->
let name = gensym "_lambda" in
defns := {name; args; body= expr_of_expr_lam defns body} :: !defns ;
Closure name

interp.ml

type value =
  | Number of int
  | Boolean of bool
  | Pair of (value * value)
  | Function of (string * value symtab)


let rec interp_exp (defns : defn list) (env : value symtab)
    (exp : expr) : value =
  (* ... *)
  | Var var when is_defn defns var ->
      Function (var, Symtab.empty)
  | Closure f -> Function (f, env)
  | Call (f, args) -> (
      let vals = args |> List.map (interp_exp defns env) in
      let fv = interp_exp defns env f in
      match fv with
      | Function (name, saved_env) when is_defn defns name ->
          let defn = get_defn defns name in
          if List.length args = List.length defn.args then
            let fenv =
              List.combine defn.args vals |> Symtab.add_list saved_env
            in
            interp_exp defns fenv defn.body
          else raise (BadExpression exp)
      | _ ->
          raise (BadExpression exp) )

Closures in the compiler

In the interpreter implementation of closures, we've pretty much just packaging up the environment and tagging it onto our function values. How could we possibly do this in the compiler? There's no runtime representation of a symbol table.

Note that we don't need the whole symbol table. We just need the variables that are free in the body of the lambda.

Let's write a function to compute what variables appear free in an expression.

let rec fv (bound : string list) (exp : expr) =
  match exp with
  | Var s when not (List.mem s bound) ->
      [s]
  | Let (v, e, body) ->
      fv bound e @ fv (v :: bound) body
  | If (te, the, ee) ->
      fv bound te @ fv bound the @ fv bound ee
  | Do es ->
      List.concat_map (fv bound) es
  | Call (exp, args) ->
      fv bound exp @ List.concat_map (fv bound) args
  | Prim1 (_, e) ->
      fv bound e
  | Prim2 (_, e1, e2) ->
      fv bound e1 @ fv bound e2
  | _ ->
      []

We know, at compile time, which variables appear free in the body of any function definition. So we'll change our function representation to include both a pointer to a function definition and values for these free variables.

Consider the following program:

(print ((let ((x 2)) (lambda (y) (+ x y))) 3))

We're going to store our representation of the lambda expression on the heap at runtime. The first cell on the heap will contain a pointer to the function code, like we talked about on Monday. The next cell will contain the representation of x at the point when we compile the lambda expression.

When we call this lambda function, we'll set up a stack frame with its argument (3) stored at rsp - 8 and a pointer to the function on the heap at rsp - 16.

Let's run through an example of all of this in action:

(let ((y 4))
  (let ((f (lambda (x) (+ x y))))
    (f 3)))