Skip to content

Commit

Permalink
cond replaces if in pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
billhails committed Dec 2, 2023
1 parent ca2ace5 commit 7b4172c
Show file tree
Hide file tree
Showing 15 changed files with 406 additions and 104 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ PROFILING=-pg
OPTIMIZING=-O2
DEBUGGING=-g

CC=cc -Werror $(DEBUGGING)
# CC=cc -Werror $(OPTIMIZING)
# CC=cc -Werror $(DEBUGGING)
CC=cc -Werror $(OPTIMIZING)
# CC=cc -Werror $(PROFILING)

CFILES=$(wildcard src/*.c)
Expand Down
45 changes: 45 additions & 0 deletions docs/MATCH.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,48 @@ the faster it will be.
We might think about implementing arrays in the target language, initially for this single purpose,
to avoid having to use cons for everything. the array would have a length field and an array of
values.

# Cond

`match` is long-since implemented as described above, and is as fast as expected, however as
explained it cannot apply to non-exhaustive pattern matches like the Fibonacci function, and
in those situations we fall back to a chain of if-then-else's as discussed.

Situations like the Fibonacci function are very common (and `fib(40)` is a benchmark :grin:)
so it would be better to have a special construct to deal with those, which I'll call `cond` because
it has a lot of similarities with the lisp `cond` of old.

There's no surface grammar for it, it's generated by the TPMC and looks as follows:

```
(cond <value>
(<index> <expression>)
(<index> <expression>)
...
(default <expression>))
```

When converted to ANF, the `<value>` must be an Aexp, `<index>` values must be literal integers,
and `<expression>` values are Exps.

Conversion to bytecode is then fairly trivial:

```
| aexp | COND | num-cases | <index1> | addr(exp1) | ... | <indexn> | addr(expn) | ..expdef.. | ..exp1.. | ... | ..expn.. |
```

`expdef` is the default expression, it requires no index or address.

The virtual machine, on seeing a `COND` instruction, pops the result of the Aexp off the stack and
reads the following number of cases.
it then walks the tuples of `| <indexi> | <addri> |` looking for a matching index. If it finds one it
jumps to the addr, otherwise it runs naturally in to the default. Obviously each expr must end with a JMP
instruction to go to the end of the entire COND construct.

This should be significanly faster than a nest of ifs, which has to read and perform separate
operations for each condition.

## Update

`cond` is now implemented and the `fib(40)` benchmark has gone from around 40s
to 28s, 1.4 times faster.
23 changes: 22 additions & 1 deletion src/analysis.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ static void analizeAexpUnaryApp(AexpUnaryApp *x, CTEnv *env);
static void analizeAexpList(AexpList *x, CTEnv *env);
static void analizeCexpApply(CexpApply *x, CTEnv *env);
static void analizeCexpIf(CexpIf *x, CTEnv *env);
static void analizeCexpCond(CexpCond *x, CTEnv *env);
static void analizeCexpCondCases(CexpCondCases *x, CTEnv *env);
static void analizeCexpLetRec(CexpLetRec *x, CTEnv *env);
static void analizeCexpAmb(CexpAmb *x, CTEnv *env);
static void analizeCexpCut(CexpCut *x, CTEnv *env);
Expand Down Expand Up @@ -129,6 +131,23 @@ static void analizeCexpIf(CexpIf *x, CTEnv *env) {
analizeExp(x->alternative, env);
}

static void analizeCexpCond(CexpCond *x, CTEnv *env) {
#ifdef DEBUG_ANALIZE
fprintf(stderr, "analizeCexpCond "); printCexpCond(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
#endif
analizeAexp(x->condition, env);
analizeCexpCondCases(x->cases, env);
}

static void analizeCexpCondCases(CexpCondCases *x, CTEnv *env) {
if (x == NULL) return;
#ifdef DEBUG_ANALIZE
fprintf(stderr, "analizeCexpCondCases "); printCexpCondCases(x); fprintf(stderr, " "); printCTEnv(env); fprintf(stderr, "\n");
#endif
analizeExp(x->body, env);
analizeCexpCondCases(x->next, env);
}

static void analizeLetRecLam(Aexp *x, CTEnv *env, int letRecOffset) {
switch (x->type) {
case AEXP_TYPE_LAM:
Expand Down Expand Up @@ -220,7 +239,6 @@ static void analizeAexp(Aexp *x, CTEnv *env) {
case AEXP_TYPE_FALSE:
case AEXP_TYPE_INT:
case AEXP_TYPE_CHAR:
case AEXP_TYPE_DEFAULT:
case AEXP_TYPE_VOID:
break;
case AEXP_TYPE_PRIM:
Expand Down Expand Up @@ -266,6 +284,9 @@ static void analizeCexp(Cexp *x, CTEnv *env) {
case CEXP_TYPE_IF:
analizeCexpIf(x->val.iff, env);
break;
case CEXP_TYPE_COND:
analizeCexpCond(x->val.cond, env);
break;
case CEXP_TYPE_CALLCC:
analizeAexp(x->val.callCC, env);
break;
Expand Down
54 changes: 44 additions & 10 deletions src/anf.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static Exp *normalizeMatch(LamMatch *match, Exp *tail);
static MatchList *normalizeMatchList(LamMatchList *matchList);
static AexpList *convertIntList(LamIntList *list);
static Exp *normalizeCond(LamCond *cond, Exp *tail);
static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacements);
static CexpCondCases *normalizeCondCases(LamCondCases *cases);
static CexpLetRec *replaceCexpLetRec(CexpLetRec *cexpLetRec, LamLetRecBindings *lamLetRecBindings);

Exp *anfNormalize(LamExp *lamExp) {
Expand Down Expand Up @@ -152,14 +152,21 @@ static Exp *normalizeCond(LamCond *cond, Exp *tail) {
int save = PROTECT(replacements);
Aexp *value = replaceLamExp(cond->value, replacements);
int save2 = PROTECT(value);
Exp *exp = replaceCond(value, cond->cases, replacements);
CexpCondCases *cases = normalizeCondCases(cond->cases);
PROTECT(cases);
CexpCond *cexpCond = newCexpCond(value, cases);
UNPROTECT(save2);
save2 = PROTECT(cexpCond);
Cexp *cexp = newCexp(CEXP_TYPE_COND, CEXP_VAL_COND(cexpCond));
REPLACE_PROTECT(save2, cexp);
Exp *exp = wrapCexp(cexp);
REPLACE_PROTECT(save2, exp);
exp = wrapTail(exp, tail);
REPLACE_PROTECT(save2, exp);
Exp *res = letBind(exp, replacements);
exp = letBind(exp, replacements);
UNPROTECT(save);
LEAVE(normalizeCond);
return res;
return exp;
}

static Exp *normalizeMatch(LamMatch *match, Exp *tail) {
Expand Down Expand Up @@ -689,6 +696,7 @@ static Aexp *cloneAexp(Aexp *orig) {
return orig;
}

/*
static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacements) {
ENTER(replaceCond);
if (cases == NULL) {
Expand Down Expand Up @@ -726,13 +734,40 @@ static Exp *replaceCond(Aexp *value, LamCondCases *cases, HashTable *replacement
LEAVE(replaceCond);
return exp;
}
*/


static Aexp *replaceCondDefault() {
return newAexp(AEXP_TYPE_DEFAULT, AEXP_VAL_DEFAULT());
static CexpCondCases *normalizeCondCases(LamCondCases *cases) {
ENTER(normalizeCondCases);
if (cases == NULL) {
LEAVE(normalizeCondCases);
return NULL;
}
CexpCondCases *next = normalizeCondCases(cases->next);
int save = PROTECT(next);
int constant = 0;
switch (cases->constant->type) {
case LAMEXP_TYPE_INTEGER:
constant = cases->constant->val.integer;
break;
case LAMEXP_TYPE_CHARACTER:
constant = (int) cases->constant->val.character;
break;
case LAMEXP_TYPE_COND_DEFAULT:
if (next != NULL) {
cant_happen("cond default not last case");
}
constant = 0;
break;
default:
cant_happen("unexpected type %d for constant in normalizeCondCases");
}
Exp *body = normalize(cases->body, NULL);
PROTECT(body);
CexpCondCases *this = newCexpCondCases(constant, body, next);
UNPROTECT(save);
return this;
}


static Aexp *replaceLamExp(LamExp *lamExp, HashTable *replacements) {
ENTER(replaceLamExp);
Aexp *res = NULL;
Expand Down Expand Up @@ -774,8 +809,7 @@ static Aexp *replaceLamExp(LamExp *lamExp, HashTable *replacements) {
res = replaceLamCexp(lamExp, replacements);
break;
case LAMEXP_TYPE_COND_DEFAULT:
res = replaceCondDefault();
break;
cant_happen("replaceLamExp encountered cond default");
default:
cant_happen("unrecognised type %d in replaceLamExp", lamExp->type);
}
Expand Down
72 changes: 72 additions & 0 deletions src/bytecode.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ static void writeWordAt(int loc, ByteCodeArray *b, int word) {
b->entries[loc + 1] = word & 255;
}

static void writeIntAt(int loc, ByteCodeArray *b, int word) {
b->entries[loc + 0] = (word >> 24) & 255;
b->entries[loc + 1] = (word >> 16) & 255;
b->entries[loc + 2] = (word >> 8) & 255;
b->entries[loc + 3] = word & 255;
}

static void writeCurrentAddressAt(int patch, ByteCodeArray *b) {
int offset = b->count - patch;
writeWordAt(patch, b, offset);
Expand All @@ -88,6 +95,15 @@ static int reserveWord(ByteCodeArray *b) {
return address;
}

static int reserveInt(ByteCodeArray *b) {
int address = b->count;
addByte(b, 0);
addByte(b, 0);
addByte(b, 0);
addByte(b, 0);
return address;
}

static void writeInt(ByteCodeArray *b, int word) {
if (word > 4294967295) {
cant_happen("maximum int size exceeded");
Expand Down Expand Up @@ -243,6 +259,58 @@ void writeCexpIf(CexpIf *x, ByteCodeArray *b) {
writeCurrentAddressAt(patch2, b);
}

static int countCexpCondCases(CexpCondCases *x) {
int val = 0;
while (x != NULL) {
val++;
x = x->next;
}
return val;
}

void writeCexpCondCases(int depth, int *values, int *addresses, int *jumps, CexpCondCases *x, ByteCodeArray *b) {
if (x == NULL) {
return;
}
writeCexpCondCases(depth + 1, values, addresses, jumps, x->next, b);
if (x->next == NULL) { // default
writeExp(x->body, b);
} else {
writeIntAt(values[depth], b, x->option);
writeCurrentAddressAt(addresses[depth], b);
writeExp(x->body, b);
}
if (depth > 0) {
addByte(b, BYTECODE_JMP);
jumps[depth - 1] = reserveWord(b);
}
}

void writeCexpCond(CexpCond *x, ByteCodeArray *b) {
int numCases = countCexpCondCases(x->cases);
numCases--; // don't count the default case
if (numCases <= 0) {
cant_happen("zero cases in writeCexpCond");
}
writeAexp(x->condition, b);
addByte(b, BYTECODE_COND);
writeWord(b, numCases);
int *values = NEW_ARRAY(int, numCases); // address in b for each index_i
int *addresses = NEW_ARRAY(int, numCases); // address in b for each addr(exp_i)
int *jumps = NEW_ARRAY(int, numCases); // address in b for the JMP patch address at the end of each expression
for (int i = 0; i < numCases; i++) {
values[i] = reserveInt(b);
addresses[i] = reserveWord(b);
}
writeCexpCondCases(0, values, addresses, jumps, x->cases, b);
for (int i = 0; i < numCases; i++) {
writeCurrentAddressAt(jumps[i], b);
}
FREE_ARRAY(int, values, numCases);
FREE_ARRAY(int, addresses, numCases);
FREE_ARRAY(int, jumps, numCases);
}

void writeCexpLetRec(CexpLetRec *x, ByteCodeArray *b) {
writeLetRecBindings(x->bindings, b);
addByte(b, BYTECODE_LETREC);
Expand Down Expand Up @@ -412,6 +480,10 @@ void writeCexp(Cexp *x, ByteCodeArray *b) {
writeCexpIf(x->val.iff, b);
}
break;
case CEXP_TYPE_COND: {
writeCexpCond(x->val.cond, b);
}
break;
case CEXP_TYPE_MATCH: {
writeCexpMatch(x->val.match, b);
}
Expand Down
1 change: 1 addition & 0 deletions src/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ typedef enum ByteCodes {
BYTECODE_MATCH,
BYTECODE_APPLY,
BYTECODE_IF,
BYTECODE_COND,
BYTECODE_LETREC,
BYTECODE_AMB,
BYTECODE_CUT,
Expand Down
2 changes: 1 addition & 1 deletion src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
// #define TEST_STACK
// #define DEBUG_STACK
// #define DEBUG_STEP
#define DEBUG_STRESS_GC
// #define DEBUG_STRESS_GC
// #define DEBUG_LOG_GC
// #define DEBUG_GC
// #define DEBUG_TPMC_MATCH
Expand Down
Loading

0 comments on commit 7b4172c

Please sign in to comment.