Skip to content

Commit

Permalink
Merge pull request #24 from billhails/cond-if
Browse files Browse the repository at this point in the history
Replace if with cond in pattern matching
  • Loading branch information
billhails authored Dec 2, 2023
2 parents 97164fc + 7b4172c commit 1245ce1
Show file tree
Hide file tree
Showing 18 changed files with 464 additions and 159 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ PROFILING=-pg
OPTIMIZING=-O2
DEBUGGING=-g

CC=cc -Werror $(DEBUGGING)
# CC=cc -Werror $(OPTIMIZING)
# CC=cc -Werror $(DEBUGGING)
CC=cc -Werror $(OPTIMIZING)
# CC=cc -Werror $(PROFILING)

CFILES=$(wildcard src/*.c)
Expand Down
45 changes: 45 additions & 0 deletions docs/MATCH.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,48 @@ the faster it will be.
We might think about implementing arrays in the target language, initially for this single purpose,
to avoid having to use cons for everything. the array would have a length field and an array of
values.

# Cond

`match` is long-since implemented as described above, and is as fast as expected, however as
explained it cannot apply to non-exhaustive pattern matches like the Fibonacci function, and
in those situations we fall back to a chain of if-then-else's as discussed.

Situations like the Fibonacci function are very common (and `fib(40)` is a benchmark :grin:)
so it would be better to have a special construct to deal with those, which I'll call `cond` because
it has a lot of similarities with the lisp `cond` of old.

There's no surface grammar for it, it's generated by the TPMC and looks as follows:

```
(cond <value>
(<index> <expression>)
(<index> <expression>)
...
(default <expression>))
```

When converted to ANF, the `<value>` must be an Aexp, `<index>` values must be literal integers,
and `<expression>` values are Exps.

Conversion to bytecode is then fairly trivial:

```
| aexp | COND | num-cases | <index1> | addr(exp1) | ... | <indexn> | addr(expn) | ..expdef.. | ..exp1.. | ... | ..expn.. |
```

`expdef` is the default expression, it requires no index or address.

The virtual machine, on seeing a `COND` instruction, pops the result of the Aexp off the stack and
reads the following number of cases.
it then walks the tuples of `| <indexi> | <addri> |` looking for a matching index. If it finds one it
jumps to the addr, otherwise it runs naturally in to the default. Obviously each expr must end with a JMP
instruction to go to the end of the entire COND construct.

This should be significanly faster than a nest of ifs, which has to read and perform separate
operations for each condition.

## Update

`cond` is now implemented and the `fib(40)` benchmark has gone from around 40s
to 28s, 1.4 times faster.
20 changes: 10 additions & 10 deletions fn/interpreter.fn
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ let

// an interpreter
fn eval {
(addition(l, r), e) { opAdd(eval(l, e), eval(r, e)) }
(subtraction(l, r), e) { opSub(eval(l, e), eval(r, e)) }
(multiplication(l, r), e) { opMul(eval(l, e), eval(r, e)) }
(division(l, r), e) { opDiv(eval(l, e), eval(r, e)) }
(addition(l, r), e) { add(eval(l, e), eval(r, e)) }
(subtraction(l, r), e) { sub(eval(l, e), eval(r, e)) }
(multiplication(l, r), e) { mul(eval(l, e), eval(r, e)) }
(division(l, r), e) { div(eval(l, e), eval(r, e)) }
(i = number(_), e) { i }
(symbol(s), e) { lookup(s, e) }
(conditional(test, pro, con), e) { opCond(test, pro, con, e) }
(conditional(test, pro, con), e) { cond(test, pro, con, e) }
(l = lambda(_, _), e) { closure(l, e) }
(application(function, arg), e) { apply(eval(function, e), eval(arg, e)) }
}
Expand All @@ -35,15 +35,15 @@ let
}

// built-ins
fn opAdd (number(a), number(b)) { number(a + b) }
fn add (number(a), number(b)) { number(a + b) }

fn opSub (number(a), number(b)) { number(a - b) }
fn sub (number(a), number(b)) { number(a - b) }

fn opMul (number(a), number(b)) { number(a * b) }
fn mul (number(a), number(b)) { number(a * b) }

fn opDiv (number(a), number(b)) { number(a / b) }
fn div (number(a), number(b)) { number(a / b) }

fn opCond(test, pro, con, e) {
fn cond(test, pro, con, e) {
switch (eval(test, e)) {
(number(0)) { eval(con, e) } // 0 is false
(number(_)) { eval(pro, e) }
Expand Down
27 changes: 24 additions & 3 deletions src/analysis.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ static void analizeAexpPrimApp(AexpPrimApp *x, CTEnv *env);
static void analizeAexpUnaryApp(AexpUnaryApp *x, CTEnv *env);
static void analizeAexpList(AexpList *x, CTEnv *env);
static void analizeCexpApply(CexpApply *x, CTEnv *env);
static void analizeCexpIf(CexpIf *x, CTEnv *env);
static void analizeCexpCond(CexpCond *x, CTEnv *env);
static void analizeCexpCondCases(CexpCondCases *x, CTEnv *env);
static void analizeCexpLetRec(CexpLetRec *x, CTEnv *env);
static void analizeCexpAmb(CexpAmb *x, CTEnv *env);
static void analizeCexpCut(CexpCut *x, CTEnv *env);
Expand Down Expand Up @@ -120,15 +122,32 @@ static void analizeCexpApply(CexpApply *x, CTEnv *env) {
analizeAexpList(x->args, env);
}

static void analizeCexpCond(CexpCond *x, CTEnv *env) {
static void analizeCexpIf(CexpIf *x, CTEnv *env) {
#ifdef DEBUG_ANALIZE
fprintf(stderr, "analizeCexpCond "); printCexpCond(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
fprintf(stderr, "analizeCexpIf "); printCexpIf(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
#endif
analizeAexp(x->condition, env);
analizeExp(x->consequent, env);
analizeExp(x->alternative, env);
}

static void analizeCexpCond(CexpCond *x, CTEnv *env) {
#ifdef DEBUG_ANALIZE
fprintf(stderr, "analizeCexpCond "); printCexpCond(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
#endif
analizeAexp(x->condition, env);
analizeCexpCondCases(x->cases, env);
}

static void analizeCexpCondCases(CexpCondCases *x, CTEnv *env) {
if (x == NULL) return;
#ifdef DEBUG_ANALIZE
fprintf(stderr, "analizeCexpCondCases "); printCexpCondCases(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
#endif
analizeExp(x->body, env);
analizeCexpCondCases(x->next, env);
}

static void analizeLetRecLam(Aexp *x, CTEnv *env, int letRecOffset) {
switch (x->type) {
case AEXP_TYPE_LAM:
Expand Down Expand Up @@ -220,7 +239,6 @@ static void analizeAexp(Aexp *x, CTEnv *env) {
case AEXP_TYPE_FALSE:
case AEXP_TYPE_INT:
case AEXP_TYPE_CHAR:
case AEXP_TYPE_DEFAULT:
case AEXP_TYPE_VOID:
break;
case AEXP_TYPE_PRIM:
Expand Down Expand Up @@ -263,6 +281,9 @@ static void analizeCexp(Cexp *x, CTEnv *env) {
case CEXP_TYPE_APPLY:
analizeCexpApply(x->val.apply, env);
break;
case CEXP_TYPE_IF:
analizeCexpIf(x->val.iff, env);
break;
case CEXP_TYPE_COND:
analizeCexpCond(x->val.cond, env);
break;
Expand Down
64 changes: 49 additions & 15 deletions src/anf.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static Exp *normalizeMatch(LamMatch *match, Exp *tail);
static MatchList *normalizeMatchList(LamMatchList *matchList);
static AexpList *convertIntList(LamIntList *list);
static Exp *normalizeCond(LamCond *cond, Exp *tail);
static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacements);
static CexpCondCases *normalizeCondCases(LamCondCases *cases);
static CexpLetRec *replaceCexpLetRec(CexpLetRec *cexpLetRec, LamLetRecBindings *lamLetRecBindings);

Exp *anfNormalize(LamExp *lamExp) {
Expand Down Expand Up @@ -152,14 +152,21 @@ static Exp *normalizeCond(LamCond *cond, Exp *tail) {
int save = PROTECT(replacements);
Aexp *value = replaceLamExp(cond->value, replacements);
int save2 = PROTECT(value);
Exp *exp = replaceCond(value, cond->cases, replacements);
CexpCondCases *cases = normalizeCondCases(cond->cases);
PROTECT(cases);
CexpCond *cexpCond = newCexpCond(value, cases);
UNPROTECT(save2);
save2 = PROTECT(cexpCond);
Cexp *cexp = newCexp(CEXP_TYPE_COND, CEXP_VAL_COND(cexpCond));
REPLACE_PROTECT(save2, cexp);
Exp *exp = wrapCexp(cexp);
REPLACE_PROTECT(save2, exp);
exp = wrapTail(exp, tail);
REPLACE_PROTECT(save2, exp);
Exp *res = letBind(exp, replacements);
exp = letBind(exp, replacements);
UNPROTECT(save);
LEAVE(normalizeCond);
return res;
return exp;
}

static Exp *normalizeMatch(LamMatch *match, Exp *tail) {
Expand Down Expand Up @@ -321,10 +328,10 @@ static Exp *normalizeIff(LamIff *lamIff, Exp *tail) {
PROTECT(consequent);
Exp *alternative = normalize(lamIff->alternative, NULL);
PROTECT(alternative);
CexpCond *cexpCond = newCexpCond(condition, consequent, alternative);
CexpIf *cexpIf = newCexpIf(condition, consequent, alternative);
UNPROTECT(save2);
save2 = PROTECT(cexpCond);
Cexp *cexp = newCexp(CEXP_TYPE_COND, CEXP_VAL_COND(cexpCond));
save2 = PROTECT(cexpIf);
Cexp *cexp = newCexp(CEXP_TYPE_IF, CEXP_VAL_IF(cexpIf));
REPLACE_PROTECT(save2, cexp);
Exp *exp = wrapCexp(cexp);
REPLACE_PROTECT(save2, exp);
Expand Down Expand Up @@ -689,6 +696,7 @@ static Aexp *cloneAexp(Aexp *orig) {
return orig;
}

/*
static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacements) {
ENTER(replaceCond);
if (cases == NULL) {
Expand Down Expand Up @@ -717,22 +725,49 @@ static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacement
PROTECT(eq);
Aexp *condition = newAexp(AEXP_TYPE_PRIM, AEXP_VAL_PRIM(eq));
PROTECT(condition);
CexpCond *iff = newCexpCond(condition, consequent, alternative);
CexpIf *iff = newCexpIf(condition, consequent, alternative);
PROTECT(iff);
Cexp *cexp = newCexp(CEXP_TYPE_COND, CEXP_VAL_COND(iff));
Cexp *cexp = newCexp(CEXP_TYPE_IF, CEXP_VAL_IF(iff));
PROTECT(cexp);
Exp *exp = wrapCexp(cexp);
UNPROTECT(save);
LEAVE(replaceCond);
return exp;
}
*/


static Aexp *replaceCondDefault() {
return newAexp(AEXP_TYPE_DEFAULT, AEXP_VAL_DEFAULT());
static CexpCondCases *normalizeCondCases(LamCondCases *cases) {
ENTER(normalizeCondCases);
if (cases == NULL) {
LEAVE(normalizeCondCases);
return NULL;
}
CexpCondCases *next = normalizeCondCases(cases->next);
int save = PROTECT(next);
int constant = 0;
switch (cases->constant->type) {
case LAMEXP_TYPE_INTEGER:
constant = cases->constant->val.integer;
break;
case LAMEXP_TYPE_CHARACTER:
constant = (int) cases->constant->val.character;
break;
case LAMEXP_TYPE_COND_DEFAULT:
if (next != NULL) {
cant_happen("cond default not last case");
}
constant = 0;
break;
default:
cant_happen("unexpected type %d for constant in normalizeCondCases");
}
Exp *body = normalize(cases->body, NULL);
PROTECT(body);
CexpCondCases *this = newCexpCondCases(constant, body, next);
UNPROTECT(save);
return this;
}


static Aexp *replaceLamExp(LamExp *lamExp, HashTable *replacements) {
ENTER(replaceLamExp);
Aexp *res = NULL;
Expand Down Expand Up @@ -774,8 +809,7 @@ static Aexp *replaceLamExp(LamExp *lamExp, HashTable *replacements) {
res = replaceLamCexp(lamExp, replacements);
break;
case LAMEXP_TYPE_COND_DEFAULT:
res = replaceCondDefault();
break;
cant_happen("replaceLamExp encountered cond default");
default:
cant_happen("unrecognised type %d in replaceLamExp", lamExp->type);
}
Expand Down
74 changes: 73 additions & 1 deletion src/bytecode.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ static void writeWordAt(int loc, ByteCodeArray *b, int word) {
b->entries[loc + 1] = word & 255;
}

static void writeIntAt(int loc, ByteCodeArray *b, int word) {
b->entries[loc + 0] = (word >> 24) & 255;
b->entries[loc + 1] = (word >> 16) & 255;
b->entries[loc + 2] = (word >> 8) & 255;
b->entries[loc + 3] = word & 255;
}

static void writeCurrentAddressAt(int patch, ByteCodeArray *b) {
int offset = b->count - patch;
writeWordAt(patch, b, offset);
Expand All @@ -88,6 +95,15 @@ static int reserveWord(ByteCodeArray *b) {
return address;
}

static int reserveInt(ByteCodeArray *b) {
int address = b->count;
addByte(b, 0);
addByte(b, 0);
addByte(b, 0);
addByte(b, 0);
return address;
}

static void writeInt(ByteCodeArray *b, int word) {
if (word > 4294967295) {
cant_happen("maximum int size exceeded");
Expand Down Expand Up @@ -231,7 +247,7 @@ void writeCexpApply(CexpApply *x, ByteCodeArray *b) {
addByte(b, BYTECODE_APPLY);
}

void writeCexpCond(CexpCond *x, ByteCodeArray *b) {
void writeCexpIf(CexpIf *x, ByteCodeArray *b) {
writeAexp(x->condition, b);
addByte(b, BYTECODE_IF);
int patch = reserveWord(b);
Expand All @@ -243,6 +259,58 @@ void writeCexpCond(CexpCond *x, ByteCodeArray *b) {
writeCurrentAddressAt(patch2, b);
}

static int countCexpCondCases(CexpCondCases *x) {
int val = 0;
while (x != NULL) {
val++;
x = x->next;
}
return val;
}

void writeCexpCondCases(int depth, int *values, int *addresses, int *jumps, CexpCondCases *x, ByteCodeArray *b) {
if (x == NULL) {
return;
}
writeCexpCondCases(depth + 1, values, addresses, jumps, x->next, b);
if (x->next == NULL) { // default
writeExp(x->body, b);
} else {
writeIntAt(values[depth], b, x->option);
writeCurrentAddressAt(addresses[depth], b);
writeExp(x->body, b);
}
if (depth > 0) {
addByte(b, BYTECODE_JMP);
jumps[depth - 1] = reserveWord(b);
}
}

void writeCexpCond(CexpCond *x, ByteCodeArray *b) {
int numCases = countCexpCondCases(x->cases);
numCases--; // don't count the default case
if (numCases <= 0) {
cant_happen("zero cases in writeCexpCond");
}
writeAexp(x->condition, b);
addByte(b, BYTECODE_COND);
writeWord(b, numCases);
int *values = NEW_ARRAY(int, numCases); // address in b for each index_i
int *addresses = NEW_ARRAY(int, numCases); // address in b for each addr(exp_i)
int *jumps = NEW_ARRAY(int, numCases); // address in b for the JMP patch address at the end of each expression
for (int i = 0; i < numCases; i++) {
values[i] = reserveInt(b);
addresses[i] = reserveWord(b);
}
writeCexpCondCases(0, values, addresses, jumps, x->cases, b);
for (int i = 0; i < numCases; i++) {
writeCurrentAddressAt(jumps[i], b);
}
FREE_ARRAY(int, values, numCases);
FREE_ARRAY(int, addresses, numCases);
FREE_ARRAY(int, jumps, numCases);
}

void writeCexpLetRec(CexpLetRec *x, ByteCodeArray *b) {
writeLetRecBindings(x->bindings, b);
addByte(b, BYTECODE_LETREC);
Expand Down Expand Up @@ -408,6 +476,10 @@ void writeCexp(Cexp *x, ByteCodeArray *b) {
writeCexpApply(x->val.apply, b);
}
break;
case CEXP_TYPE_IF: {
writeCexpIf(x->val.iff, b);
}
break;
case CEXP_TYPE_COND: {
writeCexpCond(x->val.cond, b);
}
Expand Down
Loading

0 comments on commit 1245ce1

Please sign in to comment.