More Functions

Functions in the AST

We decided at the end of last class: a program is no longer a single expression. Rather, it's a list of definitions followed by the program body. Think about it for a minute: what should the "ast" of a definition look like?

We define it here, with a few helper functions alongside.

ast.ml

type defn = { name : string; args : string list; body : expr }
type program = { defns : defn list; body : expr }

let is_defn defns name = List.exists (fun d -> d.name = name) defns
let get_defn defns name = List.find (fun d -> d.name = name) defns

The final helper function will let us convert a list of s_exp objects (that we get from the parser) into a program.

let program_of_s_exps (exps : s_exp list) : program =
  let rec get_args args =
    match args with
    | Sym v :: args -> v :: get_args args
    | e :: _ -> raise (BadSExpression e)
    | [] -> []
  in
  let get_defn = function
    | Lst [ Sym "define"; Lst (Sym name :: args); body ] ->
        let args = get_args args in
        { name; args; body = expr_of_s_exp body }
    | e -> raise (BadSExpression e)
  in
  let rec go exps defns =
    match exps with
    | [ e ] -> { defns = List.rev defns; body = expr_of_s_exp e }
    | d :: exps -> go exps (get_defn d :: defns)
    | _ -> raise (BadSExpression (Sym "empty"))
  in
  go exps []

But this was all structural: what about representing function calls themselves? We'll have to add a constructor to our AST.

type expr =
  (* ... *)
  | Call of string * expr list

let rec expr_of_s_exp (e : s_exp) : expr =
  match e with
  (* ... *)
  | Lst (Sym f :: args) ->
      Call (f, List.map expr_of_s_exp args)
  | _ ->
      raise (BadSExpression e)

Function definitions in the interpreter

Suppose we're in the case where we want to interpret (f e1 e2 e3), where f corresponds to a definition (which we have already parsed into a defn object). Every definition has an arity associated with it -- that is, it expects a certain number of arguments. We need to check that, in this case, f is expecting 3 arguments.

Then we need to interpret the body of this function, replacing variables with the values of these arguments. It's easy to evaluate the arguments themselves (just calling interp_exp recursively), but then we need to substitute them into the body. Luckily we have a way to do this already: the environment maps symbol names to values, so we just need to update an environment.

But a question comes up: do we want to update the old environment, or start fresh? Consider the following program:

(define (f x) (x + z))
(let ((z 4)) (f 2))

Should this program throw an error (z is unbound in the body of f), or evaluate to 6? There are arguments to be made in both directions. To implement the latter option, we would want to update our existing environment when evaluating the body of f. The behavior demonstrated here is dynamic scoping, and can be both useful and confusing. A more principled approach (and one that will be much easier in our compiler) rejects this by interpreting the body in a fresh environment (where x -> 2 and z is unbound). This is known as lexical scope. We'll take this route.

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) :
    value =
  (* ... *)
  | Call (f, args) when is_defn defns f ->
      let defn = get_defn defns f in
      if List.length args = List.length defn.args then
        let vals = List.map (interp_exp defns env) args in
        let fenv = List.combine defn.args vals |> Symtab.of_list in
        interp_exp defns fenv defn.body
      else raise (BadExpression exp)
  | Call _ ->
      raise (BadExpression exp)

let interp (program : string) : unit =
  let program2 = parse_many program |> program_of_s_exps in
  interp_exp program2.defns Symtab.empty program2.body |> ignore

Try interpreting a program that computes a large Fibonacci number. It's slow -- we can now write and run real, nontrivial programs!

(define (fib x) (if (= x 0) 1 (if (= x 1) 1 (+ (fib (sub1 x)) (fib (sub1 (sub1 x)))))))
(print (fib 4))

And in the compiler

We face the same scoping question in the compiler. Since we chose to use lexical scope in the interpreter, we should do the same, but where should we store the values of the arguments to a function at runtime?

Like let bindings, it makes sense to put these values on the stack. It also makes sense to index them in our symbol table. And like the interpreter, to scope things correctly, we'll need to temporarily use a new symbol table.

Unlike the interpreter, though, we won't be able to compile a function call simply by compiling the function body in-place. Think about why not! (Hint: what would happen in the recursive even/odd program?) Sometimes this is possible, and represents an optimization technique, inlining. But in general we can't.

Instead, we'll give each function its own label, and call it almost like it's an external C function. This means that each time we call one of our newly defined functions, it will have its own stack frame.

First, we'll have to update our main compiler function to deal with definitions, as we did in the interpreter.

let compile (prog : program) : string =
  [ Global "entry"
  ; Extern "error"
  ; Extern "read_num"
  ; Extern "print_newline"
  ; Extern "print_value"
  ; Label "entry" ]
  @ compile_exp prog.defns Symtab.empty (-8) prog.body
  @ [Ret]
  @ List.concat_map (compile_defn prog.defns) prog.defns
  |> List.map string_of_directive
  |> String.concat "\n"

let compile_to_file (program : string) : unit =
  let file = open_out "program.s" in
  parse_many program |> program_of_s_exps |> compile |> output_string file ; close_out file

We've outsourced dealing with definitions to a new function, compile_defn, which will label and implement a single definition. The code produced will assume that it has a fresh stack frame: rsp should point to the return address in entry, and rsp-8 and above should be free for the new function to use.

let compile_defn defns {name; args; body} =
  let ftab =
    args
    |> List.mapi (fun i arg -> (arg, -8 * (i + 1)))
    |> Symtab.of_list
  in
  [Label (defn_label name)]
  @ compile_exp defns ftab (-8 * (List.length args + 1)) body
  @ [Ret]

This uses an auxiliary function defn_label which strips from the function name any characters that are not allowed in label names.

Then we can update the main compiler loop. To execute a defined function, we compute where the bottom of its stack frame will be, move the values of the arguments into this part of the stack, and then Call the function we defined.

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) : directive list =
  (* ... *)
  | Call (f, args) when is_defn defns f -> 
    let defn = get_defn defns f in 
    if List.length args <> List.length defn.args then raise (BadExpression exp) else 
    let stack_base = align_stack_index (stack_index + 8) in 
    let compiled_args =
      args |> List.mapi (fun i arg ->
        compile_exp defns tab (stack_base - 8*(i + 2)) arg 
        @ [Mov (stack_address (stack_base - 8*(i + 2)), Reg Rax)])
      |> List.concat in 
    compiled_args @
    [ Add (Reg Rsp, Imm stack_base)
    ; Call (defn_label f)
    ; Sub (Reg Rsp, Imm stack_base) ]
  | Call _ -> raise (BadExpression exp)

There's some slightly tricky stack pointer arithmetic here, so let's be careful.

When we execute the Call assembly directive, we want rsp to contain the address of the topmost 16-byte misaligned spot on the stack such that all memory above is free. Call will decrement rsp by 8 (thus aligning it with 16, and pointing to an empty cell), and insert a return address at [rsp].

stack_base, computed in the OCaml code above to be align_stack_index (stack_index + 8), is exactly the amount we need to add to rsp to achieve this. Before we start compiling the arguments to the function call, the first free spot on the stack is [rsp + stack_index]. If stack_index is 16-byte aligned, then rsp + stack_index + 8 is misaligned, and everything above it is empty. Otherwise, rsp + stack_index (i.e. rsp + (stack_index + 8) - 8) has this property.

So, after we set rsp to be rsp + stack_base, Call decrements rsp by another 8 and inserts a return address. The first free spot on the stack is 8 bytes above the new rsp. In terms of our original rsp, this is rsp + stack_base - 16. It is at this spot in the stack that we store the first argument to our function. Subsequent arguments go on the stack above this: the ith argument goes at rsp + stack_base - 8*(i + 2).