Tail calls

Before we get started, let's run through an example in the compiler as we ended last time:

(define (f x y) (+ x y))
(let ((a 2)) (let ((b 10)) (print (f a b))))

The key terminology here is that of a stack frame. Each time we begin a function call, we want it to have a fresh frame: its arguments appear right above [rsp], and there is nothing else on the visible stack that it could overwrite. See the lecture capture for details.

Uh oh, segfaults

We left off having implemented function definitions in the interpreter and compiler.

Here's a function we can now run:

(define (even n) (if (zero? n) true (odd (sub1 n))))
(define (odd n) (if (zero? n) false (not (even n))))
(print (even (read-num)))

But try running this on large inputs (say, 2000000) in either the interpreter or compiler. We see some strange behavior in both cases: OCaml errors, segmentation faults, ...

What's happening? Stack overflows! When executing this program, both our compiler and OCaml's compiler want to store values on the stack. But the amount of stack space they use grows linearly with the input value. Eventually they run out space on the stack and error out.

There's a small tweak we can make to this program that fixes the issue in the interpreter, but not the compiler:

(define (even n) (if (zero? n) true (odd (sub1 n))))
(define (odd n) (if (zero? n) false (even (sub1 n))))
(print (even (read-num)))

Here's another example, a program to sum the first n natural numbers:

(define (sum n)
  (if (zero? n)
    n
    (+ n (sum (sub1 n)))))
(print (sum (read-num)))

Calling (sum 1000000) again overflows the stack. And again, in our interpreter, we can fix this:

(define (sum n total)
  (if (zero? n)
    total
    (sum (sub1 n) (+ n total))))
(print (sum (read-num) 0))

But our compiler still segfaults. What's going on here?

The question isn't why our compiler segfaults–why wouldn't it? It has a finite amount of stack space, after all, and we're doing a lot of function calls here. Each call to sum adds a few values to the stack. So why isn't the interpreter overflowing the stack?

Let's postpone the answer to that question. First, let's look at the assembly instructions we produce for our little sum program:

global _entry
extern _error
extern _read_num
extern _print_value
extern _print_newline
_entry:
    mov [rsp + -24], rdi
    add rsp, -24
    call _read_num
    sub rsp, -24
    mov rdi, [rsp + -24]
    mov [rsp + -24], rax
    mov rax, 0
    mov [rsp + -32], rax
    add rsp, -8
    call _function_sum_957043065
    sub rsp, -8
    mov [rsp + -8], rdi
    mov rdi, rax
    add rsp, -8
    call _print_value
    sub rsp, -8
    mov rdi, [rsp + -8]
    mov rax, 159
    ret
_function_sum_957043065:
    mov rax, [rsp + -8]
    cmp rax, 0
    mov rax, 0
    setz al
    shl rax, 7
    or rax, 31
    cmp rax, 31
    jz _else__2
    mov rax, [rsp + -16]
    jmp _continue__3
_else__2:
    mov rax, [rsp + -8]
    sub rax, 4
    mov [rsp + -40], rax
    mov rax, [rsp + -8]
    mov [rsp + -48], rax
    mov rax, [rsp + -16]
    mov r8, [rsp + -48]
    add rax, r8
    mov [rsp + -48], rax
    add rsp, -24
    call _function_sum_957043065
    sub rsp, -24
_continue__3:
    ret

Look at the end of the program, there. In the else case of our conditional expression, we're pushing a bunch of arguments to the stack and then calling our function. After we come back from that call we'll execute two instructions: we'll restore the stack pointer and then return again. In other words, we're really not doing any work once our function returns! We have this stack frame storing our local variables and function parameters, but we're not accessing any of them! It seems like once we're doing that function call, we shouldn't really need an extra stack frame; instead, we should be able to just re-use the one we already have–essentially, replacing the call with a jmp.

When can we do this? Well, we should be able to do this whenever we don't have any work to do after a function is called. If we're just going to return whatever another function returns, without modification, we should be able to re-use our stack frame: we're guaranteed not to need any of the things we've stored there.

Function calls in this position are called tail calls. Take a look at this little program:

(define (f x) (+ 3 x))
(define (sum-f n total)
  (if (zero? n)
    total
    (sum-f (sub1 n) (+ (f n) total))))
(print (sum-f (read-num) 0))

Is the call to f in tail position? No–after f returns, we have to do more work.

How about this program?

(define (even n) (if (zero? n) true (odd (sub1 n))))
(define (odd n) (if (zero? n) false (not (even n))))
(print (even (read-num)))

Is even's call to odd a tail call? Yes. Is odd's call to even a tail call? No–after even returns, there's more work to do (specifically, negating the value). We've decided that we want to compile function calls differently when they are in tail position. So, let's add an argument to the compiler that will be true if the expression being compiled is in tail position and false otherwise. We'll need to add it to every call to compile_exp; here are some of the more interesting ones:

compile.ml

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
  (* ... *)
  | Prim1 (Print, e) ->
      compile_exp defns tab stack_index e false
      @ [ Mov (stack_address stack_index, Reg Rdi)
        ; Mov (Reg Rdi, Reg Rax)
        ; Add (Reg Rsp, Imm (align_stack_index stack_index))
        ; Call "print_value"
        ; Sub (Reg Rsp, Imm (align_stack_index stack_index))
        ; Mov (Reg Rdi, stack_address stack_index)
        ; Mov (Reg Rax, operand_of_bool true) ]
  | Do exps when List.length exps > 0 ->
      List.mapi
        (fun i exp ->
          compile_exp defns tab stack_index exp
            (if i = List.length exps - 1 then is_tail else false))
        exps
      |> List.concat
  | If (test_exp, then_exp, else_exp) ->
      let else_label = Util.gensym "else" in
      let continue_label = Util.gensym "continue" in
      compile_exp defns tab stack_index test_exp false
      @ [Cmp (Reg Rax, operand_of_bool false); Jz else_label]
      @ compile_exp defns tab stack_index then_exp is_tail
      @ [Jmp continue_label] @ [Label else_label]
      @ compile_exp defns tab stack_index else_exp is_tail
      @ [Label continue_label]
  | Prim2 (Plus, e1, e2) ->
      compile_exp defns tab stack_index e1 false
      @ [Mov (stack_address stack_index, Reg Rax)]
      @ compile_exp defns tab (stack_index - 8) e2 false
      @ [Mov (Reg R8, stack_address stack_index)]
      @ [Add (Reg Rax, Reg R8)]

In our recursive calls, we have essentially two cases:

When we call compile_exp, either to compile the body of a function or our main program body, is_tail will start out as true.

compile.ml

let compile_defn defns defn =
  let ftab =
    defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list
  in
  [Label (defn_label defn.name)]
  @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
  @ [Ret]

let compile (program : s_exp list) : string =
  let defns, body = defns_and_body program in
  [ Global "entry"
  ; Extern "error"
  ; Extern "read_num"
  ; Extern "print_value"
  ; Extern "print_newline"
  ; Label "entry" ]
  @ compile_exp defns Symtab.empty (-8) body true
  @ [Ret]
  @ (List.map (compile_defn defns) defns |> List.concat)
  |> List.map string_of_directive
  |> String.concat "\n"

Now we need to use is_tail to reuse the current stack frame if a function call is in tail position. We can add another case like this:

compile.ml

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
       (* ... *)
  | Call (f, args) when is_defn defns f && is_tail ->
      let defn = get_defn defns f in
      if List.length args = List.length defn.args then
        let compiled_args =
          args
          |> List.mapi (fun i arg ->
                 compile_exp defns tab (stack_index - (8 * i)) arg false
                 @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)])
          |> List.concat
        in
        let moved_args =
          args
          |> List.mapi (fun i _ ->
                 [ Mov (Reg R8, stack_address (stack_index - (8 * i)))
                 ; Mov (stack_address ((i + 1) * -8), Reg R8) ])
          |> List.concat
        in
        compiled_args @ moved_args @ [Jmp (defn_label f)]
      else raise (BadExpression exp)

We're first compiling each argument and storing it in the next available stack index. We then move all of these arguments to the base of the stack (right after rsp), since that's where function arguments go! Then, we can just jump to the right label. Notice that we're not changing rsp when we call a function in tail position: this is exactly what it means for us to reuse a stack frame.

Tail calls in the interpreter

As we saw, our interpreter already seems to be doing this–it didn't overflow the stack when we interpreted a tail-recursive program. Why? Well, the interpreter is written in OCaml, which properly implements tail calls by reusing stack frames. So as long as the interpreter's calls to itself are in tail position, everything will work out!