Skip to content

Commit

Permalink
Merge pull request #22 from billhails/anf-conversion
Browse files Browse the repository at this point in the history
Anf conversion
  • Loading branch information
billhails authored Nov 26, 2023
2 parents 93bfb45 + 1828927 commit cc0e9e2
Show file tree
Hide file tree
Showing 31 changed files with 2,482 additions and 417 deletions.
86 changes: 86 additions & 0 deletions docs/ANF.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# ANF Conversion

Another [blog post](https://matt.might.net/articles/a-normalization/)
from Matt Might describes a "simple" algorithm for doing A-Normalization,
but that algorithm uses continuations which aren't available in C, and so
I'll have to reverse-engineer a solution.

Probably best to do specific cases.

## 1
```scheme
(a (b c)) => (let (t$1 (b c))
(a t$1))
```
So on encountering an application, descend into the arguments.
If you find an inner function application, generate a symbol, replace the
inner application with the symbol, and wrap the outer application with a let
binding of that symbol to the inner application.

## 2
```scheme
(a (b c) (d e)) => (let (t$1 (d e))
(let (t$2 (a b))
(a t$2 t$1)))
```
The same logic applies here, the outer application is wrapped in let bindings.
Worth noting the recursion is like `foldr`, the let bindings are constructed
right to left on the way out of the recursion.

## 3
```scheme
(a (b (c d))) => (let (t$2 (c d))
(let (t$1 (b t$1))
(a t$1)))
```
Ok in this case we need to be a bit more careful, the let bindings being
extended remain the bindings for the `a` application, we can't just treat
`b` as an outer application or we'd end up with a `let` inside the arguments
to `a`.

It might be possible to perform that transformation in two stages?
```scheme
(a (b (c d))) => (let (t$1 (b (c d))) => (let (t$2 (c d))
(a t$1)) (let (t$1 (b t$2))
(a t$1)))
```

regardless of the details, the plan is to recurse into the leaves of the arguments, replacing
and constructing the let bindings on the way back out. so in processing `(a (b (c d)))` the
`t$2` binding for `(c d)` must be in scope when `(b t$2)` is replaced.

| Recursion | Let Bindings |
| --------- | ------------ |
| `(a (b (c d)))` | `...` |
| `(b (c d))` | `...` |
| `(c d)` | `...` |
| `t$1` | `(let (t$1 (c d)) ... )` |
| `(b t$1)` | `(let (t$1 (c d)) ... )` |
| `t$2` | `(let (t$1 (c d)) (let (t$2 (b t$1)) ... ))` |
| `(a t$2)` | `(let (t$1 (c d)) (let (t$2 (b t$1)) (a t$2)))` |

or the other way, where we recurse on the bindings

| Expression | Let Bindings |
| --------- | ------------ |
| `(a (b (c d)))` | `...` |
| `(a t$1)` | `(let (t$1 (b (c d))) (a t$1))` |
| `(a t$1)` | `(let (t$2 (c d)) (let (t$1 (b t$2)) (a t$1)))` |

The second approach seems a little more intuitive as we're prepending to the let bindins and the
body of the call is fully substituted in step 1

so informally:
1. Walk the application, replacing any cexp with a fresh
symbol and binding that symbol to the cexp.
2. Iteratively walk each application bound in step 1.
3. Stop when the iteration binds no new variables.

Not sure that's quite right though, as nests of primitive applications are fine as aexps, but then
again the recursion is not limited to just the top level of an application.

In terms of types, the cexp starts out as a LamExp which is being transformed into an Exp.
The result of walking a LamExp application should be an Exp, with variables substituted,
but the substitutions temporarily bound before being assigned to let bindings are still LamExps.
Let bindings are only constructed from sub-applications once the sub-application is translated
and the next iteration of bindings prepared for transformation.
4 changes: 3 additions & 1 deletion fn/colours.fn
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ let
typedef colour { red | green | blue }

fn tostr {
(red()) { "red" }
(red) { "red" }
(green) { "green" }
(blue) { "blue" }
}
in
tostr
1 change: 1 addition & 0 deletions fn/if.fn
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
if (true and true) { 10 } else { 20 }
20 changes: 10 additions & 10 deletions fn/interpreter.fn
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ let

// an interpreter
fn eval {
(addition(l, r), e) { add(eval(l, e), eval(r, e)) }
(subtraction(l, r), e) { sub(eval(l, e), eval(r, e)) }
(multiplication(l, r), e) { mul(eval(l, e), eval(r, e)) }
(division(l, r), e) { div(eval(l, e), eval(r, e)) }
(addition(l, r), e) { opAdd(eval(l, e), eval(r, e)) }
(subtraction(l, r), e) { opSub(eval(l, e), eval(r, e)) }
(multiplication(l, r), e) { opMul(eval(l, e), eval(r, e)) }
(division(l, r), e) { opDiv(eval(l, e), eval(r, e)) }
(i = number(_), e) { i }
(symbol(s), e) { lookup(s, e) }
(conditional(test, pro, con), e) { cond(test, pro, con, e) }
(conditional(test, pro, con), e) { opCond(test, pro, con, e) }
(l = lambda(_, _), e) { closure(l, e) }
(application(function, arg), e) { apply(eval(function, e), eval(arg, e)) }
}
Expand All @@ -35,15 +35,15 @@ let
}

// built-ins
fn add (number(a), number(b)) { number(a + b) }
fn opAdd (number(a), number(b)) { number(a + b) }

fn sub (number(a), number(b)) { number(a - b) }
fn opSub (number(a), number(b)) { number(a - b) }

fn mul (number(a), number(b)) { number(a * b) }
fn opMul (number(a), number(b)) { number(a * b) }

fn div (number(a), number(b)) { number(a / b) }
fn opDiv (number(a), number(b)) { number(a / b) }

fn cond(test, pro, con, e) {
fn opCond(test, pro, con, e) {
switch (eval(test, e)) {
(number(0)) { eval(con, e) } // 0 is false
(number(_)) { eval(pro, e) }
Expand Down
11 changes: 11 additions & 0 deletions fn/testLetRec.fn
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
let
fn bad() {
let
d1 = [1];
d2 = d1 @@ [2];
d3 = d2 @@ [3];
in
d3
}
in
bad
126 changes: 126 additions & 0 deletions prototyping/ANF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#! /usr/bin/env python3

import functools

def flip(func):
@functools.wraps(func)
def newfunc(x, y):
return func(y, x)
return newfunc

def foldr(func, acc, xs):
return functools.reduce(flip(func), reversed(xs), acc)

class Base:
counter = 0
def genSym(self):
Base.counter += 1
return "t$" + str(Base.counter)

def normalize_term(self):
return self.normalize(lambda x: x)

def normalize_name(self, k):
return self.normalize(lambda n: n.normalize_helper(n, k))

def normalize_helper(self, n, k):
t = self.genSym()
return Let(t, n, k(t))


class Lambda(Base):
def __init__(self, args, body):
self.args = args
self.body = body

def normalize(self, k):
return k(Lambda(self.args, self.body.normalize_term()))

def __str__(self):
return "(lambda (" + " ".join([str(x) for x in self.args]) + ") " + str(self.body) + ")"


class Let(Base):
def __init__(self, var, val, body):
self.var = var
self.val = val
self.body = body

def normalize(self, k):
return self.val.normalize(lambda n1: Let(self.var, n1, self.body.normalize(k)))

def __str__(self):
return "(let (" + str(self.var) + " " + str(self.val) + ") " + str(self.body) + ")"

class If(Base):
def __init__(self, test, consequent, alternative):
self.test = test
self.consequent = consequent
self.alternative = alternative

def normalize(self, k):
return self.test.normalize_name(lambda t: k(If(t, self.consequent.normalize_term(), self.alternative.normalize_term())))

def __str__(self):
return "(if " + str(self.test) + " " + str(self.consequent) + " " + str(self.alternative) + ")"


class Apply(Base):
def __init__(self, fun, *args):
self.fun = fun
self.args = Args.build([x for x in args])

def normalize(self, k):
return self.fun.normalize_name(lambda t: self.args.normalize_name(lambda t2: k(Apply(t, t2))))

def __str__(self):
return "(" + str(self.fun) + " " + str(self.args) + ")"


class Null(Base):
def normalize_name(self, k):
return k(self)

def __str__(self):
return ""

class Args(Base):
def __init__(self, val, rest):
self.val = val
self.rest = rest

def normalize_name(self, k):
return self.val.normalize_name(lambda t: self.rest.normalize_name(lambda t2: k(Args(t, t2))))

def __str__(self):
return str(self.val) + " " + str(self.rest)

@classmethod
def build(cls, args):
return foldr(lambda val, acc: cls(val, acc), Null(), args)

class Value(Base):
def __init__(self, val):
self.val = val

def normalize(self, k):
return k(self)

def normalize_helper(self, n, k):
return k(self)

def __str__(self):
return str(self.val)


def test(testexpr):
print(str(testexpr))
result = testexpr.normalize_term()
print(str(result))
print()

test(Apply(Value("a"), Apply(Value("b"), Value("c"))))

test(Lambda([Value("x"), Value("y")], Apply(Value("+"), Value("x"), Apply(Value("-"), Value("y")))))

test(Lambda([Value("x"), Value("y")], If(Apply(Value("+"), Value("x"), Apply(Value("-"), Value("y"))), Value("x"), Value("y"))))
Loading

0 comments on commit cc0e9e2

Please sign in to comment.