diff --git a/docs/pettersson92.pdf b/docs/pettersson92.pdf new file mode 100644 index 0000000..43d59b2 Binary files /dev/null and b/docs/pettersson92.pdf differ diff --git a/fn/curry.fn b/fn/curry.fn index c4af4d7..da3cc02 100644 --- a/fn/curry.fn +++ b/fn/curry.fn @@ -1,5 +1,5 @@ let fn add3(a, b, c) { a + b + c } in - add3(1)(2)(3) + print(add3(1)(2)(3)) diff --git a/fn/liars.fn b/fn/liars.fn index a5b47d5..b1e0b12 100644 --- a/fn/liars.fn +++ b/fn/liars.fn @@ -27,6 +27,41 @@ let } } + fn sortBy(predicate, lst) { + let + fn full_sort { + ([]) { [] } + (first @ rest) { + partition(first, rest, fn (lesser, greater) { + partial_sort(lesser, first @ full_sort(greater)) + }) + } + } + fn partial_sort { + (first @ rest, already_sorted) { + partition(first, rest, fn (lesser, greater) { + partial_sort(lesser, first @ partial_sort(greater, already_sorted)) + }) + } + ([], sorted) { sorted } + } + fn partition(key, lst, kont) { + let fn helper { + ([], lesser, greater) { kont(lesser, greater) } + (first @ rest, lesser, greater) { + if (predicate(key, first) == lt) { + helper(rest, lesser, first @ greater) + } else { + helper(rest, first @ lesser, greater) + } + } + } + in helper(lst, [], []) + } + in + full_sort(lst) + } + fn liars() { let ranks = [1, 2, 3, 4, 5]; @@ -41,7 +76,16 @@ let require((joan == 3) xor (ethel == 5)); require((kitty == 2) xor (mary == 4)); require((mary == 4) xor (betty == 1)); - [betty, ethel, joan, kitty, mary] + sortBy( + fn (#(_, a), #(_, b)) { a <=> b }, + [ + #("Betty", betty), + #("Ethel", ethel), + #("Joan", joan), + #("Kitty", kitty), + #("Mary", mary) + ] + ) } in print(liars()) diff --git a/fn/listutils.fn b/fn/listutils.fn index df59056..dcee891 100644 --- a/fn/listutils.fn +++ b/fn/listutils.fn @@ -146,7 +146,7 @@ let } } - fn sort(lst) { + fn sortBy(predicate, lst) { let fn full_sort { ([]) { [] } @@ -168,7 +168,7 @@ let let fn helper { ([], lesser, greater) { kont(lesser, greater) } (first @ rest, lesser, greater) { - if (key < first) { + if (predicate(key, first) == lt) { helper(rest, lesser, first @ greater) } else { helper(rest, first @ lesser, greater) @@ -181,5 +181,7 @@ let full_sort(lst) } + sort = sortBy(fn (a, b) { a <=> b }); + in print(concat(take(3, ["well", " ", "hi", " ", "there"]))) diff --git a/fn/triple.fn b/fn/triple.fn new file mode 100644 index 0000000..5f72117 --- /dev/null +++ b/fn/triple.fn @@ -0,0 +1,8 @@ +let + typedef T(#a, #b, #c) { triple(#a, #b, #c) } + fn testTriple { + (triple(1, c, d)) { c * d } + (triple(2, c, d)) { c + d } + } +in + true diff --git a/fn/tuple.fn b/fn/tuple.fn new file mode 100644 index 0000000..d3f6502 --- /dev/null +++ b/fn/tuple.fn @@ -0,0 +1,7 @@ +let + fn testTuple { + (#(1, c, d)) { #(c, d, 'h', c * d) } + (#(2, c, d)) { #(c, d, 'h', c + d) } + } +in + print(testTuple(#(1, 2, 3))) diff --git a/fn/tuple2.fn b/fn/tuple2.fn new file mode 100644 index 0000000..fa5edaf --- /dev/null +++ b/fn/tuple2.fn @@ -0,0 +1 @@ +print(#(1, 2, 3)) diff --git a/src/anf_normalize.c b/src/anf_normalize.c index 19abd60..27440ed 100644 --- a/src/anf_normalize.c +++ b/src/anf_normalize.c @@ -87,6 +87,8 @@ static CexpCondCases *normalizeCondCases(LamCondCases *cases); static CexpLetRec *replaceCexpLetRec(CexpLetRec *cexpLetRec, LamLetRecBindings *lamLetRecBindings); static Exp *normalizeConstruct(LamConstruct *construct, Exp *tail); +static Exp *normalizeMakeTuple(LamList *tuple, Exp *tail); +static Exp *normalizeTupleIndex(LamTupleIndex *construct, Exp *tail); static Exp *normalizeDeconstruct(LamDeconstruct *deconstruct, Exp *tail); static Exp *normalizeTag(LamExp *tag, Exp *tail); @@ -132,6 +134,8 @@ static Exp *normalize(LamExp *lamExp, Exp *tail) { return normalizePrint(lamExp->val.print, tail); case LAMEXP_TYPE_LETREC: return normalizeLetRec(lamExp->val.letrec, tail); + case LAMEXP_TYPE_TUPLE_INDEX: + return normalizeTupleIndex(lamExp->val.tuple_index, tail); case LAMEXP_TYPE_DECONSTRUCT: return normalizeDeconstruct(lamExp->val.deconstruct, tail); case LAMEXP_TYPE_CONSTRUCT: @@ -152,10 +156,12 @@ static Exp *normalize(LamExp *lamExp, Exp *tail) { return normalizeBack(tail); case LAMEXP_TYPE_ERROR: return normalizeError(tail); + case LAMEXP_TYPE_MAKE_TUPLE: + return normalizeMakeTuple(lamExp->val.make_tuple, tail); case LAMEXP_TYPE_COND_DEFAULT: cant_happen("normalize encountered cond default"); default: - cant_happen("unrecognized type %d in normalize", lamExp->type); + cant_happen("unrecognized type %s", lamExpTypeName(lamExp->type)); } LEAVE(normalize); } @@ -274,6 +280,27 @@ static Exp *normalizeDeconstruct(LamDeconstruct *deconstruct, Exp *tail) { return res; } +static LamPrimApp *tupleIndexToPrimApp(LamTupleIndex *tupleIndex) { + LamExp *index = + newLamExp(LAMEXP_TYPE_STDINT, LAMEXP_VAL_STDINT(tupleIndex->vec)); + int save = PROTECT(index); + LamPrimApp *res = + newLamPrimApp(LAMPRIMOP_TYPE_VEC, index, tupleIndex->exp); + UNPROTECT(save); + return res; +} + +static Exp *normalizeTupleIndex(LamTupleIndex *index, Exp *tail) { + ENTER(noramaalizeTupleIndex); + LamPrimApp *primApp = tupleIndexToPrimApp(index); + int save = PROTECT(primApp); + Exp *res = normalizePrim(primApp, tail); + UNPROTECT(save); + LEAVE(noramaalizeTupleIndex); + return res; +} + + static Exp *normalizeTag(LamExp *tagged, Exp *tail) { ENTER(noramaalizeTag); LamPrimApp *primApp = tagToPrimApp(tagged); @@ -414,10 +441,7 @@ static Exp *normalizeMakeVec(LamMakeVec *lamMakeVec, Exp *tail) { } static LamMakeVec *constructToMakeVec(LamConstruct *construct) { - int nargs = 0; - for (LamList *args = construct->args; args != NULL; args = args->next) { - nargs++; - } + int nargs = countLamList(construct->args); LamExp *newArg = newLamExp(LAMEXP_TYPE_STDINT, LAMEXP_VAL_STDINT(construct->tag)); int save = PROTECT(newArg); @@ -428,6 +452,12 @@ static LamMakeVec *constructToMakeVec(LamConstruct *construct) { return res; } +static LamMakeVec *tupleToMakeVec(LamList *tuple) { + int nargs = countLamList(tuple); + LamMakeVec *res = newLamMakeVec(nargs, tuple); + return res; +} + static Exp *normalizeConstruct(LamConstruct *construct, Exp *tail) { ENTER(normalizeConstruct); LamMakeVec *makeVec = constructToMakeVec(construct); @@ -438,6 +468,14 @@ static Exp *normalizeConstruct(LamConstruct *construct, Exp *tail) { return res; } +static Exp *normalizeMakeTuple(LamList *tuple, Exp *tail) { + LamMakeVec *makeVec = tupleToMakeVec(tuple); + int save = PROTECT(makeVec); + Exp *res = normalizeMakeVec(makeVec, tail); + UNPROTECT(save); + return res; +} + // sequences are not covered by the algorithm // however the algorithm states that "All non-atomic // (complex) expressions must be let-bound or appear @@ -912,13 +950,13 @@ static Aexp *replaceLamExp(LamExp *lamExp, LamExpTable *replacements) { case LAMEXP_TYPE_AND: case LAMEXP_TYPE_OR: case LAMEXP_TYPE_AMB: + case LAMEXP_TYPE_MAKE_TUPLE: res = replaceLamCexp(lamExp, replacements); break; case LAMEXP_TYPE_COND_DEFAULT: cant_happen("replaceLamExp encountered cond default"); default: - cant_happen("unrecognised type %d in replaceLamExp", - lamExp->type); + cant_happen("unrecognised type %s", lamExpTypeName(lamExp->type)); } LEAVE(replaceLamExp); return res; diff --git a/src/anf_pp.c b/src/anf_pp.c index 17b7b28..1077d79 100644 --- a/src/anf_pp.c +++ b/src/anf_pp.c @@ -45,6 +45,22 @@ void ppAexpVarList(AexpVarList *x) { eprintf(")"); } +static void ppChar(char c) { + switch(c) { + case '\n': + eprintf("\"\\n\""); + break; + case '\t': + eprintf("\"\\t\""); + break; + case '\"': + eprintf("\"\\\"\""); + break; + default: + eprintf("\"%c\"", c); + } +} + void ppAexpVar(HashSymbol *x) { eprintf("%s", x->name); } @@ -99,8 +115,11 @@ void ppAexpPrimApp(AexpPrimApp *x) { case AEXPPRIMOP_TYPE_MOD: eprintf("mod "); break; + case AEXPPRIMOP_TYPE_CMP: + eprintf("cmp "); + break; default: - cant_happen("unrecognized op in ppAexpPrimApp (%d)", x->type); + cant_happen("unrecognized op %s", aexpPrimOpName(x->type)); } ppAexp(x->exp1); if (x->exp2 != NULL) { @@ -228,7 +247,9 @@ void ppCexpIntCondCases(CexpIntCondCases *x) { void ppCexpCharCondCases(CexpCharCondCases *x) { while (x != NULL) { - eprintf("('%c' ", x->option); + eprintf("("); + ppChar(x->option); + eprintf(" "); ppExp(x->body); eprintf(")"); if (x->next) { @@ -357,7 +378,7 @@ void ppAexp(Aexp *x) { eprintf("%d", x->val.littleinteger); break; case AEXP_TYPE_CHARACTER: - eprintf("'%c'", x->val.character); + ppChar(x->val.character); break; case AEXP_TYPE_PRIM: ppAexpPrimApp(x->val.prim); diff --git a/src/ast.yaml b/src/ast.yaml index bd24338..caa07b1 100644 --- a/src/ast.yaml +++ b/src/ast.yaml @@ -148,6 +148,7 @@ unions: unpack: AstUnpack number: BigInt character: char + tuple: AstArgList AstExpression: back: void_ptr @@ -160,6 +161,7 @@ unions: nest: AstNest iff: AstIff print: AstPrint + tuple: AstExpressions arrays: AstCharArray: diff --git a/src/cekf.h b/src/cekf.h index eaa242f..b47167a 100644 --- a/src/cekf.h +++ b/src/cekf.h @@ -107,6 +107,8 @@ void setFrame(Stack *stack, int nargs); void clearFrame(Stack *stack); void copyTosToEnv(Stack *s, Env *e, int n); void copyValues(Value *to, Value *from, int size); +// safe version of copyValues: +void moveValues(Value *to, Value *from, int size); extern Snapshot noSnapshot; @@ -135,5 +137,6 @@ void markEnv(Env *x); void markKont(Kont *x); void markFail(Fail *x); void markVec(Vec *x); +void dumpStack(Stack *stack); #endif diff --git a/src/common.h b/src/common.h index 8af2700..9f0372b 100644 --- a/src/common.h +++ b/src/common.h @@ -25,34 +25,35 @@ # define DEBUG_ANY # ifdef DEBUG_ANY -// #define DEBUG_STACK -// #define DEBUG_STEP +// # define DEBUG_STACK +// # define DEBUG_STEP // if DEBUG_STEP is defined, this sleeps for 1 second between each machine step -// #define DEBUG_SLOW_STEP +// # define DEBUG_SLOW_STEP // define this to cause a GC at every malloc (catches memory leaks early) # define DEBUG_STRESS_GC -// #define DEBUG_LOG_GC -// #define DEBUG_GC +// # define DEBUG_LOG_GC +// # define DEBUG_GC // # define DEBUG_TPMC_MATCH // # define DEBUG_TPMC_TRANSLATE // # define DEBUG_TPMC_LOGIC -// #define DEBUG_ANNOTATE -// #define DEBUG_DESUGARING -// #define DEBUG_HASHTABLE -// #define DEBUG_TIN_SUBSTITUTION -// #define DEBUG_TIN_INSTANTIATION -// #define DEBUG_TIN_UNIFICATION -// #define DEBUG_BYTECODE +// # define DEBUG_TPMC_COMPARE +// # define DEBUG_ANNOTATE +// # define DEBUG_DESUGARING +// # define DEBUG_HASHTABLE +// # define DEBUG_TIN_SUBSTITUTION +// # define DEBUG_TIN_INSTANTIATION +// # define DEBUG_TIN_UNIFICATION +// # define DEBUG_BYTECODE // define this to make fatal errors dump core (if ulimit allows) # define DEBUG_DUMP_CORE // # define DEBUG_TC // # define DEBUG_LAMBDA_CONVERT -// #define DEBUG_LAMBDA_SUBSTITUTE -// #define DEBUG_LEAK -// #define DEBUG_ANF -// #define DEBUG_ALLOC -// #define DEBUG_PRINT_GENERATOR -// #define DEBUG_PRINT_COMPILER +// # define DEBUG_LAMBDA_SUBSTITUTE +// # define DEBUG_LEAK +// # define DEBUG_ANF +// # define DEBUG_ALLOC +// # define DEBUG_PRINT_GENERATOR +// # define DEBUG_PRINT_COMPILER // define this to turn on additional safety checks for things that shouldn't but just possibly might happen # define SAFETY_CHECKS # endif diff --git a/src/debug.c b/src/debug.c index 93fc128..446651c 100644 --- a/src/debug.c +++ b/src/debug.c @@ -25,7 +25,7 @@ #include "debug.h" #include "hash.h" -static void printClo(Clo *x, int depth); +static void printClo(Clo *x, char *type, int depth); static void printElidedEnv(Env *x); static void printEnv(Env *x, int depth); static void printFail(Fail *x, int depth); @@ -66,7 +66,10 @@ void printContainedValue(Value x, int depth) { } break; case VALUE_TYPE_CLO: - printClo(x.val.clo, depth); + printClo(x.val.clo, "C", depth); + break; + case VALUE_TYPE_PCLO: + printClo(x.val.clo, "PC", depth); break; case VALUE_TYPE_CONT: printKont(x.val.k, depth); @@ -76,7 +79,7 @@ void printContainedValue(Value x, int depth) { printVec(x.val.vec); break; default: - cant_happen("unrecognised value type in printContainedValue"); + cant_happen("unrecognised value type %d", x.type); } } @@ -119,8 +122,8 @@ void printValue(Value x, int depth) { eprintf("]"); } -void printElidedClo(Clo *x) { - eprintf("C[%d, %04lx, E[<...>], ", x->nvar, x->c); +void printElidedClo(Clo *x, char *type) { + eprintf("%s[%d, %04lx, E[<...>], ", type, x->nvar, x->c); eprintf("]"); } @@ -173,7 +176,10 @@ void printElidedValue(Value x) { printVec(x.val.vec); break; case VALUE_TYPE_CLO: - printElidedClo(x.val.clo); + printElidedClo(x.val.clo, "C"); + break; + case VALUE_TYPE_PCLO: + printElidedClo(x.val.clo, "PC"); break; case VALUE_TYPE_CONT: printElidedKont(x.val.k); @@ -184,9 +190,9 @@ void printElidedValue(Value x) { eprintf("]"); } -static void printClo(Clo *x, int depth) { +static void printClo(Clo *x, char *type, int depth) { printPad(depth); - eprintf("C[%d, %04lx, ", x->nvar, x->c); + eprintf("%s[%d, %04lx, ", type, x->nvar, x->c); printElidedEnv(x->rho); eprintf("]"); } diff --git a/src/lambda.yaml b/src/lambda.yaml index c2f8ac6..7a9b7f6 100644 --- a/src/lambda.yaml +++ b/src/lambda.yaml @@ -71,6 +71,11 @@ structs: vec: int # offset of the field in the vector exp: LamExp # expression being deconstructed + LamTupleIndex: + vec: int # offset of the field in the untagged tuple vector + size: int # size of the tuple + exp: LamExp # expression being accessed + LamMakeVec: nargs: int args: LamList @@ -228,6 +233,9 @@ unions: makeVec: LamMakeVec construct: LamConstruct deconstruct: LamDeconstruct + tuple_index: LamTupleIndex + tuple: LamList + make_tuple: LamList tag: LamExp constant: LamConstant apply: LamApply diff --git a/src/lambda_conversion.c b/src/lambda_conversion.c index 374cc2e..d2d8c98 100644 --- a/src/lambda_conversion.c +++ b/src/lambda_conversion.c @@ -32,6 +32,8 @@ #include "ast_debug.h" #include "print_generator.h" +char *lambda_conversion_function = NULL; // set by --lambda-conversion flag + static LamLetRecBindings *convertFuncDefs(AstDefinitions *definitions, LamContext *env); static LamList *convertExpressions(AstExpressions *expressions, @@ -141,6 +143,14 @@ static LamExp *lamConvertPrint(AstPrint *print, LamContext *context) { return result; } +static LamExp *lamConvertTuple(AstExpressions *tuple, LamContext *env) { + LamList *expressions = convertExpressions(tuple, env); + int save = PROTECT(expressions); + LamExp *res = newLamExp(LAMEXP_TYPE_MAKE_TUPLE, LAMEXP_VAL_MAKE_TUPLE(expressions)); + UNPROTECT(save); + return res; +} + static LamLetRecBindings *convertFuncDefs(AstDefinitions *definitions, LamContext *env) { ENTER(convertFuncDefs); @@ -410,6 +420,10 @@ static LamLetRecBindings *prependDefine(AstDefine * define, LamContext * env, if (doMermaid) tpmc_mermaid_flag = 1; LamExp *exp = convertExpression(define->expression, env); + if (lambda_conversion_function != NULL && strcmp(lambda_conversion_function, define->symbol->name) == 0) { + ppLamExp(exp); + eprintf("\n"); + } if (doMermaid) tpmc_mermaid_flag = 0; int save = PROTECT(exp); @@ -703,10 +717,13 @@ static LamExp *convertExpression(AstExpression *expression, LamContext *env) { case AST_EXPRESSION_TYPE_PRINT: result = lamConvertPrint(expression->val.print, env); break; + case AST_EXPRESSION_TYPE_TUPLE: + result = lamConvertTuple(expression->val.tuple, env); + break; default: cant_happen - ("unrecognised expression type %d in convertExpression", - expression->type); + ("unrecognised expression type %s", + astExpressionTypeName(expression->type)); } LEAVE(convertExpression); return result; diff --git a/src/lambda_conversion.h b/src/lambda_conversion.h index 2eb2a29..f76cf5b 100644 --- a/src/lambda_conversion.h +++ b/src/lambda_conversion.h @@ -22,4 +22,7 @@ # include "lambda.h" LamExp *lamConvertNest(AstNest *nest, LamContext *env); + +extern char *lambda_conversion_function; + #endif diff --git a/src/lambda_pp.c b/src/lambda_pp.c index 44ae11f..b50329e 100644 --- a/src/lambda_pp.c +++ b/src/lambda_pp.c @@ -187,8 +187,14 @@ void ppLamExp(LamExp *exp) { case LAMEXP_TYPE_COND_DEFAULT: eprintf("default"); break; + case LAMEXP_TYPE_TUPLE_INDEX: + ppLamTupleIndex(exp->val.tuple_index); + break; + case LAMEXP_TYPE_MAKE_TUPLE: + ppLamMakeTuple(exp->val.make_tuple); + break; default: - cant_happen("unrecognized type %d in ppLamExp", exp->type); + cant_happen("unrecognized type %s", lamExpTypeName(exp->type)); } } @@ -312,6 +318,12 @@ static void _ppLamList(LamList *list) { _ppLamList(list->next); } +void ppLamMakeTuple(LamList *args) { + eprintf("(make-tuple"); + _ppLamList(args); + eprintf(")"); +} + void ppLamSequence(LamSequence *sequence) { eprintf("(begin "); _ppLamSequence(sequence); @@ -498,6 +510,16 @@ void ppLamMatch(LamMatch *match) { eprintf(")"); } +void ppLamTupleIndex(LamTupleIndex *index) { + if (index == NULL) { + eprintf(""); + return; + } + eprintf("(index %d ", index->vec); + ppLamExp(index->exp); + eprintf(")"); +} + static void _ppLamLetRecBindings(LamLetRecBindings *bindings) { if (bindings == NULL) return; diff --git a/src/lambda_pp.h b/src/lambda_pp.h index cfb9b72..9a5a8be 100644 --- a/src/lambda_pp.h +++ b/src/lambda_pp.h @@ -46,7 +46,9 @@ void ppLamConstant(LamConstant *constant); void ppLamTypeDefs(LamTypeDefs *typeDefs); void ppLamLet(LamLet *let); void ppLamMatch(LamMatch *match); +void ppLamTupleIndex(LamTupleIndex *index); void ppLamLetRecBindings(LamLetRecBindings *bindings); void ppLamIntList(LamIntList *list); +void ppLamMakeTuple(LamList *args); #endif diff --git a/src/lambda_substitution.c b/src/lambda_substitution.c index 798adfb..5463864 100644 --- a/src/lambda_substitution.c +++ b/src/lambda_substitution.c @@ -114,6 +114,12 @@ static LamList *performListSubstitutions(LamList *list, TpmcSubstitutionTable return list; } +static LamTupleIndex *performTupleIndexSubstitutions(LamTupleIndex *tupleIndex, + TpmcSubstitutionTable *substitutions) { + tupleIndex->exp = lamPerformSubstitutions(tupleIndex->exp, substitutions); + return tupleIndex; +} + static LamMakeVec *performMakeVecSubstitutions(LamMakeVec *makeVec, TpmcSubstitutionTable *substitutions) { ENTER(performMakeVecSubstitutions); @@ -413,10 +419,17 @@ LamExp *lamPerformSubstitutions(LamExp *exp, exp->val.amb = performAmbSubstitutions(exp->val.amb, substitutions); break; + case LAMEXP_TYPE_MAKE_TUPLE: + exp->val.make_tuple = + performListSubstitutions(exp->val.make_tuple, substitutions); + break; + case LAMEXP_TYPE_TUPLE_INDEX: + exp->val.tuple_index = + performTupleIndexSubstitutions(exp->val.tuple_index, substitutions); + break; default: cant_happen - ("unrecognized LamExp type (%d) in lamPerformSubstitutions", - exp->type); + ("unrecognized LamExp type %s", lamExpTypeName(exp->type)); } LEAVE(lamPerformSubstitutions); return exp; diff --git a/src/main.c b/src/main.c index c211119..f1a4024 100644 --- a/src/main.c +++ b/src/main.c @@ -46,6 +46,7 @@ int report_flag = 0; static int help_flag = 0; +static int anf_flag = 0; static void processArgs(int argc, char *argv[]) { int c; @@ -54,8 +55,11 @@ static void processArgs(int argc, char *argv[]) { static struct option long_options[] = { { "bigint", no_argument, &bigint_flag, 1 }, { "report", no_argument, &report_flag, 1 }, + { "anf", no_argument, &anf_flag, 1 }, + { "dump-bytecode", no_argument, &dump_bytecode_flag, 1 }, { "help", no_argument, &help_flag, 1 }, - { "tpmc-mermaid", required_argument, 0, 'm' }, + { "tpmc", required_argument, 0, 'm' }, + { "lambda", required_argument, 0, 'l' }, { 0, 0, 0, 0 } }; int option_index = 0; @@ -68,14 +72,22 @@ static void processArgs(int argc, char *argv[]) { if (c == 'm') { tpmc_mermaid_function = optarg; } + + if (c == 'l') { + lambda_conversion_function = optarg; + } } if (help_flag) { printf("%s", "--bigint use arbitrary precision integers\n" "--report report statistics\n" - "--tpmc-mermaid=function produce a mermaid graph of the\n" + "--anf display the generated ANF\n" + "--lambda=function display the intermediate code\n" + " generated for the function\n" + "--tpmc=function produce a mermaid graph of the\n" " function's TPMC state table\n" + "--dump-bytecode dump the generated bytecode\n" "--help this help\n"); exit(0); } @@ -172,6 +184,11 @@ int main(int argc, char *argv[]) { Exp *anfExp = anfNormalize(exp); REPLACE_PROTECT(save, anfExp); + if (anf_flag) { + ppExp(anfExp); + eprintf("\n"); + } + anfExp = desugar(anfExp); REPLACE_PROTECT(save, anfExp); diff --git a/src/parser.y b/src/parser.y index f6fac13..3e78187 100644 --- a/src/parser.y +++ b/src/parser.y @@ -160,7 +160,7 @@ static AstCompositeFunction *makeAstCompositeFunction(AstAltFunction *functions, %type str %type number %type farg -%type fargs +%type fargs arg_tuple %type composite_function functions fun %type defun denv %type definition @@ -168,7 +168,7 @@ static AstCompositeFunction *makeAstCompositeFunction(AstAltFunction *functions, %type env env_expr %type env_type %type expression -%type expressions expression_statements +%type expressions expression_statements tuple %type user_type %type fun_call binop conslist unop switch string %type load @@ -362,6 +362,7 @@ consfargs : farg { $$ = newAstUnpack(consSymbol(), newAstArgList( number : NUMBER { $$ = makeBigInt($1); } ; + farg : symbol { $$ = newAstArg(AST_ARG_TYPE_SYMBOL, AST_ARG_VAL_SYMBOL($1)); } | unpack { $$ = newAstArg(AST_ARG_TYPE_UNPACK, AST_ARG_VAL_UNPACK($1)); } | cons { $$ = newAstArg(AST_ARG_TYPE_UNPACK, AST_ARG_VAL_UNPACK($1)); } @@ -373,8 +374,12 @@ farg : symbol { $$ = newAstArg(AST_ARG_TYPE_SYMBOL, AST_ARG_VAL_SYM | stringarg { $$ = newAstArg(AST_ARG_TYPE_UNPACK, AST_ARG_VAL_UNPACK($1)); } | CHAR { $$ = newAstArg(AST_ARG_TYPE_CHARACTER, AST_ARG_VAL_CHARACTER($1)); } | WILDCARD { $$ = newAstArg(AST_ARG_TYPE_WILDCARD, AST_ARG_VAL_WILDCARD()); } + | arg_tuple { $$ = newAstArg(AST_ARG_TYPE_TUPLE, AST_ARG_VAL_TUPLE($1)); } ; +arg_tuple: '#' '(' fargs ')' { $$ = $3; } + ; + cons : farg CONS farg { $$ = newAstUnpack(consSymbol(), newAstArgList($1, newAstArgList($3, NULL))); } ; @@ -412,6 +417,7 @@ expression : binop { $$ = newAstExpression(AST_EXPRESSION_TYPE_FU | CHAR { $$ = newAstExpression(AST_EXPRESSION_TYPE_CHARACTER, AST_EXPRESSION_VAL_CHARACTER($1)); } | nest { $$ = newAstExpression(AST_EXPRESSION_TYPE_NEST, AST_EXPRESSION_VAL_NEST($1)); } | print { $$ = newAstExpression(AST_EXPRESSION_TYPE_PRINT, AST_EXPRESSION_VAL_PRINT($1)); } + | tuple { $$ = newAstExpression(AST_EXPRESSION_TYPE_TUPLE, AST_EXPRESSION_VAL_TUPLE($1)); } | '(' expression ')' { $$ = $2; } ; @@ -425,6 +431,9 @@ unop : '-' expression %prec NEG { $$ = unOpToFunCall(negSymbol(), $2); } fun_call : expression '(' expressions ')' { $$ = newAstFunCall($1, $3); } ; +tuple : '#' '(' expressions ')' { $$ = $3; } + ; + binop : expression THEN expression { $$ = binOpToFunCall(thenSymbol(), $1, $3); } | expression AND expression { $$ = binOpToFunCall(andSymbol(), $1, $3); } | expression OR expression { $$ = binOpToFunCall(orSymbol(), $1, $3); } diff --git a/src/preamble.c b/src/preamble.c index d00c42a..ab16e3d 100644 --- a/src/preamble.c +++ b/src/preamble.c @@ -20,32 +20,124 @@ // pre-defined, for example the infix `@` is mapped to `cons`, `@@` to `append`, // prefix `<` to `car` etc. // `puts` is required for the print system, and `cmp` for the `<=>` operator. + +// *INDENT-OFF* const char *preamble = - "let" " typedef cmp { lt | eq | gt }" + "let" + " typedef cmp { lt | eq | gt }" " typedef bool { false | true }" - " typedef list(#t) { nil | cons(#t, list(#t)) }" " fn append {" - " ([], b) { b }" " (h @ t, b) { h @ append(t, b) }" " }" - " fn car {" " (h @ _) { h }" " }" " fn cdr {" - " (_ @ t) { t }" " }" " fn puts(s) {" " let" - " fn helper {" " ([]) { true }" - " (h @ t) {" " putc(h);" - " helper(t)" " }" " }" - " in" " helper(s);" " s" " }" - " fn print_list(helper, l) {" " let" " fn h1 {" - " ([]) { true }" " (h @ t) {" - " helper(h);" " h2(t)" - " }" " }" " fn h2 {" - " ([]) { true }" " (h @ t) {" - " puts(\", \");" " helper(h);" - " h2(t)" " }" " }" - " in" " puts(\"[\");" " h1(l);" - " puts(\"]\");" " l" " }" " fn print_fn(f) {" - " puts(\"\");" " f" " }" - " fn print_int(n) {" " putn(n);" " n" " }" - " fn print_char(c) {" " putc('\\'');" " putc(c);" - " putc('\\'');" " c" " }" " fn print_(v) {" - " putv(v);" " v" " }" " fn print_string(s) {" - " putc('\"');" " puts(s);" " putc('\"');" " s" - " }" "in {"; + " typedef list(#t) { nil | cons(#t, list(#t)) }" + " fn append {" + " ([], b) { b }" + " (h @ t, b) { h @ append(t, b) }" + " }" + " fn car {" + " (h @ _) { h }" + " }" + " fn cdr {" + " (_ @ t) { t }" + " }" + " fn puts(s) {" + " let" + " fn helper {" + " ([]) { true }" + " (h @ t) {" + " putc(h);" + " helper(t)" + " }" + " }" + " in" + " helper(s);" + " s" + " }" + " fn print_list(helper, l) {" + " let" + " fn h1 {" + " ([]) { true }" + " (h @ t) {" + " helper(h);" + " h2(t)" + " }" + " }" + " fn h2 {" + " ([]) { true }" + " (h @ t) {" + " puts(\", \");" + " helper(h);" + " h2(t)" + " }" + " }" + " in" + " puts(\"[\");" + " h1(l);" + " puts(\"]\");" + " l" + " }" + " fn print_fn(f) {" + " puts(\"\");" + " f" + " }" + " fn print_int(n) {" + " putn(n);" + " n" + " }" + " fn print_char(c) {" + " putc('\\'');" + " putc(c);" + " putc('\\'');" + " c" + " }" + " fn print_(v) {" + " putv(v);" + " v" + " }" + " fn print_string(s) {" + " putc('\"');" + " puts(s);" + " putc('\"');" + " s" + " }" + " fn print_tuple_0(t) {" + " puts(\"#()\");" + " t" + " }" + " fn print_tuple_1(p1, t=#(a)) {" + " puts(\"#(\");" + " p1(a);" + " puts(\")\");" + " t" + " }" + " fn print_tuple_2(p1, p2, t=#(a, b)) {" + " puts(\"#(\");" + " p1(a);" + " puts(\", \");" + " p2(b);" + " puts(\")\");" + " t" + " }" + " fn print_tuple_3(p1, p2, p3, t=#(a, b, c)) {" + " puts(\"#(\");" + " p1(a);" + " puts(\", \");" + " p2(b);" + " puts(\", \");" + " p3(c);" + " puts(\")\");" + " t" + " }" + " fn print_tuple_4(p1, p2, p3, p4, t=#(a, b, c, d)) {" + " puts(\"#(\");" + " p1(a);" + " puts(\", \");" + " p2(b);" + " puts(\", \");" + " p3(c);" + " puts(\", \");" + " p4(d);" + " puts(\")\");" + " t" + " }" + "in {"; +// *INDENT-ON* const char *postamble = "}"; diff --git a/src/print_compiler.c b/src/print_compiler.c index 8320e29..8e233f0 100644 --- a/src/print_compiler.c +++ b/src/print_compiler.c @@ -27,6 +27,7 @@ #include "common.h" #include "lambda.h" #include "lambda_helper.h" +#include "lambda_pp.h" #include "symbol.h" #include "symbols.h" @@ -42,6 +43,7 @@ static LamExp *compilePrinterForVar(TcVar *var, TcEnv *env); static LamExp *compilePrinterForInt(); static LamExp *compilePrinterForChar(); static LamExp *compilePrinterForUserType(TcUserType *userType, TcEnv *env); +static LamExp *compilePrinterForTuple(TcTypeArray *tuple, TcEnv *env); static LamExp *compilePrinter(TcType *type, TcEnv *env); LamExp *compilePrinterForType(TcType *type, TcEnv *env) { @@ -110,9 +112,11 @@ static LamExp *compilePrinter(TcType *type, TcEnv *env) { case TCTYPE_TYPE_USERTYPE: res = compilePrinterForUserType(type->val.userType, env); break; + case TCTYPE_TYPE_TUPLE: + res = compilePrinterForTuple(type->val.tuple, env); + break; default: - cant_happen("unrecognised TcType %d in compilePrinter", - type->type); + cant_happen("unrecognised TcType %s", tcTypeTypeName(type->type)); } LEAVE(compilePrinter); return res; @@ -159,6 +163,20 @@ static LamList *compilePrinterForUserTypeArgs(TcUserTypeArgs *args, return res; } +static LamList *compilePrinterForTupleArgs(TcTypeArray *tuple, TcEnv *env) { + LamList *res = NULL; + int save = PROTECT(res); + for (int i = tuple->size; i > 0; i--) { + int index = i - 1; + LamExp *this = compilePrinter(tuple->entries[index], env); + PROTECT(this); + res = newLamList(this, res); + PROTECT(res); + } + UNPROTECT(save); + return res; +} + static LamExp *compilePrinterForString() { HashSymbol *name = newSymbol("print$string"); return newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(name)); @@ -191,3 +209,29 @@ static LamExp *compilePrinterForUserType(TcUserType *userType, TcEnv *env) { UNPROTECT(save); return res; } + +static LamExp *compilePrinterForTuple(TcTypeArray *tuple, TcEnv *env) { + ENTER(compilePrinterForTuple); + if (tuple->size < 5) { + char buf[16]; + sprintf(buf, "print$tuple$%d", tuple->size); + LamExp *exp = makeSymbolExpr(buf); + if (tuple->size == 0) { + LEAVE(compilePrinterForTuple); + return exp; + } + int save = PROTECT(exp); + LamList *args = compilePrinterForTupleArgs(tuple, env); + PROTECT(args); + LamApply *apply = newLamApply(exp, tuple->size, args); + PROTECT(apply); + LamExp *res = newLamExp(LAMEXP_TYPE_APPLY, LAMEXP_VAL_APPLY(apply)); + UNPROTECT(save); + IFDEBUG(ppLamExp(res)); + LEAVE(compilePrinterForTuple); + return res; + } else { + LEAVE(compilePrinterForTuple); + return makeSymbolExpr("print$"); + } +} diff --git a/src/stack.c b/src/stack.c index 60a42f1..531dc93 100644 --- a/src/stack.c +++ b/src/stack.c @@ -26,8 +26,8 @@ #include #include #include +#include "debug.h" #ifdef DEBUG_STACK -# include "debug.h" # include "debugging_on.h" #else # include "debugging_off.h" @@ -80,6 +80,17 @@ void pushValue(Stack *s, Value v) { s->stack[s->sp++] = v; } +void dumpStack(Stack *s) { + eprintf("STACK DUMP sp = %d, capacity = %d\n", s->sp, s->capacity); + eprintf("=================================\n"); + for (int i = 0; i < s->sp; i++) { + eprintf("[%d] *** ", i); + printValue(s->stack[i], 0); + eprintf("\n"); + } + eprintf("=================================\n"); +} + Value popValue(Stack *s) { DEBUG("popValue()"); if (s->sp == 0) { @@ -138,6 +149,10 @@ void copyValues(Value *to, Value *from, int size) { COPY_ARRAY(Value, to, from, size); } +void moveValues(Value *to, Value *from, int size) { + MOVE_ARRAY(Value, to, from, size); +} + void copyTosToEnv(Stack *s, Env *e, int n) { DEBUG("copyTosToEnv, sp = %d, capacity = %d", s->sp, s->capacity); copyValues(e->values, &(s->stack[s->sp - n]), n); diff --git a/src/step.c b/src/step.c index e1637fe..e8d4f6d 100644 --- a/src/step.c +++ b/src/step.c @@ -31,6 +31,8 @@ #include "step.h" #include "hash.h" +int dump_bytecode_flag = 0; + #ifdef DEBUG_STEP # define DEBUGPRINTF(...) printf(__VA_ARGS__) #else @@ -392,14 +394,14 @@ static Value not(Value a) { static Value vec(Value index, Value vector) { #ifdef SAFETY_CHECKS if (index.type != VALUE_TYPE_STDINT) - cant_happen("invalid index type for vec %d", index.type); + cant_happen("invalid index type for vec %d location %04lx", index.type, state.C); if (vector.type != VALUE_TYPE_VEC) - cant_happen("invalid vector type for vec %d", vector.type); + cant_happen("invalid vector type for vec %d location %04lx", vector.type, state.C); #endif int i = index.val.z; Vec *v = vector.val.vec; if (i < 0 || i >= v->size) - cant_happen("index out of range 0 - %d for vec (%d)", v->size, i); + cant_happen("index out of range 0 - %d for vec (%d), location %04lx", v->size, i, state.C); return v->values[i]; } @@ -440,15 +442,30 @@ static void applyProc(int naargs) { switch (callable.type) { case VALUE_TYPE_PCLO:{ Clo *clo = callable.val.clo; +#ifdef DEBUG_STEP + eprintf("partial closure, count == %d, nvar == %d, naargs == %d\n", clo->rho->count, clo->nvar, naargs); +#endif if (clo->nvar == naargs) { state.C = clo->c; state.E = clo->rho->next; - copyValues(state.S.stack, clo->rho->values, - clo->rho->count); - copyValues(&(state.S.stack[clo->rho->count]), +#ifdef DEBUG_STEP + eprintf("copying %d values from stack[%d] to stack[%d]\n", clo->nvar, state.S.sp - clo->nvar, clo->rho->count); +#endif + moveValues(&(state.S.stack[clo->rho->count]), &(state.S.stack[state.S.sp - clo->nvar]), clo->nvar); +#ifdef DEBUG_STEP + eprintf("copying %d values from closure[0] to stack[0]\n", clo->rho->count); +#endif + copyValues(state.S.stack, clo->rho->values, + clo->rho->count); +#ifdef DEBUG_STEP + eprintf("setting stack sp to %d\n", clo->rho->count + clo->nvar); +#endif state.S.sp = clo->rho->count + clo->nvar; +#ifdef DEBUG_STEP + dumpStack(&state.S); +#endif } else if (naargs == 0) { push(callable); } else if (naargs < clo->nvar) { @@ -526,9 +543,8 @@ void reportSteps(void) { } static void step() { -#ifdef DEBUG_STEP - dumpByteCode(&state.B); -#endif + if (dump_bytecode_flag) + dumpByteCode(&state.B); if (bigint_flag) { add = bigAdd; mul = bigMul; @@ -584,7 +600,11 @@ static void step() { case BYTECODE_LVAR:{ // look up a stack variable and push it int offset = readCurrentByte(); - DEBUGPRINTF("LVAR [%d]\n", offset); + DEBUGPRINTF("LVAR [%d] ", offset); +#ifdef DEBUG_STEP + printValue(peek(offset), 0); + eprintf("\n"); +#endif push(peek(offset)); } break; diff --git a/src/step.h b/src/step.h index 4859063..784edf2 100644 --- a/src/step.h +++ b/src/step.h @@ -27,4 +27,6 @@ void markCEKF(void); void reportSteps(void); +extern int dump_bytecode_flag; + #endif diff --git a/src/tc.yaml b/src/tc.yaml index 2836ab6..5bf3f46 100644 --- a/src/tc.yaml +++ b/src/tc.yaml @@ -61,6 +61,11 @@ hashes: TcTypeTable: entries: TcType +arrays: + TcTypeArray: + dimension: 1 + entries: TcType + unions: TcType: function: TcFunction @@ -71,5 +76,6 @@ unions: character: void_ptr unknown: HashSymbol userType: TcUserType + tuple: TcTypeArray primitives: !include primitives.yaml diff --git a/src/tc_analyze.c b/src/tc_analyze.c index e62cc7e..3d8a8aa 100644 --- a/src/tc_analyze.c +++ b/src/tc_analyze.c @@ -50,6 +50,7 @@ static TcType *makeUnknown(HashSymbol *var); static TcType *makeFreshVar(char *name __attribute__((unused))); static TcType *makeVar(HashSymbol *t); static TcType *makeFn(TcType *arg, TcType *result); +static TcType *makeTuple(int size); static void addBoolBinOpToEnv(TcEnv *env, HashSymbol *symbol); static void addHereToEnv(TcEnv *env); static void addIfToEnv(TcEnv *env); @@ -84,6 +85,8 @@ static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng); static TcType *analyzeAnd(LamAnd *and, TcEnv *env, TcNg *ng); static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng); static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng); +static TcType *analyzeTupleIndex(LamTupleIndex *index, TcEnv *env, TcNg *ng); +static TcType *analyzeMakeTuple(LamList *tuple, TcEnv *env, TcNg *ng); static TcType *analyzeCharacter(); static TcType *analyzeBack(); static TcType *analyzeError(); @@ -198,10 +201,14 @@ static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { return prune(analyzeBack()); case LAMEXP_TYPE_ERROR: return prune(analyzeError()); + case LAMEXP_TYPE_TUPLE_INDEX: + return prune(analyzeTupleIndex(exp->val.tuple_index, env, ng)); + case LAMEXP_TYPE_MAKE_TUPLE: + return prune(analyzeMakeTuple(exp->val.make_tuple, env, ng)); case LAMEXP_TYPE_COND_DEFAULT: cant_happen("encountered cond default in analyzeExp"); default: - cant_happen("unrecognized type %d in analyzeExp", exp->type); + cant_happen("unrecognized type %s", lamExpTypeName(exp->type)); } } @@ -494,6 +501,37 @@ static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, return fieldType; } +static TcType *analyzeTupleIndex(LamTupleIndex *index, TcEnv *env, TcNg *ng) { + TcType *tuple = analyzeExp(index->exp, env, ng); + int save = PROTECT(tuple); + TcType *template = makeTuple(index->size); + PROTECT(template); + if (!unify(tuple, template, "tuple index")) { + eprintf("while analyzing tuple "); + ppTcType(tuple); + HashSymbol *name = newSymbol("tuple"); + UNPROTECT(save); + return makeUnknown(name); + } + UNPROTECT(save); + return template->val.tuple->entries[index->vec]; +} + +static TcType *analyzeMakeTuple(LamList *tuple, TcEnv *env, TcNg *ng) { + TcTypeArray *values = newTcTypeArray(); + int save = PROTECT(values); + while (tuple != NULL) { + TcType *part = analyzeExp(tuple->exp, env, ng); + int save2 = PROTECT(part); + pushTcTypeArray(values, part); + UNPROTECT(save2); + tuple = tuple->next; + } + TcType *res = newTcType(TCTYPE_TYPE_TUPLE, TCTYPE_VAL_TUPLE(values)); + UNPROTECT(save); + return res; +} + static TcType *analyzeTag(LamExp *tagged, TcEnv *env, TcNg *ng) { return analyzeExp(tagged, env, ng); } @@ -686,10 +724,12 @@ static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { processLetRecBinding(bindings, env, ng); } // HACK! second pass through fixes up forward references - for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; - bindings = bindings->next) { - if (isLambdaBinding(bindings)) { - processLetRecBinding(bindings, env, ng); + if (!hadErrors()) { + for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; + bindings = bindings->next) { + if (isLambdaBinding(bindings)) { + processLetRecBinding(bindings, env, ng); + } } } TcType *res = analyzeExp(letRec->body, env, ng); @@ -734,6 +774,20 @@ static TcType *makeTcUserType(LamType *lamType, TcTypeTable *map) { return res; } +static TcType *makeTuple(int size) { + TcTypeArray *array = newTcTypeArray(); + int save = PROTECT(array); + while (size-- > 0) { + TcType *part = makeFreshVar("tuple"); + int save2 = PROTECT(part); + pushTcTypeArray(array, part); + UNPROTECT(save2); + } + TcType *res = newTcType(TCTYPE_TYPE_TUPLE, TCTYPE_VAL_TUPLE(array)); + UNPROTECT(save); + return res; +} + static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg, TcTypeTable *map); @@ -1199,6 +1253,20 @@ static TcType *freshUserType(TcUserType *userType, TcNg *ng, TcTypeTable *map) { return res; } +static TcType *freshTuple(TcTypeArray *tuple, TcNg *ng, TcTypeTable *map) { + TcTypeArray *fresh = newTcTypeArray(); + int save = PROTECT(fresh); + for (int i = 0; i < tuple->size; i ++) { + TcType *part = freshRec(tuple->entries[i], ng, map); + int save2 = PROTECT(part); + pushTcTypeArray(fresh, part); + UNPROTECT(save2); + } + TcType *res = newTcType(TCTYPE_TYPE_TUPLE, TCTYPE_VAL_TUPLE(fresh)); + UNPROTECT(save); + return res; +} + static bool isGeneric(TcType *typeVar, TcNg *ng) { while (ng != NULL) { int i = 0; @@ -1228,9 +1296,10 @@ static TcType *typeGetOrPut(TcTypeTable *map, TcType *typeVar, static TcType *freshRec(TcType *type, TcNg *ng, TcTypeTable *map) { type = prune(type); switch (type->type) { - case TCTYPE_TYPE_FUNCTION: + case TCTYPE_TYPE_FUNCTION: { TcType *res = freshFunction(type->val.function, ng, map); return res; + } case TCTYPE_TYPE_PAIR:{ TcType *res = freshPair(type->val.pair, ng, map); return res; @@ -1253,8 +1322,10 @@ static TcType *freshRec(TcType *type, TcNg *ng, TcTypeTable *map) { TcType *res = freshUserType(type->val.userType, ng, map); return res; } + case TCTYPE_TYPE_TUPLE: + return freshTuple(type->val.tuple, ng, map); default: - cant_happen("unrecognised type %d in freshRec", type->type); + cant_happen("unrecognised type %s", tcTypeTypeName(type->type)); } } @@ -1484,6 +1555,20 @@ static bool unifyPairs(TcPair *a, TcPair *b) { return res; } +static bool unifyTuples(TcTypeArray *a, TcTypeArray *b) { + if (a->size != b->size) { + can_happen("tuple sizes differ: %d vs %d", a->size, b->size); + return false; + } + bool unified = true; + for (int i = 0; i < a->size; i++) { + if (!unify(a->entries[i], b->entries[i], "tuples")) { + unified = false; + } + } + return unified; +} + static bool unifyUserTypes(TcUserType *a, TcUserType *b) { if (a->name != b->name) { can_happen("unification failed[1]"); @@ -1560,8 +1645,10 @@ static bool _unify(TcType *a, TcType *b) { return false; case TCTYPE_TYPE_USERTYPE: return unifyUserTypes(a->val.userType, b->val.userType); + case TCTYPE_TYPE_TUPLE: + return unifyTuples(a->val.tuple, b->val.tuple); default: - cant_happen("unrecognised type %d in unify", a->type); + cant_happen("unrecognised type %s", tcTypeTypeName(a->type)); } } cant_happen("reached end of unify"); @@ -1684,6 +1771,15 @@ static bool occursInUserType(TcType *var, TcUserType *userType) { return false; } +static bool occursInTuple(TcType *var, TcTypeArray *tuple) { + for (int i = 0; i < tuple->size; ++i) { + if (occursInType(var, tuple->entries[i])) { + return true; + } + } + return false; +} + static bool occursIn(TcType *a, TcType *b) { switch (b->type) { case TCTYPE_TYPE_FUNCTION: @@ -1699,7 +1795,9 @@ static bool occursIn(TcType *a, TcType *b) { return false; case TCTYPE_TYPE_USERTYPE: return occursInUserType(a, b->val.userType); + case TCTYPE_TYPE_TUPLE: + return occursInTuple(a, b->val.tuple); default: - cant_happen("unrecognised type %d in occursIn", b->type); + cant_happen("unrecognised type %s", tcTypeTypeName(b->type)); } } diff --git a/src/tpmc.yaml b/src/tpmc.yaml index 107dde7..b684bcb 100644 --- a/src/tpmc.yaml +++ b/src/tpmc.yaml @@ -96,6 +96,7 @@ unions: character: char biginteger: BigInt constructor: TpmcConstructorPattern + tuple: TpmcPatternArray TpmcStateValue: test: TpmcTestState diff --git a/src/tpmc_compare.c b/src/tpmc_compare.c index 03558ee..0566a19 100644 --- a/src/tpmc_compare.c +++ b/src/tpmc_compare.c @@ -21,6 +21,14 @@ // TODO should be able to get rid of this now we auto-generate comparison functions #include "tpmc_compare.h" +#include "tpmc_pp.h" +#include "common.h" + +#ifdef DEBUG_TPMC_COMPARE +# include "debugging_on.h" +#else +# include "debugging_off.h" +#endif #define PREAMBLE() do {\ if (a == b) { \ @@ -79,11 +87,15 @@ bool tpmcArcEq(TpmcArc *a, TpmcArc *b) { } bool tpmcArcInArray(TpmcArc *arc, TpmcArcArray *arcArray) { + DEBUGN("tpmcArcInArray: "); + IFDEBUGN(ppTpmcPattern(arc->test)); for (int i = 0; i < arcArray->size; i++) { if (tpmcArcEq(arcArray->entries[i], arc)) { + DEBUG("tpmcArcInArray returning true"); return true; } } + DEBUG("tpmcArcInArray returning false"); return false; } @@ -115,8 +127,10 @@ bool tpmcPatternValueEq(TpmcPatternValue *a, TpmcPatternValue *b) { case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: return tpmcConstructorPatternEq(a->val.constructor, b->val.constructor); + case TPMCPATTERNVALUE_TYPE_TUPLE: + return tpmcPatternArrayEq(a->val.tuple, b->val.tuple); default: - cant_happen("unrecognised type %d in tpmcPatternEq", a->type); + cant_happen("unrecognised type %s", tpmcPatternValueTypeName(a->type)); } } diff --git a/src/tpmc_logic.c b/src/tpmc_logic.c index cf87575..a268842 100644 --- a/src/tpmc_logic.c +++ b/src/tpmc_logic.c @@ -141,6 +141,18 @@ static TpmcPattern *makeConstructorPattern(AstUnpack *unpack, LamContext *env) { return pattern; } +static TpmcPattern *makeTuplePattern(AstArgList *args, LamContext *env) { + TpmcPatternArray *tuple = convertArgList(args, env); + int save = PROTECT(tuple); + TpmcPatternValue *val = + newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_TUPLE, + TPMCPATTERNVALUE_VAL_TUPLE(tuple)); + PROTECT(val); + TpmcPattern *pattern = newTpmcPattern(val); + UNPROTECT(save); + return pattern; +} + static TpmcPattern *makeBigIntegerPattern(BigInt *number) { TpmcPatternValue *val = newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_BIGINTEGER, @@ -173,13 +185,15 @@ static TpmcPattern *convertPattern(AstArg *arg, LamContext *env) { cant_happen("env arg type not supported yet in convertPattern"); case AST_ARG_TYPE_UNPACK: return makeConstructorPattern(arg->val.unpack, env); + case AST_ARG_TYPE_TUPLE: + return makeTuplePattern(arg->val.tuple, env); case AST_ARG_TYPE_NUMBER: return makeBigIntegerPattern(arg->val.number); case AST_ARG_TYPE_CHARACTER: return makeCharacterPattern(arg->val.character); default: - cant_happen("unrecognized arg type %d in convertPattern", - arg->type); + cant_happen("unrecognized arg type %s in convertPattern", + astArgTypeName(arg->type)); } } @@ -253,6 +267,19 @@ static void renameConstructorPattern(TpmcConstructorPattern *pattern, } } +static void renameTuplePattern(TpmcPatternArray *components, + HashSymbol *path) { + char buf[512]; + for (int i = 0; i < components->size; i++) { + if (snprintf(buf, 512, "%s$%d", path->name, i) >= 511) { + can_happen("maximum path depth exceeded"); + } + DEBUG("renameTuplePattern: %s", buf); + HashSymbol *newPath = newSymbol(buf); + renamePattern(components->entries[i], newPath); + } +} + static void renamePattern(TpmcPattern *pattern, HashSymbol *variable) { pattern->path = variable; switch (pattern->pattern->type) { @@ -273,8 +300,11 @@ static void renamePattern(TpmcPattern *pattern, HashSymbol *variable) { renameConstructorPattern(pattern->pattern->val.constructor, variable); break; + case TPMCPATTERNVALUE_TYPE_TUPLE: + renameTuplePattern(pattern->pattern->val.tuple, variable); + break; default: - cant_happen("unrecognised pattern type in renamePattern"); + cant_happen("unrecognised pattern type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } @@ -354,6 +384,16 @@ static TpmcPattern *replaceConstructorPattern(TpmcPattern *pattern, return pattern; } +static TpmcPattern *replaceTuplePattern(TpmcPattern *pattern, + TpmcPatternTable *seen) { + TpmcPatternArray *components = pattern->pattern->val.tuple; + for (int i = 0; i < components->size; ++i) { + components->entries[i] = + replaceComparisonPattern(components->entries[i], seen); + } + return pattern; +} + static TpmcPattern *replaceComparisonPattern(TpmcPattern *pattern, TpmcPatternTable *seen) { switch (pattern->pattern->type) { @@ -365,13 +405,15 @@ static TpmcPattern *replaceComparisonPattern(TpmcPattern *pattern, return replaceVarPattern(pattern, seen); case TPMCPATTERNVALUE_TYPE_ASSIGNMENT: return replaceAssignmentPattern(pattern, seen); + case TPMCPATTERNVALUE_TYPE_TUPLE: + return replaceTuplePattern(pattern, seen); case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: return replaceConstructorPattern(pattern, seen); case TPMCPATTERNVALUE_TYPE_COMPARISON: cant_happen ("encounterted nested comparison pattern during replaceComparisonPattern"); default: - cant_happen("unrecognised pattern type in renamePattern"); + cant_happen("unrecognised pattern type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } @@ -428,6 +470,17 @@ static TpmcPattern *collectConstructorSubstitutions(TpmcPattern *pattern, TpmcSu return pattern; } +static TpmcPattern *collectTupleSubstitutions(TpmcPattern *pattern, TpmcSubstitutionTable *substitutions) { + TpmcPatternArray *components = + pattern->pattern->val.tuple; + for (int i = 0; i < components->size; ++i) { + components->entries[i] = + collectPatternSubstitutions(components->entries[i], + substitutions); + } + return pattern; +} + static TpmcPattern *collectComparisonSubstitutions(TpmcPattern *pattern, TpmcSubstitutionTable *substitutions) { TpmcPattern *previous = pattern->pattern->val.comparison->previous; @@ -479,10 +532,12 @@ static TpmcPattern *collectPatternSubstitutions(TpmcPattern *pattern, TpmcSubsti return collectAssignmentSubstitutions(pattern, substitutions); case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: return collectConstructorSubstitutions(pattern, substitutions); + case TPMCPATTERNVALUE_TYPE_TUPLE: + return collectTupleSubstitutions(pattern, substitutions); case TPMCPATTERNVALUE_TYPE_COMPARISON: return collectComparisonSubstitutions(pattern, substitutions); default: - cant_happen("unrecognised pattern type in renamePattern"); + cant_happen("unrecognised pattern type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } diff --git a/src/tpmc_match.c b/src/tpmc_match.c index e27f7a9..72c250d 100644 --- a/src/tpmc_match.c +++ b/src/tpmc_match.c @@ -56,6 +56,10 @@ static bool patternIsWildcard(TpmcPattern *pattern) { return pattern->pattern->type == TPMCPATTERNVALUE_TYPE_WILDCARD; } +static bool patternIsComparison(TpmcPattern *pattern) { + return pattern->pattern->type == TPMCPATTERNVALUE_TYPE_COMPARISON; +} + static bool topRowOnlyVariables(TpmcMatrix *matrix) { for (int x = 0; x < matrix->width; x++) { if (!patternIsWildcard(getTpmcMatrixIndex(matrix, x, 0))) { @@ -64,6 +68,14 @@ static bool topRowOnlyVariables(TpmcMatrix *matrix) { } return true; } +static bool columnHasComparisons(int x, TpmcMatrix *matrix) { + for (int y = 0; y < matrix->height; y++) { + if (patternIsComparison(getTpmcMatrixIndex(matrix, x, y))) { + return true; + } + } + return false; +} static int findFirstConstructorColumn(TpmcMatrix *matrix) { for (int x = 0; x < matrix->width; x++) { @@ -126,8 +138,18 @@ static bool patternMatches(TpmcPattern *constructor, TpmcPattern *pattern) { pattern->pattern->val.constructor->tag) || isComparison; return res; } + case TPMCPATTERNVALUE_TYPE_TUPLE:{ + bool res = + (constructor->pattern->type == TPMCPATTERNVALUE_TYPE_TUPLE) || isComparison; + if (countTpmcPatternArray(constructor->pattern->val.tuple) != + countTpmcPatternArray(pattern->pattern->val.tuple)) { + can_happen("tuple arity mismatch"); + return false; + } + return res; + } default: - cant_happen("unrecognized pattern type %s in patternMatches", + cant_happen("unrecognized pattern type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } @@ -244,12 +266,16 @@ static TpmcStateArray *extractStateArraySubset(TpmcIntArray *indices, TpmcStateA } static int arityOf(TpmcPattern *pattern) { - if (pattern->pattern->type == TPMCPATTERNVALUE_TYPE_CONSTRUCTOR) { - LamTypeConstructorInfo *info = - pattern->pattern->val.constructor->info; - return info->arity; - } else { - return 0; + switch (pattern->pattern->type) { + case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR:{ + LamTypeConstructorInfo *info = + pattern->pattern->val.constructor->info; + return info->arity; + } + case TPMCPATTERNVALUE_TYPE_TUPLE: + return countTpmcPatternArray(pattern->pattern->val.tuple); + default: + return 0; } } @@ -274,13 +300,13 @@ static void populateSubPatternMatrixRowWithWildcards(TpmcMatrix *matrix, } } -static void populateSubPatternMatrixRowWithComponents(TpmcMatrix *matrix, +static void populateSubPatternMatrixRowWithConstructor(TpmcMatrix *matrix, int y, int arity, TpmcPattern *pattern) { if (arity != pattern->pattern->val.constructor->components->size) { ppTpmcPattern(pattern); cant_happen - ("arity %d does not match constructor arity %d in populateSubPatternMatrixRowWithComponents", + ("arity %d does not match constructor arity %d", arity, pattern->pattern->val.constructor->components->size); } for (int i = 0; i < arity; i++) { @@ -290,6 +316,21 @@ static void populateSubPatternMatrixRowWithComponents(TpmcMatrix *matrix, } } +static void populateSubPatternMatrixRowWithTuple(TpmcMatrix *matrix, + int y, int arity, + TpmcPattern *pattern) { + if (arity != countTpmcPatternArray(pattern->pattern->val.tuple)) { + ppTpmcPattern(pattern); + cant_happen + ("arity %d does not match tuple arity %d", + arity, countTpmcPatternArray(pattern->pattern->val.tuple)); + } + for (int i = 0; i < arity; i++) { + TpmcPattern *entry = pattern->pattern->val.tuple->entries[i]; + setTpmcMatrixIndex(matrix, i, y, entry); + } +} + static TpmcMatrix *makeSubPatternMatrix(TpmcPatternArray *patterns, int arity) { TpmcMatrix *matrix = newTpmcMatrix(arity, patterns->size); if (arity == 0) { @@ -308,21 +349,20 @@ static TpmcMatrix *makeSubPatternMatrix(TpmcPatternArray *patterns, int arity) { pattern); break; case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: - populateSubPatternMatrixRowWithComponents(matrix, i, arity, + populateSubPatternMatrixRowWithConstructor(matrix, i, arity, pattern); break; + case TPMCPATTERNVALUE_TYPE_TUPLE: + populateSubPatternMatrixRowWithTuple(matrix, i, arity, pattern); + break; case TPMCPATTERNVALUE_TYPE_ASSIGNMENT: - cant_happen - ("encountered pattern type assignment during makeSubPatternMatrix"); + cant_happen("encountered pattern type assignment"); case TPMCPATTERNVALUE_TYPE_CHARACTER: - cant_happen - ("encountered pattern type char during makeSubPatternMatrix"); + cant_happen("encountered pattern type char"); case TPMCPATTERNVALUE_TYPE_BIGINTEGER: - cant_happen - ("encountered pattern type int during makeSubPatternMatrix"); + cant_happen("encountered pattern type int"); default: - cant_happen - ("unrecognised pattern type %s during makeSubPatternMatrix", + cant_happen("unrecognised pattern type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } @@ -330,41 +370,73 @@ static TpmcMatrix *makeSubPatternMatrix(TpmcPatternArray *patterns, int arity) { return matrix; } -static TpmcPattern *replaceComponentsWithWildcards(TpmcPattern *pattern) { - if (pattern->pattern->type == TPMCPATTERNVALUE_TYPE_CONSTRUCTOR) { - TpmcConstructorPattern *constructor = - pattern->pattern->val.constructor; - if (constructor->components->size > 0) { - TpmcPatternArray *components = - newTpmcPatternArray("replaceComponentsWithWildcards"); - int save = PROTECT(components); - for (int i = 0; i < constructor->components->size; i++) { - TpmcPatternValue *wc = - newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_WILDCARD, - TPMCPATTERNVALUE_VAL_WILDCARD()); - int save2 = PROTECT(wc); - TpmcPattern *replacement = newTpmcPattern(wc); - PROTECT(replacement); - replacement->path = constructor->components->entries[i]->path; - pushTpmcPatternArray(components, replacement); - UNPROTECT(save2); +static TpmcPatternArray *replaceComponentsWithWildcards(TpmcPatternArray *components) { + ENTER(replaceComponentsWithWildcards); + TpmcPatternArray *result = + newTpmcPatternArray("replaceComponentsWithWildcards"); + int save = PROTECT(result); + for (int i = 0; i < components->size; i++) { + DEBUG("i = %d, size = %d", i, components->size); + TpmcPatternValue *wc = + newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_WILDCARD, + TPMCPATTERNVALUE_VAL_WILDCARD()); + int save2 = PROTECT(wc); + TpmcPattern *replacement = newTpmcPattern(wc); + PROTECT(replacement); + replacement->path = components->entries[i]->path; + pushTpmcPatternArray(result, replacement); + UNPROTECT(save2); + } + UNPROTECT(save); + LEAVE(replaceComponentsWithWildcards); + return result; +} + +static TpmcPattern *replacePatternComponentsWithWildcards(TpmcPattern *pattern) { + switch (pattern->pattern->type) { + case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: { + TpmcConstructorPattern *constructor = + pattern->pattern->val.constructor; + if (constructor->components->size > 0) { + TpmcPatternArray *components = replaceComponentsWithWildcards(constructor->components); + int save = PROTECT(components); + TpmcConstructorPattern *newCons = + newTpmcConstructorPattern(constructor->tag, constructor->info, + components); + PROTECT(newCons); + TpmcPatternValue *patternValue = + newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_CONSTRUCTOR, + TPMCPATTERNVALUE_VAL_CONSTRUCTOR + (newCons)); + PROTECT(patternValue); + TpmcPattern *replacement = newTpmcPattern(patternValue); + replacement->path = pattern->path; + UNPROTECT(save); + return replacement; + } else { + return pattern; + } + } + case TPMCPATTERNVALUE_TYPE_TUPLE: { + TpmcPatternArray *tuple = pattern->pattern->val.tuple; + if (tuple->size > 0) { + TpmcPatternArray *components = replaceComponentsWithWildcards(tuple); + int save = PROTECT(components); + TpmcPatternValue *patternValue = + newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_TUPLE, + TPMCPATTERNVALUE_VAL_TUPLE(components)); + PROTECT(patternValue); + TpmcPattern *replacement = newTpmcPattern(patternValue); + replacement->path = pattern->path; + UNPROTECT(save); + return replacement; + } else { + return pattern; } - TpmcConstructorPattern *newCons = - newTpmcConstructorPattern(constructor->tag, constructor->info, - components); - PROTECT(newCons); - TpmcPatternValue *patternValue = - newTpmcPatternValue(TPMCPATTERNVALUE_TYPE_CONSTRUCTOR, - TPMCPATTERNVALUE_VAL_CONSTRUCTOR - (newCons)); - PROTECT(patternValue); - TpmcPattern *replacement = newTpmcPattern(patternValue); - replacement->path = pattern->path; - UNPROTECT(save); - return replacement; } + default: + return pattern; } - return pattern; } static TpmcIntArray *makeTpmcIntArray(int size, int initialValue) { @@ -476,8 +548,13 @@ static void collectPathsBoundByPattern(TpmcPattern *pattern, collectPathsBoundByConstructor(components, boundVariables); } break; + case TPMCPATTERNVALUE_TYPE_TUPLE:{ + TpmcPatternArray *components = pattern->pattern->val.tuple; + collectPathsBoundByConstructor(components, boundVariables); + } + break; default: - cant_happen("unrecognised type %s in collectPathsBoundByPattern", + cant_happen("unrecognised type %s", tpmcPatternValueTypeName(pattern->pattern->type)); } } @@ -584,7 +661,10 @@ static TpmcState *mixture(TpmcMatrix *M, TpmcStateArray *finalStates, ENTER(mixture); // there is some column N whose topmost pattern is a constructor int firstConstructorColumn = findFirstConstructorColumn(M); - firstConstructorColumn = 0; + // this heuristic allows for comparisons to work: + if (firstConstructorColumn > 0 && columnHasComparisons(firstConstructorColumn, M)) { + firstConstructorColumn = 0; + } TpmcPatternArray *N = extractMatrixColumn(firstConstructorColumn, M); int save = PROTECT(N); // let M-N be all the columns in M except N @@ -599,29 +679,29 @@ static TpmcState *mixture(TpmcMatrix *M, TpmcStateArray *finalStates, // For each constructor c in the selected column, its arc is defined as follows: if (!patternIsWildcard(c)) { // Let {i1 , ... , ij} be the row-indices of the patterns in N that match c. - TpmcIntArray *matchingIndices = findPatternsMatching(c, N); - int save2 = PROTECT(matchingIndices); + TpmcIntArray *indicesMatchingC = findPatternsMatching(c, N); + int save2 = PROTECT(indicesMatchingC); // Let {pat1 , ... , patj} be the patterns in the column corresponding to the indices computed above, - TpmcPatternArray *matchingPatterns = extractColumnSubset(matchingIndices, N); - PROTECT(matchingPatterns); + TpmcPatternArray *patternsMatchingC = extractColumnSubset(indicesMatchingC, N); + PROTECT(patternsMatchingC); // let n be the arity of the constructor c int n = arityOf(c); // For each pati, its n sub-patterns are extracted; // if pati is a wildcard, n wildcards are produced instead, each tagged with the right path variable. - TpmcMatrix *subPatternMatrix = makeSubPatternMatrix(matchingPatterns, n); - PROTECT(subPatternMatrix); + TpmcMatrix *subPatternsMatchingC = makeSubPatternMatrix(patternsMatchingC, n); + PROTECT(subPatternsMatchingC); // This matrix is then appended to the result of selecting, from each column in MN, // those rows whose indices are in {i1 , ... , ij}. - TpmcMatrix *prefixMatrix = extractMatrixRows(matchingIndices, MN); + TpmcMatrix *prefixMatrix = extractMatrixRows(indicesMatchingC, MN); PROTECT(prefixMatrix); - TpmcMatrix *newMatrix = appendMatrices(prefixMatrix, subPatternMatrix); + TpmcMatrix *newMatrix = appendMatrices(prefixMatrix, subPatternsMatchingC); PROTECT(newMatrix); // Finally the indices are used to select the corresponding final states that go with these rows. - TpmcStateArray *newFinalStates = extractStateArraySubset(matchingIndices, finalStates); + TpmcStateArray *newFinalStates = extractStateArraySubset(indicesMatchingC, finalStates); PROTECT(newFinalStates); // The arc for the constructor c is now defined as (c’,state), where c’ is c with any immediate // sub-patterns replaced by their path variables (thus c’ is a simple pattern) - TpmcPattern *cPrime = replaceComponentsWithWildcards(c); + TpmcPattern *cPrime = replacePatternComponentsWithWildcards(c); PROTECT(cPrime); // and state is the result of recursively applying match to the new matrix and the new sequence of final states TpmcState *newState = match(newMatrix, newFinalStates, errorState, knownStates); diff --git a/src/tpmc_mermaid.c b/src/tpmc_mermaid.c index a0a3337..2c642ea 100644 --- a/src/tpmc_mermaid.c +++ b/src/tpmc_mermaid.c @@ -73,7 +73,7 @@ static char *mermaidStateName(TpmcState *state) { printf("%s(\"%s\\n", buf, state->state->val.test->path->name); mermaidFreeVariables(state->freeVariables); - printf("\")\n"); + printf("\\n(arcs %d)\")\n", countTpmcArcArray(state->state->val.test->arcs)); } break; case TPMCSTATEVALUE_TYPE_FINAL: @@ -135,9 +135,13 @@ static void mermaidPattern(TpmcPattern *pattern) { mermaidConstructorComponents(value->val.constructor->components); printf(")"); break; + case TPMCPATTERNVALUE_TYPE_TUPLE: + printf("#("); + mermaidConstructorComponents(value->val.tuple); + printf(")"); + break; default: - cant_happen("unrecognised type %d in mermaidArcLabel", - value->type); + cant_happen("unrecognised type %s", tpmcPatternValueTypeName(value->type)); } } diff --git a/src/tpmc_pp.c b/src/tpmc_pp.c index edb7fe4..ef01313 100644 --- a/src/tpmc_pp.c +++ b/src/tpmc_pp.c @@ -35,6 +35,11 @@ void ppTpmcConstructorPattern(TpmcConstructorPattern *constructorPattern) { ppTpmcPatternArray(constructorPattern->components); } +void ppTpmcTuplePattern(TpmcPatternArray *tuple) { + eprintf("#"); + ppTpmcPatternArray(tuple); +} + void ppTpmcPatternArray(TpmcPatternArray *patternArray) { eprintf("("); int i = 0; @@ -73,6 +78,9 @@ void ppTpmcPatternValue(TpmcPatternValue *patternValue) { case TPMCPATTERNVALUE_TYPE_BIGINTEGER: fprintBigInt(errout, patternValue->val.biginteger); break; + case TPMCPATTERNVALUE_TYPE_TUPLE: + ppTpmcTuplePattern(patternValue->val.tuple); + break; case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: ppTpmcConstructorPattern(patternValue->val.constructor); break; diff --git a/src/tpmc_translate.c b/src/tpmc_translate.c index ea87c4b..b47ae30 100644 --- a/src/tpmc_translate.c +++ b/src/tpmc_translate.c @@ -20,9 +20,11 @@ #include #include +#include #include "lambda.h" #include "lambda_helper.h" #include "tpmc.h" +#include "tpmc_pp.h" #include "symbol.h" #include "common.h" @@ -87,7 +89,6 @@ static LamList *convertVarListToList(LamVarList *vars) { LamList *next = convertVarListToList(vars->next); int save = PROTECT(next); LamExp *exp = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(vars->var)); - DEBUG("[newLamExp]"); PROTECT(exp); LamList *this = newLamList(exp, next); UNPROTECT(save); @@ -98,7 +99,6 @@ static LamList *convertVarListToList(LamVarList *vars) { static LamExp *translateToApply(HashSymbol *name, TpmcState *dfa) { ENTER(translateToApply); LamExp *function = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(name)); - DEBUG("[newLamExp]"); int save = PROTECT(function); LamVarList *cargs = makeCanonicalArgs(dfa->freeVariables); PROTECT(cargs); @@ -109,7 +109,6 @@ static LamExp *translateToApply(HashSymbol *name, TpmcState *dfa) { args); PROTECT(apply); LamExp *res = newLamExp(LAMEXP_TYPE_APPLY, LAMEXP_VAL_APPLY(apply)); - DEBUG("[newLamExp]"); UNPROTECT(save); LEAVE(translateToApply); return res; @@ -125,7 +124,6 @@ static LamExp *translateToLambda(TpmcState *dfa, LamExpTable *lambdaCache) { newLamLam(countTpmcVariableTable(dfa->freeVariables), args, exp); PROTECT(lambda); LamExp *res = newLamExp(LAMEXP_TYPE_LAM, LAMEXP_VAL_LAM(lambda)); - DEBUG("[newLamExp]"); UNPROTECT(save); LEAVE(translateToLambda); return res; @@ -162,16 +160,13 @@ static LamExp *translateComparisonArcToTest(TpmcArc *arc) { TpmcComparisonPattern *pattern = arc->test->pattern->val.comparison; LamExp *a = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(pattern->previous->path)); - DEBUG("[newLamExp]"); int save = PROTECT(a); LamExp *b = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(pattern->current->path)); - DEBUG("[newLamExp]"); PROTECT(b); LamPrimApp *eq = newLamPrimApp(LAMPRIMOP_TYPE_EQ, a, b); PROTECT(eq); LamExp *res = newLamExp(LAMEXP_TYPE_PRIM, LAMEXP_VAL_PRIM(eq)); - DEBUG("[newLamExp]"); UNPROTECT(save); LEAVE(translateComparisonArcToTest); return res; @@ -181,56 +176,85 @@ static LamExp *prependLetBindings(TpmcPattern *test, TpmcVariableTable *freeVariables, LamExp *body) { ENTER(prependLetBindings); -#ifdef SAFETY_CHECKS - if (test->pattern->type != TPMCPATTERNVALUE_TYPE_CONSTRUCTOR) { - cant_happen("prependLetBindings passed non-constructor %d", - test->pattern->type); - } -#endif - TpmcConstructorPattern *constructor = test->pattern->val.constructor; - if (constructor->components->size == 0) { - LEAVE(prependLetBindings); - return body; - } - HashSymbol *name = constructor->info->type->name; int save = PROTECT(body); - DEBUG("constructor %s has size %d", name->name, constructor->components->size); - for (int i = 0; i < constructor->components->size; i++) { - HashSymbol *path = constructor->components->entries[i]->path; - DEBUG("considering variable %s", path->name); - if (getTpmcVariableTable(freeVariables, path)) { - DEBUG("%s is free", path->name); - LamExp *base = - newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(test->path)); - int save2 = PROTECT(base); - PROTECT(base); - LamDeconstruct *deconstruct = - newLamDeconstruct(name, i + 1, base); - PROTECT(deconstruct); - LamExp *deconstructExp = newLamExp(LAMEXP_TYPE_DECONSTRUCT, - LAMEXP_VAL_DECONSTRUCT - (deconstruct)); - PROTECT(deconstructExp); - LamLet *let = newLamLet(path, deconstructExp, body); - PROTECT(let); - body = newLamExp(LAMEXP_TYPE_LET, LAMEXP_VAL_LET(let)); - REPLACE_PROTECT(save, body); - UNPROTECT(save2); - } else { - DEBUG("%s is not free", path->name); + switch (test->pattern->type) { + case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: { + TpmcConstructorPattern *constructor = test->pattern->val.constructor; + TpmcPatternArray *components = constructor->components; + HashSymbol *name = constructor->info->type->name; + DEBUG("constructor %s has size %d", name->name, components->size); + IFDEBUG(ppTpmcConstructorPattern(constructor)); + for (int i = 0; i < components->size; i++) { + HashSymbol *path = components->entries[i]->path; + DEBUG("considering variable %s", path->name); + if (getTpmcVariableTable(freeVariables, path)) { + DEBUG("%s is free", path->name); + LamExp *base = + newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(test->path)); + int save2 = PROTECT(base); + LamDeconstruct *deconstruct = + newLamDeconstruct(name, i + 1, base); + PROTECT(deconstruct); + LamExp *deconstructExp = newLamExp(LAMEXP_TYPE_DECONSTRUCT, + LAMEXP_VAL_DECONSTRUCT + (deconstruct)); + PROTECT(deconstructExp); + LamLet *let = newLamLet(path, deconstructExp, body); + PROTECT(let); + body = newLamExp(LAMEXP_TYPE_LET, LAMEXP_VAL_LET(let)); + REPLACE_PROTECT(save, body); + UNPROTECT(save2); + } else { + DEBUG("%s is not free", path->name); + } + } + } + break; + case TPMCPATTERNVALUE_TYPE_TUPLE: { + TpmcPatternArray *components = test->pattern->val.tuple; + int size = components->size; + for (int i = 0; i < size; i++) { + HashSymbol *path = components->entries[i]->path; + if (getTpmcVariableTable(freeVariables, path)) { + LamExp *base = + newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(test->path)); + int save2 = PROTECT(base); + LamTupleIndex *index = newLamTupleIndex(i, size, base); + PROTECT(index); + LamExp *tupleIndex = newLamExp(LAMEXP_TYPE_TUPLE_INDEX, + LAMEXP_VAL_TUPLE_INDEX(index)); + PROTECT(tupleIndex); + LamLet *let = newLamLet(path, tupleIndex, body); + PROTECT(let); + body = newLamExp(LAMEXP_TYPE_LET, LAMEXP_VAL_LET(let)); + REPLACE_PROTECT(save, body); + UNPROTECT(save2); + } + } } + break; + default: + cant_happen("prependLetBindings passed non-constructor %s", + tpmcPatternValueTypeName(test->pattern->type)); } LEAVE(prependLetBindings); + UNPROTECT(save); return body; } static LamExp *translateArcToCode(TpmcArc *arc, LamExpTable *lambdaCache) { ENTER(translateArcToCode); LamExp *res = translateState(arc->state, lambdaCache); - if (arc->test->pattern->type == TPMCPATTERNVALUE_TYPE_CONSTRUCTOR) { - int save = PROTECT(res); - res = prependLetBindings(arc->test, arc->state->freeVariables, res); - UNPROTECT(save); + switch (arc->test->pattern->type) { + case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: + case TPMCPATTERNVALUE_TYPE_TUPLE: { + int save = PROTECT(res); + res = prependLetBindings(arc->test, arc->state->freeVariables, res); + UNPROTECT(save); + } + break; + default: + // no-op } LEAVE(translateArcToCode); return res; @@ -249,7 +273,6 @@ static LamExp *translateComparisonArcAndAlternativeToIf(TpmcArc *arc, LamExpTabl LamIff *iff = newLamIff(test, consequent, alternative); PROTECT(iff); LamExp *res = newLamExp(LAMEXP_TYPE_IFF, LAMEXP_VAL_IFF(iff)); - DEBUG("[newLamExp]"); UNPROTECT(save); LEAVE(translateComparisonArcAndAlternativeToIf); return res; @@ -325,7 +348,6 @@ static LamExp *translateTestState(TpmcTestState *testState, int save = PROTECT(arcList); LamExp *testVar = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(testState->path)); - DEBUG("[newLamExp]"); PROTECT(testVar); LamExp *res = translateArcList(arcList, testVar, lambdaCache); UNPROTECT(save); @@ -449,9 +471,12 @@ static LamExp *translateArcList(TpmcArcList *arcList, LamExp *testVar, UNPROTECT(save); break; } + case TPMCPATTERNVALUE_TYPE_TUPLE:{ + res = translateArcToCode(arcList->arc, lambdaCache); + break; + } default: - cant_happen("unrecognized pattern type %d in translateArcList", - arcList->arc->test->pattern->type); + cant_happen("unrecognised type %s", tpmcPatternValueTypeName(arcList->arc->test->pattern->type)); } LEAVE(translateArcList); return res; @@ -673,7 +698,6 @@ static LamExp *translateStateToInlineCode(TpmcState *dfa, break; case TPMCSTATEVALUE_TYPE_ERROR: res = newLamExp(LAMEXP_TYPE_ERROR, LAMEXP_VAL_ERROR()); - DEBUG("[newLamExp]"); break; default: cant_happen("unrecognised state type %d in tpmcTranslate", @@ -759,14 +783,15 @@ static LamExp *prependLetRec(LamExpTable *lambdaCache, LamExp *body) { LamLetRec *letrec = newLamLetRec(nbindings, bindings, body); PROTECT(letrec); LamExp *res = newLamExp(LAMEXP_TYPE_LETREC, LAMEXP_VAL_LETREC(letrec)); - DEBUG("[newLamExp]"); UNPROTECT(save); LEAVE(prependLetRec); return res; } LamExp *tpmcTranslate(TpmcState *dfa) { + // IFDEBUG(system("clear")); ENTER(tpmcTranslate); + IFDEBUG(ppTpmcState(dfa)); LamExpTable *lambdaCache = newLamExpTable(); int save = PROTECT(lambdaCache); recalculateRefCounts(dfa); diff --git a/tools/makeAST.py b/tools/makeAST.py index 2996ca0..b48a3a3 100644 --- a/tools/makeAST.py +++ b/tools/makeAST.py @@ -727,6 +727,12 @@ def __init__(self, name, data): self.height = SimpleField(self.name, "height", "int") self.entries = SimpleField(self.name,"entries", data["entries"]) + def getDefineValue(self): + return 'x' + + def getDefineArg(self): + return 'x' + def tag(self): super().tag() self.tagField = SimpleField(self.name, "_tag", "string") @@ -736,8 +742,9 @@ def getTypeDeclaration(self): def printCompareField(self, field, depth, prefix=''): myName=self.getName() + extraCmpArgs = self.getExtraCmpAargs(catalog) pad(depth) - print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field})) return false; // SimpleArray.printCompareField") + print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field}{extraCmpArgs})) return false; // SimpleArray.printCompareField") def printCopyField(self, field, depth, prefix=''): myName=self.getName() @@ -952,6 +959,14 @@ def getExtraCmpFargs(self, catalog): return ", " + ", ".join(extra) return "" + def getExtraCmpAargs(self, catalog): + extra = [] + for name in self.extraCmpArgs: + extra += [name] + if len(extra) > 0: + return ", " + ", ".join(extra) + return "" + def getCompareSignature(self, catalog): myType = self.getTypeDeclaration() myName = self.getName() @@ -1737,6 +1752,30 @@ def printCopyField(self, field, depth, prefix=''): pad(depth) print(f'x->{field} = o->{field}; // SimpleEnum.printCopyField') + def getNameFunctionDeclaration(self): + name = self.getName(); + camel = name[0].lower() + name[1:] + return f"char * {camel}Name(enum {name} type)" + + def printNameFunctionDeclaration(self): + decl = self.getNameFunctionDeclaration() + print(f"{decl}; // SimpleEnum.printNameFunctionDeclaration") + + def printNameFunctionBody(self): + decl = self.getNameFunctionDeclaration() + comment = '// SimpleEnum.printNameFunctionDeclaration' + print(f"{decl} {{ {comment}") + print(f" switch(type) {{ {comment}") + for field in self.fields: + field.printNameFunctionLine() + print(f" default: {{ {comment}") + print(f" static char buf[64]; {comment}") + print(f' sprintf(buf, "%d", type); {comment}') + print(f" return buf; {comment}"); + print(f" }} {comment}") + print(f" }} {comment}") + print(f"}} {comment}") + print("") class DiscriminatedUnionEnum(Base): @@ -1765,17 +1804,18 @@ def printNameFunctionDeclaration(self): def printNameFunctionBody(self): decl = self.getNameFunctionDeclaration() - print(f"{decl} {{ // DiscriminatedUnionEnum.printNameFunctionDeclaration") - print(" switch(type) {") + comment = '// DiscriminatedUnionEnum.printNameFunctionDeclaration' + print(f"{decl} {{ {comment}") + print(f" switch(type) {{ {comment}") for field in self.fields: field.printNameFunctionLine() - print(" default: {") - print(" static char buf[64];") - print(' sprintf(buf, "%d", type);') - print(" return buf;"); - print(" }") - print(" }") - print("}") + print(f" default: {{ {comment}") + print(f" static char buf[64]; {comment}") + print(f' sprintf(buf, "%d", type); {comment}') + print(f" return buf; {comment}"); + print(f" }} {comment}") + print(f" }} {comment}") + print(f"}} {comment}") print("") def getTypeDeclaration(self):