First-class functions

First-class functions

Most functional programming languages (really, most modern programming languages of all stripes) support some form of first-class functions. Right now, functions in our language are totally separate from values. We can define them and call them, but we can't:

Here are some programs, written in a slightly extended version of our language, that use first-class functions:

(define (f g) (g 2))
(define (mul2 x) (+ x x))
(print (f mul2))



(define (f g) (g 2))
(print (f (lambda (x) (+ x x))))


(define (f g) (g 2))
(let ((y 3)) (print (f (lambda (x) (+ x y)))))

Our first goal is to add support for the first program–we'll be able to pass functions around like other values. Next time we'll support the second two programs.

Extending the AST

It's easy enough to change our AST to support this.

ast.ml

type expr =
  | Prim0 of prim0
  | Prim1 of prim1 * expr
  | Prim2 of prim2 * expr * expr
  | Let of string * expr * expr
  | If of expr * expr * expr
  | Do of expr list
  | Num of int
  | Var of string
  | Call of expr * expr list
  | True
  | False

let rec expr_of_s_exp : s_exp -> expr = function
  (* some cases elided ... *)
  | Lst (f :: args) ->
      Call (expr_of_s_exp f, List.map expr_of_s_exp args)
  | e ->
      raise (BadSExpression e)

We'll modify this a bit when we introduce lambda expressions!

Extending the interpreter

We'll need a new type of value in our interpreter. It's up for debate how we should represent these function values!

The main point here: we're going to reuse the definition list as much as we can.

type value =
  | Number of int
  | Boolean of bool
  | Pair of (value * value)
  | Function of string

let rec string_of_value (v : value) : string =
  (* some cases elided ... *)
  | Function _ -> "<function>" 

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value
    =
  match exp with
  (* some cases elided ... *)
  | Call (f, args) -> (
      let argvals = args |> List.map (interp_exp defns env) in
      let fv = interp_exp defns env f in
      match fv with
      | Function name when is_defn defns name ->
          let defn = get_defn defns name in
          if List.length args = List.length defn.args then
            let fenv = List.combine defn.args argvals |> Symtab.of_list in
            interp_exp defns fenv defn.body
          else raise (BadExpression exp)
      | _ ->
          raise (BadExpression exp) )
  | Var var when is_defn defns var ->
      Function var

Extending the compiler

Similarly to the interpreter, we want to use a function name to look up the location of that function's code. The runtime representation of a function needs to be a pointer.

let fn_tag = 0b110

let ensure_fn (op : operand) : directive list =
  [ Mov (Reg R8, op)
  ; And (Reg R8, Imm heap_mask)
  ; Cmp (Reg R8, Imm fn_tag)
  ; Jnz "error" ]

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
  (* some cases elided ... *)
  | Call (f, args) when not is_tail ->
      let stack_base = align_stack_index (stack_index + 8) in
      let compiled_args =
        args
        |> List.mapi (fun i arg ->
               compile_exp defns tab (stack_base - (8 * (i + 2))) arg false
               @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)])
        |> List.concat
      in
      compiled_args
      @ compile_exp defns tab 
          (stack_base - (8 * (List.length args + 2))) 
          f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ [ Add (Reg Rsp, Imm stack_base)
        ; ComputedCall (Reg Rax)
        ; Sub (Reg Rsp, Imm stack_base) ]
  | Call (f, args) ->
      let compiled_args =
        args
        |> List.mapi (fun i arg ->
               compile_exp defns tab (stack_index - (8 * i)) arg false
               @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)])
        |> List.concat
      in
      let moved_args =
        args
        |> List.mapi (fun i _ ->
               [ Mov (Reg R8, stack_address (stack_index - (8 * i)))
               ; Mov (stack_address ((i + 1) * -8), Reg R8) ])
        |> List.concat
      in
      compiled_args
      @ compile_exp defns tab 
          (stack_index - (8 * (List.length args + 2))) 
          f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ moved_args @ [ComputedJmp (Reg Rax)]
  | Var var when is_defn defns var ->
      [LeaLabel (Reg Rax, defn_label var); Or (Reg Rax, Imm fn_tag)]

let compile_defn defns defn =
  let ftab =
    defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list
  in
  [Align 8; Label (defn_label defn.name)]
  @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
  @ [Ret]

We should also update our runtime to print functions:

#define fn_tag 0b110

void print_value(uint64_t value) {
    ...
  } else if ((value & heap_mask) == fn_tag) {
    printf("<function>");
  }

Note, also add -no-pie to our linker call.