First-class and anonymous functions
We'll continue right where we left off adding support for function objects.
Reminder: extending the AST
Last time, we introduced a notion of abstract syntax for our language.
We can convert s_exp
terms into ASTs pretty easily.
To support function objects in our AST, we had to slightly change what Call
expressions looked like:
ast.ml
type expr = | Prim0 of prim0 | Prim1 of prim1 * expr | Prim2 of prim2 * expr * expr | Let of string * expr * expr | If of expr * expr * expr | Do of expr list | Num of int | Var of string | Call of expr * expr list | True | False let rec expr_of_s_exp : s_exp -> expr = function (* some cases elided ... *) | Lst (f :: args) -> Call (f, List.map expr_of_s_exp args) | e -> raise (BadSExpression e)
Extending the interpreter
We'll need a new type of value in our interpreter. It's up for debate how we should represent these function values!
The main point here: we're going to reuse the definition list as much as we can.
type value = | Number of int | Boolean of bool | Pair of (value * value) | Function of string let rec string_of_value (v : value) : string = (* some cases elided ... *) | Function _ -> "<function>" let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided ... *) | Call (f, args) -> ( let argvals = args |> List.map (interp_exp defns env) in let fv = interp_exp defns env f in match fv with | Function name when is_defn defns name -> let defn = get_defn defns name in if List.length args = List.length defn.args then let fenv = List.combine defn.args argvals |> Symtab.of_list in interp_exp defns fenv defn.body else raise (BadExpression exp) | _ -> raise (BadExpression exp) ) | Var var when is_defn defns var -> Function var
Extending the compiler
Similarly to the interpreter, we want to use a function name to look up the location of that function's code. The runtime representation of a function needs to be a pointer.
let fn_tag = 0b110 let ensure_fn (op : operand) : directive list = [ Mov (Reg R8, op) ; And (Reg R8, Imm heap_mask) ; Cmp (Reg R8, Imm fn_tag) ; Jnz "error" ] let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided ... *) | Call (f, args) when not is_tail -> let stack_base = align_stack_index (stack_index + 8) in let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_base - (8 * (i + 2))) arg false @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)]) |> List.concat in compiled_args @ compile_exp defns tab (stack_base - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ [ Add (Reg Rsp, Imm stack_base) ; ComputedCall (Reg Rax) ; Sub (Reg Rsp, Imm stack_base) ] | Call (f, args) when is_tail -> let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_index - (8 * i)) arg false @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)]) |> List.concat in let moved_args = args |> List.mapi (fun i _ -> [ Mov (Reg R8, stack_address (stack_index - (8 * i))) ; Mov (stack_address ((i + 1) * -8), Reg R8) ]) |> List.concat in compiled_args @ compile_exp defns tab (stack_index - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ moved_args @ [ComputedJmp (Reg Rax)] | Var var when is_defn defns var -> [LeaLabel (Reg Rax, defn_label var); Or (Reg Rax, Imm fn_tag)] let compile_defn defns defn = let ftab = defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in [Align 8; Label (defn_label defn.name)] @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true @ [Ret]
We should also update our runtime to print functions:
#define fn_tag 0b110 void print_value(uint64_t value) { ... } else if ((value & heap_mask) == fn_tag) { printf("<function>"); }
Note, also add -no-pie
to our linker call.
Anonymous functions
We can now talk about functions as values,
but only when they come with names!
We'd like to be able to write anonymous functions like
(lambda (x) (+ x 4))
.
This is analogous to OCaml's (fun x -> x + 4)
syntax.
Our implementation will feel a little bit like cheating at first -- it seems too easy, and we won't change our compiler or interpreter at all!
Consider the following program,
which implements a range
function returning a list of numbers and a map
over lists:
(define (range lo hi) (if (< lo hi) (pair lo (range (add1 lo) hi)) false)) (define (map f l) (if (not l) l (pair (f (left l)) (map f (right l))))) (print (map (lambda (x) (+ x 4)) (range 0 2)))
We can transform this program into an equivalent one without any anonymous functions.
(define (range lo hi) (if (< lo hi) (pair lo (range (add1 lo) hi)) false)) (define (map f l) (if (not l) l (pair (f (left l)) (map f (right l))))) (define (lambda_fun_0 x) (+ x 4)) (print (map lambda_fun_0 (range 0 2)))
We'll do this by adding a new version of our AST type, which includes lambdas, and parsing into that type; then we'll convert terms of that type into our familiar, lambda-free version.
ASTs with lambdas
ast_lam.ml
open S_exp open Util include Ast type expr_lam = | Prim0 of prim0 | Prim1 of prim1 * expr_lam | Prim2 of prim2 * expr_lam * expr_lam | Let of string * expr_lam * expr_lam | If of expr_lam * expr_lam * expr_lam | Do of expr_lam list | Num of int | Var of string | Call of expr_lam * expr_lam list | True | False | Lambda of string list * expr_lam let is_sym e = match e with Sym _ -> true | _ -> false let as_sym e = match e with Sym s -> s | _ -> raise Not_found let rec expr_lam_of_s_exp : s_exp -> expr_lam = function | Num x -> Num x | Sym "true" -> True | Sym "false" -> False | Sym var -> Var var | Lst [Sym "let"; Lst [Lst [Sym var; exp]]; body] -> Let (var, expr_lam_of_s_exp exp, expr_lam_of_s_exp body) | Lst (Sym "do" :: exps) when List.length exps > 0 -> Do (List.map expr_lam_of_s_exp exps) | Lst [Sym "if"; test_s; then_s; else_s] -> If ( expr_lam_of_s_exp test_s , expr_lam_of_s_exp then_s , expr_lam_of_s_exp else_s ) | Lst [Sym "lambda"; Lst args; body] when List.for_all is_sym args -> Lambda (List.map as_sym args, expr_lam_of_s_exp body) | Lst [Sym prim] when Option.is_some (prim0_of_string prim) -> Prim0 (Option.get (prim0_of_string prim)) | Lst [Sym prim; arg] when Option.is_some (prim1_of_string prim) -> Prim1 (Option.get (prim1_of_string prim), expr_lam_of_s_exp arg) | Lst [Sym prim; arg1; arg2] when Option.is_some (prim2_of_string prim) -> Prim2 ( Option.get (prim2_of_string prim) , expr_lam_of_s_exp arg1 , expr_lam_of_s_exp arg2 ) | Lst (f :: args) -> Call (expr_lam_of_s_exp f, List.map expr_lam_of_s_exp args) | e -> raise (BadSExpression e) let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function | Num x -> Num x | Var s -> Var s | True -> True | False -> False | If (test_exp, then_exp, else_exp) -> If ( expr_of_expr_lam defns test_exp , expr_of_expr_lam defns then_exp , expr_of_expr_lam defns else_exp ) | Let (var, exp, body) -> Let (var, expr_of_expr_lam defns exp, expr_of_expr_lam defns body) | Prim0 p -> Prim0 p | Prim1 (p, e) -> Prim1 (p, expr_of_expr_lam defns e) | Prim2 (p, e1, e2) -> Prim2 (p, expr_of_expr_lam defns e1, expr_of_expr_lam defns e2) | Do exps -> Do (List.map (expr_of_expr_lam defns) exps) | Call (exp, args) -> Call (expr_of_expr_lam defns exp, List.map (expr_of_expr_lam defns) args) | Lambda (args, body) -> let name = gensym "_lambda" in let body = expr_of_expr_lam defns body in defns := {name; args; body} :: !defns ; Var name let program_of_s_exps (exps : s_exp list) : program = let defns = ref [] in let rec get_args args = match args with | Sym v :: args -> v :: get_args args | e :: _ -> raise (BadSExpression e) | [] -> [] in let get_defn = function | Lst [Sym "define"; Lst (Sym name :: args); body] -> let args = get_args args in {name; args; body= body |> expr_lam_of_s_exp |> expr_of_expr_lam defns} | e -> raise (BadSExpression e) in let rec go exps = match exps with | [e] -> let body = e |> expr_lam_of_s_exp |> expr_of_expr_lam defns in {defns= List.rev !defns; body} | d :: exps -> let defn = get_defn d in defns := defn :: !defns ; go exps | _ -> raise (BadSExpression (Sym "empty")) in go exps
A hiccup
What should the following program evaluate to?
(define (adder x) (lambda (y) (+ x y))) (print (adder 2) 3)
But this will have some trouble. What would things look like if we did the lambda lifting step by hand? See the problem?