First-class and anonymous functions

We'll continue right where we left off adding support for function objects.

Reminder: extending the AST

Last time, we introduced a notion of abstract syntax for our language. We can convert s_exp terms into ASTs pretty easily. To support function objects in our AST, we had to slightly change what Call expressions looked like:

ast.ml

type expr =
  | Prim0 of prim0
  | Prim1 of prim1 * expr
  | Prim2 of prim2 * expr * expr
  | Let of string * expr * expr
  | If of expr * expr * expr
  | Do of expr list
  | Num of int
  | Var of string
  | Call of expr * expr list
  | True
  | False

let rec expr_of_s_exp : s_exp -> expr = function
  (* some cases elided ... *)
  | Lst (f :: args) ->
      Call (f, List.map expr_of_s_exp args)
  | e ->
      raise (BadSExpression e)

Extending the interpreter

We'll need a new type of value in our interpreter. It's up for debate how we should represent these function values!

The main point here: we're going to reuse the definition list as much as we can.

type value =
  | Number of int
  | Boolean of bool
  | Pair of (value * value)
  | Function of string

let rec string_of_value (v : value) : string =
  (* some cases elided ... *)
  | Function _ -> "<function>" 

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value
    =
  match exp with
  (* some cases elided ... *)
  | Call (f, args) -> (
      let argvals = args |> List.map (interp_exp defns env) in
      let fv = interp_exp defns env f in
      match fv with
      | Function name when is_defn defns name ->
          let defn = get_defn defns name in
          if List.length args = List.length defn.args then
            let fenv = List.combine defn.args argvals |> Symtab.of_list in
            interp_exp defns fenv defn.body
          else raise (BadExpression exp)
      | _ ->
          raise (BadExpression exp) )
  | Var var when is_defn defns var ->
      Function var

Extending the compiler

Similarly to the interpreter, we want to use a function name to look up the location of that function's code. The runtime representation of a function needs to be a pointer.

let fn_tag = 0b110

let ensure_fn (op : operand) : directive list =
  [ Mov (Reg R8, op)
  ; And (Reg R8, Imm heap_mask)
  ; Cmp (Reg R8, Imm fn_tag)
  ; Jnz "error" ]

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
  (* some cases elided ... *)
  | Call (f, args) when not is_tail ->
      let stack_base = align_stack_index (stack_index + 8) in
      let compiled_args =
        args
        |> List.mapi (fun i arg ->
               compile_exp defns tab (stack_base - (8 * (i + 2))) arg false
               @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)])
        |> List.concat
      in
      compiled_args
      @ compile_exp defns tab 
          (stack_base - (8 * (List.length args + 2))) 
          f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ [ Add (Reg Rsp, Imm stack_base)
        ; ComputedCall (Reg Rax)
        ; Sub (Reg Rsp, Imm stack_base) ]
  | Call (f, args) when is_tail ->
      let compiled_args =
        args
        |> List.mapi (fun i arg ->
               compile_exp defns tab (stack_index - (8 * i)) arg false
               @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)])
        |> List.concat
      in
      let moved_args =
        args
        |> List.mapi (fun i _ ->
               [ Mov (Reg R8, stack_address (stack_index - (8 * i)))
               ; Mov (stack_address ((i + 1) * -8), Reg R8) ])
        |> List.concat
      in
      compiled_args
      @ compile_exp defns tab 
          (stack_index - (8 * (List.length args + 2))) 
          f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ moved_args @ [ComputedJmp (Reg Rax)]
  | Var var when is_defn defns var ->
      [LeaLabel (Reg Rax, defn_label var); Or (Reg Rax, Imm fn_tag)]

let compile_defn defns defn =
  let ftab =
    defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list
  in
  [Align 8; Label (defn_label defn.name)]
  @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
  @ [Ret]

We should also update our runtime to print functions:

#define fn_tag 0b110

void print_value(uint64_t value) {
    ...
  } else if ((value & heap_mask) == fn_tag) {
    printf("<function>");
  }

Note, also add -no-pie to our linker call.

Anonymous functions

We can now talk about functions as values, but only when they come with names! We'd like to be able to write anonymous functions like (lambda (x) (+ x 4)). This is analogous to OCaml's (fun x -> x + 4) syntax.

Our implementation will feel a little bit like cheating at first -- it seems too easy, and we won't change our compiler or interpreter at all! Consider the following program, which implements a range function returning a list of numbers and a map over lists:

(define (range lo hi)
  (if (< lo hi) 
    (pair lo (range (add1 lo) hi))
    false))
(define (map f l)
  (if (not l) 
    l
    (pair (f (left l)) (map f (right l)))))
(print (map (lambda (x) (+ x 4)) (range 0 2)))

We can transform this program into an equivalent one without any anonymous functions.

(define (range lo hi)
  (if (< lo hi) 
    (pair lo (range (add1 lo) hi))
    false))
(define (map f l)
  (if (not l) 
    l
    (pair (f (left l)) (map f (right l)))))
(define (lambda_fun_0 x) (+ x 4))
(print (map lambda_fun_0 (range 0 2)))

We'll do this by adding a new version of our AST type, which includes lambdas, and parsing into that type; then we'll convert terms of that type into our familiar, lambda-free version.

ASTs with lambdas

ast_lam.ml

open S_exp
open Util
include Ast

type expr_lam =
  | Prim0 of prim0
  | Prim1 of prim1 * expr_lam
  | Prim2 of prim2 * expr_lam * expr_lam
  | Let of string * expr_lam * expr_lam
  | If of expr_lam * expr_lam * expr_lam
  | Do of expr_lam list
  | Num of int
  | Var of string
  | Call of expr_lam * expr_lam list
  | True
  | False
  | Lambda of string list * expr_lam

let is_sym e = match e with Sym _ -> true | _ -> false

let as_sym e = match e with Sym s -> s | _ -> raise Not_found

let rec expr_lam_of_s_exp : s_exp -> expr_lam = function
  | Num x ->
      Num x
  | Sym "true" ->
      True
  | Sym "false" ->
      False
  | Sym var ->
      Var var
  | Lst [Sym "let"; Lst [Lst [Sym var; exp]]; body] ->
      Let (var, expr_lam_of_s_exp exp, expr_lam_of_s_exp body)
  | Lst (Sym "do" :: exps) when List.length exps > 0 ->
      Do (List.map expr_lam_of_s_exp exps)
  | Lst [Sym "if"; test_s; then_s; else_s] ->
      If
        ( expr_lam_of_s_exp test_s
        , expr_lam_of_s_exp then_s
        , expr_lam_of_s_exp else_s )
  | Lst [Sym "lambda"; Lst args; body] when List.for_all is_sym args ->
      Lambda (List.map as_sym args, expr_lam_of_s_exp body)
  | Lst [Sym prim] when Option.is_some (prim0_of_string prim) ->
      Prim0 (Option.get (prim0_of_string prim))
  | Lst [Sym prim; arg] when Option.is_some (prim1_of_string prim) ->
      Prim1 (Option.get (prim1_of_string prim), expr_lam_of_s_exp arg)
  | Lst [Sym prim; arg1; arg2] when Option.is_some (prim2_of_string prim) ->
      Prim2
        ( Option.get (prim2_of_string prim)
        , expr_lam_of_s_exp arg1
        , expr_lam_of_s_exp arg2 )
  | Lst (f :: args) ->
      Call (expr_lam_of_s_exp f, List.map expr_lam_of_s_exp args)
  | e ->
      raise (BadSExpression e)

let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function
  | Num x ->
      Num x
  | Var s ->
      Var s
  | True ->
      True
  | False ->
      False
  | If (test_exp, then_exp, else_exp) ->
      If
        ( expr_of_expr_lam defns test_exp
        , expr_of_expr_lam defns then_exp
        , expr_of_expr_lam defns else_exp )
  | Let (var, exp, body) ->
      Let (var, expr_of_expr_lam defns exp, expr_of_expr_lam defns body)
  | Prim0 p ->
      Prim0 p
  | Prim1 (p, e) ->
      Prim1 (p, expr_of_expr_lam defns e)
  | Prim2 (p, e1, e2) ->
      Prim2 (p, expr_of_expr_lam defns e1, expr_of_expr_lam defns e2)
  | Do exps ->
      Do (List.map (expr_of_expr_lam defns) exps)
  | Call (exp, args) ->
      Call (expr_of_expr_lam defns exp, List.map (expr_of_expr_lam defns) args)
  | Lambda (args, body) ->
      let name = gensym "_lambda" in
      let body = expr_of_expr_lam defns body in
      defns := {name; args; body} :: !defns ;
      Var name

let program_of_s_exps (exps : s_exp list) : program =
  let defns = ref [] in
  let rec get_args args =
    match args with
    | Sym v :: args ->
        v :: get_args args
    | e :: _ ->
        raise (BadSExpression e)
    | [] ->
        []
  in
  let get_defn = function
    | Lst [Sym "define"; Lst (Sym name :: args); body] ->
        let args = get_args args in
        {name; args; body= body |> expr_lam_of_s_exp |> expr_of_expr_lam defns}
    | e ->
        raise (BadSExpression e)
  in
  let rec go exps =
    match exps with
    | [e] ->
        let body = e |> expr_lam_of_s_exp |> expr_of_expr_lam defns in
        {defns= List.rev !defns; body}
    | d :: exps ->
        let defn = get_defn d in
        defns := defn :: !defns ;
        go exps
    | _ ->
        raise (BadSExpression (Sym "empty"))
  in
  go exps

A hiccup

What should the following program evaluate to?

(define (adder x) (lambda (y) (+ x y)))
(print (adder 2) 3)

But this will have some trouble. What would things look like if we did the lambda lifting step by hand? See the problem?