First-class functions
First-class functions
Most functional programming languages (really, most modern programming languages of all stripes) support some form of first-class functions. Right now, functions in our language are totally separate from values. We can define them and call them, but we can't:
- Put a function in a variable
- Return a function from a function
- Pass a function to a function as an argument
Here are some programs, written in a slightly extended version of our language, that use first-class functions:
(define (f g) (g 2)) (define (mul2 x) (+ x x)) (print (f mul2)) (define (f g) (g 2)) (print (f (lambda (x) (+ x x)))) (define (f g) (g 2)) (let ((y 3)) (print (f (lambda (x) (+ x y)))))
Our first goal is to add support for the first program–we'll be able to pass functions around like other values. Next time we'll support the second two programs.
Extending the AST
It's easy enough to change our AST to support this.
ast.ml
type expr = | Prim0 of prim0 | Prim1 of prim1 * expr | Prim2 of prim2 * expr * expr | Let of string * expr * expr | If of expr * expr * expr | Do of expr list | Num of int | Var of string | Call of expr * expr list | True | False let rec expr_of_s_exp : s_exp -> expr = function (* some cases elided ... *) | Lst (f :: args) -> Call (expr_of_s_exp f, List.map expr_of_s_exp args) | e -> raise (BadSExpression e)
We'll modify this a bit when we introduce lambda expressions!
Extending the interpreter
We'll need a new type of value in our interpreter. It's up for debate how we should represent these function values!
The main point here: we're going to reuse the definition list as much as we can.
type value = | Number of int | Boolean of bool | Pair of (value * value) | Function of string let rec string_of_value (v : value) : string = (* some cases elided ... *) | Function _ -> "<function>" let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided ... *) | Call (f, args) -> ( let argvals = args |> List.map (interp_exp defns env) in let fv = interp_exp defns env f in match fv with | Function name when is_defn defns name -> let defn = get_defn defns name in if List.length args = List.length defn.args then let fenv = List.combine defn.args argvals |> Symtab.of_list in interp_exp defns fenv defn.body else raise (BadExpression exp) | _ -> raise (BadExpression exp) ) | Var var when is_defn defns var -> Function var
Extending the compiler
Similarly to the interpreter, we want to use a function name to look up the location of that function's code. The runtime representation of a function needs to be a pointer.
let fn_tag = 0b110 let ensure_fn (op : operand) : directive list = [ Mov (Reg R8, op) ; And (Reg R8, Imm heap_mask) ; Cmp (Reg R8, Imm fn_tag) ; Jnz "error" ] let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided ... *) | Call (f, args) when not is_tail -> let stack_base = align_stack_index (stack_index + 8) in let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_base - (8 * (i + 2))) arg false @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)]) |> List.concat in compiled_args @ compile_exp defns tab (stack_base - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ [ Add (Reg Rsp, Imm stack_base) ; ComputedCall (Reg Rax) ; Sub (Reg Rsp, Imm stack_base) ] | Call (f, args) -> let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_index - (8 * i)) arg false @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)]) |> List.concat in let moved_args = args |> List.mapi (fun i _ -> [ Mov (Reg R8, stack_address (stack_index - (8 * i))) ; Mov (stack_address ((i + 1) * -8), Reg R8) ]) |> List.concat in compiled_args @ compile_exp defns tab (stack_index - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ moved_args @ [ComputedJmp (Reg Rax)] | Var var when is_defn defns var -> [LeaLabel (Reg Rax, defn_label var); Or (Reg Rax, Imm fn_tag)] let compile_defn defns defn = let ftab = defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in [Align 8; Label (defn_label defn.name)] @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true @ [Ret]
We should also update our runtime to print functions:
#define fn_tag 0b110 void print_value(uint64_t value) { ... } else if ((value & heap_mask) == fn_tag) { printf("<function>"); }
Note, also add -no-pie
to our linker call.