From 2f1094deb65803e460aa09cad8ae0cf1144f7c00 Mon Sep 17 00:00:00 2001 From: Bill Hails Date: Tue, 26 Dec 2023 16:31:27 +0000 Subject: [PATCH 1/4] replacement in progress --- docs/LICENSES/README.md | 5 + docs/{LICENSE.bigint => LICENSES/bigint.md} | 7 +- .../reentrant_flex_bison_parser.md} | 4 +- docs/LICENSES/type-inference.md | 31 + docs/LambdaConversion.drawio | 488 ++++++- src/algorithm_W.c | 1075 -------------- src/{debug_ast.c => ast_debug.c} | 395 ++++- src/{debug_ast.h => ast_debug.h} | 36 +- src/common.h | 4 +- src/debug_tin.c | 177 --- src/debug_tin.h | 38 - src/debugging_on.h | 2 +- src/hash.c | 8 + src/hash.h | 1 + src/lambda_conversion.c | 2 +- src/{debug_lambda.c => lambda_debug.c} | 553 ++++++- src/{debug_lambda.h => lambda_debug.h} | 43 +- src/lambda_helper.h | 2 +- src/main.c | 45 +- src/memory.c | 24 +- src/memory.h | 7 +- src/tc.c | 301 ++++ src/tc.h | 137 ++ src/tc.yaml | 81 ++ src/tc_analyze.c | 1295 +++++++++++++++++ src/{algorithm_W.h => tc_analyze.h} | 20 +- src/tc_debug.c | 245 ++++ src/tc_debug.h | 45 + src/tc_helper.c | 86 ++ src/tc_helper.h | 30 + src/tc_objtypes.h | 49 + src/tin.c | 362 ----- src/tin.h | 158 -- src/tin.yaml | 87 -- src/tin_helper.c | 781 ---------- src/tin_helper.h | 51 - src/tin_objtypes.h | 53 - src/{debug_tpmc.c => tpmc_debug.c} | 262 +++- src/{debug_tpmc.h => tpmc_debug.h} | 28 +- src/tpmc_logic.c | 2 +- src/tpmc_match.c | 4 +- tests/src/test.h | 6 +- tests/src/test_typechecker.c | 32 +- tools/makeAST.py | 136 +- 44 files changed, 4200 insertions(+), 2998 deletions(-) create mode 100644 docs/LICENSES/README.md rename docs/{LICENSE.bigint => LICENSES/bigint.md} (94%) rename docs/{LICENSE.reentrant_flex_bison_parser.md => LICENSES/reentrant_flex_bison_parser.md} (93%) create mode 100644 docs/LICENSES/type-inference.md delete mode 100644 src/algorithm_W.c rename src/{debug_ast.c => ast_debug.c} (53%) rename src/{debug_ast.h => ast_debug.h} (55%) delete mode 100644 src/debug_tin.c delete mode 100644 src/debug_tin.h rename src/{debug_lambda.c => lambda_debug.c} (54%) rename src/{debug_lambda.h => lambda_debug.h} (53%) create mode 100644 src/tc.c create mode 100644 src/tc.h create mode 100644 src/tc.yaml create mode 100644 src/tc_analyze.c rename src/{algorithm_W.h => tc_analyze.h} (66%) create mode 100644 src/tc_debug.c create mode 100644 src/tc_debug.h create mode 100644 src/tc_helper.c create mode 100644 src/tc_helper.h create mode 100644 src/tc_objtypes.h delete mode 100644 src/tin.c delete mode 100644 src/tin.h delete mode 100644 src/tin.yaml delete mode 100644 src/tin_helper.c delete mode 100644 src/tin_helper.h delete mode 100644 src/tin_objtypes.h rename src/{debug_tpmc.c => tpmc_debug.c} (53%) rename src/{debug_tpmc.h => tpmc_debug.h} (56%) diff --git a/docs/LICENSES/README.md b/docs/LICENSES/README.md new file mode 100644 index 0000000..7c1f626 --- /dev/null +++ b/docs/LICENSES/README.md @@ -0,0 +1,5 @@ +# LICENSES + +This directory contains licenses for projects whose code I have used, +copied or made extensive reference to while working on CEKF. + diff --git a/docs/LICENSE.bigint b/docs/LICENSES/bigint.md similarity index 94% rename from docs/LICENSE.bigint rename to docs/LICENSES/bigint.md index 3b27410..24d7665 100644 --- a/docs/LICENSE.bigint +++ b/docs/LICENSES/bigint.md @@ -1,6 +1,11 @@ The following is the license that came with the bigint code I'm using. It does *not* apply to the CEKF project as a whole. +Original project [983/bigint](https://github.com/983/bigint). + +--- + +``` This is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or @@ -25,4 +30,4 @@ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to - +``` diff --git a/docs/LICENSE.reentrant_flex_bison_parser.md b/docs/LICENSES/reentrant_flex_bison_parser.md similarity index 93% rename from docs/LICENSE.reentrant_flex_bison_parser.md rename to docs/LICENSES/reentrant_flex_bison_parser.md index 83c30f6..110f152 100644 --- a/docs/LICENSE.reentrant_flex_bison_parser.md +++ b/docs/LICENSES/reentrant_flex_bison_parser.md @@ -1,10 +1,11 @@ This is the license that comes with the re-entrant parser from -[gfrey](https://github.com/gfrey/reentrant_flex_bison_parser). +[gfrey/reentrant_flex_bison_parser](https://github.com/gfrey/reentrant_flex_bison_parser). It does *not* apply to the whole of the CEKF project. ---- +``` Copyright (c) 2015, Gereon Frey All rights reserved. @@ -30,3 +31,4 @@ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/docs/LICENSES/type-inference.md b/docs/LICENSES/type-inference.md new file mode 100644 index 0000000..1e0c939 --- /dev/null +++ b/docs/LICENSES/type-inference.md @@ -0,0 +1,31 @@ +The following is the license that comes with the type-inference project +that I'm using as the basis for my type inference code. It does *not* apply +to the entire CEKF project. + +Original source [k-mrm/type-inference](https://github.com/k-mrm/type-inference/tree/master). + +--- + +``` +MIT License + +Copyright (c) 2019 mkei + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` diff --git a/docs/LambdaConversion.drawio b/docs/LambdaConversion.drawio index 361917f..75c85c6 100644 --- a/docs/LambdaConversion.drawio +++ b/docs/LambdaConversion.drawio @@ -1,6 +1,6 @@ - + - + @@ -242,7 +242,7 @@ - + @@ -296,7 +296,7 @@ - + @@ -305,37 +305,37 @@ - + - + - + - + - + - + - + - + - + - + - + @@ -343,16 +343,16 @@ - + - + - + - + @@ -360,22 +360,22 @@ - + - + - + - + - + - + @@ -383,24 +383,454 @@ - + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/algorithm_W.c b/src/algorithm_W.c deleted file mode 100644 index 355c2b7..0000000 --- a/src/algorithm_W.c +++ /dev/null @@ -1,1075 +0,0 @@ -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * This file contains an implementation of the Hindley-Milner Algorithm W - * for type inference. - */ - -#include - -#include "algorithm_W.h" -#include "tin_helper.h" -#include "symbol.h" -#include "symbols.h" -#include "debug_tin.h" -#include "debug_ast.h" - -static TinMonoType *astTypeToMonoType(AstType *type); -static TinMonoType *constantTypeFunction(HashSymbol *name); -static void addMonoTypeToContext(TinContext *context, HashSymbol *symbol, TinMonoType *monoType); -static void generalizeMonoTypeToContext(TinContext *context, HashSymbol *symbol, TinMonoType *monoType); - -static WResult *WNest(TinContext *context, AstNest *nest, int depth); -static WResult *WExpression(TinContext *context, AstExpression *expr, int depth); -static WResult *WApplication(TinContext *context, AstFunCall *funCall, int depth); -static WResult *WFarg(TinContext *context, AstArg *arg, int depth); -static void addVars(TinContext *context, AstArgList *args); -static WResult *WIff(TinContext *context, AstIff *iff, int depth); - -#ifdef DEBUG_ALGORITHM_W -static int idSource = 0; - -static void enter(char *name, int id, int depth) { - bool newline = true; - if (depth < 0) { - newline = false; - depth = -depth; - } - int pad = depth * 4 - 8; - if (pad < 0) pad = 0; - eprintf("[%03d]>>>%*s%s", id, pad, "", name); - if (newline) - eprintf("\n"); - else - eprintf(" "); -} - -static void leave(char *name, int id, int depth) { - int pad = depth * 4 - 8; - if (pad < 0) pad = 0; - eprintf("[%03d]<<<%*s%s\n", id, pad, "", name); -} -#endif - -void markWResult(WResult *result) { - if (result == NULL) return; - if (MARKED(result)) return; - MARK(result); - markTinSubstitution(result->substitution); - markTinMonoType(result->monoType); -} - -void freeWResult(WResult *x) { - FREE(x, WResult); -} - -void markWResultObj(Header *h) { - markWResult((WResult *)h ); -} - -void freeWResultObj(Header *h) { - freeWResult((WResult *)h ); -} - -void printWResult(WResult *result, int depth) { - bool save = quietPrintHashTable; - quietPrintHashTable = false; - eprintf("%*sWResult [\n", depth * 4, ""); - printTinSubstitution(result->substitution, depth + 1); - eprintf(",\n%*s", (depth + 1) * 4, ""); - showTinMonoType(result->monoType); - eprintf("\n"); - eprintf("%*s]", depth * 4, ""); - quietPrintHashTable = save; -} - -static WResult *newWResult(TinSubstitution *substitution, TinMonoType *monoType) { - WResult *x = NEW(WResult, OBJTYPE_WRESULT); - x->substitution = substitution; - x->monoType = monoType; - return x; -} - -static TinPolyType *monoToPolyType(TinMonoType *monoType) { - return newTinPolyType( - TINPOLYTYPE_TYPE_MONOTYPE, - TINPOLYTYPE_VAL_MONOTYPE(monoType) - ); -} - -static void generalizeMonoTypeToContext(TinContext *context, HashSymbol *symbol, TinMonoType *monoType) { - TinPolyType *polyType = generalize(context, monoType); int save = PROTECT(polyType); - addVarToContext(context, symbol, polyType); UNPROTECT(save); -} - -static void generalizeTypeConstructorToContext(TinContext *context, HashSymbol *symbol, TinMonoType *monoType) { - TinPolyType *polyType = generalize(context, monoType); int save = PROTECT(polyType); - addConstructorToContext(context, symbol, polyType); UNPROTECT(save); -} - -static void addMonoTypeToContext(TinContext *context, HashSymbol *symbol, TinMonoType *monoType) { - TinPolyType *polyType = monoToPolyType(monoType); int save = PROTECT(polyType); - addVarToContext(context, symbol, polyType); UNPROTECT(save); -} - -static TinMonoType *monoTypeFunApp(TinFunctionApplication *funApp) { - return newTinMonoType( - TINMONOTYPE_TYPE_FUN, - TINMONOTYPE_VAL_FUN(funApp) - ); -} - -static TinMonoType *monoTypeVar(HashSymbol *var) { - return newTinMonoType( - TINMONOTYPE_TYPE_VAR, - TINMONOTYPE_VAL_VAR(var) - ); -} - -static TinMonoType *freshMonoTypeVar(const char *prefix) { - return monoTypeVar(freshTypeVariable(prefix)); -} - -static TinFunctionApplication *constantFunApp(HashSymbol *name) { - return newTinFunctionApplication(name, 0, NULL); -} - -static TinMonoType *anyArrowApplication(HashSymbol *arrow, TinMonoType *type1, TinMonoType *type2) { - TinMonoTypeList *args = newTinMonoTypeList(type2, NULL); int save = PROTECT(args); - args = newTinMonoTypeList(type1, args); UNPROTECT(save); save = PROTECT(args); - TinFunctionApplication *funApp = newTinFunctionApplication(arrow, 2, args); UNPROTECT(save); save = PROTECT(funApp); - TinMonoType *result = monoTypeFunApp(funApp); UNPROTECT(save); - return result; -} - -static TinMonoType *arrowApplication(TinMonoType *type1, TinMonoType *type2) { - HashSymbol *arrow = arrowSymbol(); - return anyArrowApplication(arrow, type1, type2); -} - -static void collectDefine(AstDefine *define, TinContext *context) { - TinMonoType *var = freshMonoTypeVar(define->symbol->name); int save = PROTECT(var); - addMonoTypeToContext(context, define->symbol, var); UNPROTECT(save); -} - -static void collectPrototype(AstPrototype *prototype, TinContext *context) { - TinMonoType *var = freshMonoTypeVar(prototype->symbol->name); int save = PROTECT(var); - addMonoTypeToContext(context, prototype->symbol, var); UNPROTECT(save); -} - -/* -static void collectLoad(AstLoad *load, TinContext *context) { -} -*/ - -static TinMonoTypeList *typeSymbolsToMonoTypeList(AstTypeSymbols *typeSymbols) { - if (typeSymbols == NULL) return NULL; - TinMonoTypeList *next = typeSymbolsToMonoTypeList(typeSymbols->next); int save = PROTECT(next); - TinMonoType *monoType = monoTypeVar(typeSymbols->typeSymbol); (void) PROTECT(monoType); - TinMonoTypeList *this = newTinMonoTypeList(monoType, next); UNPROTECT(save); - return this; -} - -static int countTypeFunctionArgs(TinMonoTypeList *args) { - int count = 0; - while (args != NULL) { - args = args->next; - count++; - } - return count; -} - -static TinMonoType *monoTypeFunctionApplication(HashSymbol *name, TinMonoTypeList *args) { - int nargs = countTypeFunctionArgs(args); - TinFunctionApplication *functionApplication = newTinFunctionApplication(name, nargs, args); int save = PROTECT(functionApplication); - TinMonoType *monoType = monoTypeFunApp(functionApplication); UNPROTECT(save); - return monoType; -} - -static TinMonoType *flatTypeToFunctionApplication(AstFlatType *flatType) { - TinMonoTypeList *typeFunctionArgs = typeSymbolsToMonoTypeList(flatType->typeSymbols); int save = PROTECT(typeFunctionArgs); - TinMonoType *monoType = monoTypeFunctionApplication(flatType->symbol, typeFunctionArgs); UNPROTECT(save); - return monoType; -} - -static TinMonoType *makeCallMonoType(TinMonoType *this, TinMonoType *rest) { - TinMonoTypeList *monoTypeList = newTinMonoTypeList(rest, NULL); int save = PROTECT(monoTypeList); - monoTypeList = newTinMonoTypeList(this, monoTypeList); UNPROTECT(save); save = PROTECT(monoTypeList); - HashSymbol *symbol = arrowSymbol(); - TinMonoType *result = monoTypeFunctionApplication(symbol, monoTypeList); UNPROTECT(save); - return result; -} - -static TinMonoType *makeTypeCallMonoType(TinMonoType *this, TinMonoType *rest) { - TinMonoTypeList *monoTypeList = newTinMonoTypeList(rest, NULL); int save = PROTECT(monoTypeList); - monoTypeList = newTinMonoTypeList(this, monoTypeList); UNPROTECT(save); save = PROTECT(monoTypeList); - HashSymbol *symbol = arrowSymbol(); - TinMonoType *result = monoTypeFunctionApplication(symbol, monoTypeList); UNPROTECT(save); - return result; -} - -static TinMonoType *constantTypeFunction(HashSymbol *name) { - TinFunctionApplication *funApp = constantFunApp(name); int save = PROTECT(funApp); - TinMonoType *monoType = monoTypeFunApp(funApp); UNPROTECT(save); - return monoType; -} - -static TinMonoTypeList *typeListToMonoTypeList(AstTypeList *typeList) { - if (typeList == NULL) return NULL; - TinMonoTypeList *next = typeListToMonoTypeList(typeList->next); int save = PROTECT(next); - TinMonoType *this = astTypeToMonoType(typeList->type); (void) PROTECT(this); - TinMonoTypeList *result = newTinMonoTypeList(this, next); UNPROTECT(save); - return result; -} - -static TinMonoType *astTypeFunctionToMonoType(AstTypeFunction *typeFunction) { - TinMonoTypeList *typeList = typeListToMonoTypeList(typeFunction->typeList); int save = PROTECT(typeList); - TinMonoType *type = monoTypeFunctionApplication(typeFunction->symbol, typeList); UNPROTECT(save); - return type; -} - -static TinMonoType *astTypeClauseToMonoType(AstTypeClause *typeClause) { - switch (typeClause->type) { - case AST_TYPECLAUSE_TYPE_INTEGER: - return constantTypeFunction(intSymbol()); - case AST_TYPECLAUSE_TYPE_CHARACTER: - return constantTypeFunction(charSymbol()); - case AST_TYPECLAUSE_TYPE_VAR: - return monoTypeVar(typeClause->val.var); - case AST_TYPECLAUSE_TYPE_TYPEFUNCTION: { - return astTypeFunctionToMonoType(typeClause->val.typeFunction); - } - default: - cant_happen("unrecognised type %d in astTypeClauseToMonoType", typeClause->type); - } -} - -static TinMonoType *astTypeToMonoType(AstType *type) { - if (type == NULL) return NULL; - TinMonoType *rest = astTypeToMonoType(type->next); int save = PROTECT(rest); - TinMonoType *this = astTypeClauseToMonoType(type->typeClause); - if (rest == NULL) { - UNPROTECT(save); - return this; - } - (void) PROTECT(this); - TinMonoType *funApp = makeCallMonoType(this, rest); - UNPROTECT(save); - return funApp; -} - -static TinMonoType *collectTypeList(AstTypeList *typeList, TinMonoType *final) { - if (typeList == NULL) return final; - TinMonoType *rest = collectTypeList(typeList->next, final); int save = PROTECT(rest); - TinMonoType *this = astTypeToMonoType(typeList->type); (void) PROTECT(this); - TinMonoType *funMonoType = makeTypeCallMonoType(this, rest); UNPROTECT(save); - return funMonoType; -} - -static void collectTypeConstructor(AstTypeConstructor *typeConstructor, TinMonoType *type, TinContext *context) { - TinMonoType *functionType = collectTypeList(typeConstructor->typeList, type); int save = PROTECT(functionType); - generalizeTypeConstructorToContext(context, typeConstructor->symbol, functionType); UNPROTECT(save); -} - -static void collectTypeDef(AstTypeDef *typeDef, TinContext *context) { - TinMonoType *type = flatTypeToFunctionApplication(typeDef->flatType); int save = PROTECT(type); - for (AstTypeBody *typeBody = typeDef->typeBody; typeBody != NULL; typeBody = typeBody->next) { - AstTypeConstructor *typeConstructor = typeBody->typeConstructor; - collectTypeConstructor(typeConstructor, type, context); - } UNPROTECT(save); -} - -static void collectDefinition(AstDefinition *definition, TinContext *context) { - switch (definition->type) { - case AST_DEFINITION_TYPE_DEFINE: - collectDefine(definition->val.define, context); - return; - case AST_DEFINITION_TYPE_PROTOTYPE: - collectPrototype(definition->val.prototype, context); - return; - case AST_DEFINITION_TYPE_LOAD: - cant_happen("load stntax not supported yet"); - case AST_DEFINITION_TYPE_TYPEDEF: - collectTypeDef(definition->val.typeDef, context); - return; - default: - cant_happen("unrecognised type %d in collectDefinition", definition->type); - } -} - -static WResult *anyResult() { - TinMonoType *fresh = freshMonoTypeVar("any"); int save = PROTECT(fresh); - TinSubstitution *empty = makeEmptySubstitution(); (void) PROTECT(empty); - WResult *r = newWResult(empty, fresh); UNPROTECT(save); - return r; -} - -static WResult *constantResult(HashSymbol *typeName) { - TinMonoType *type = constantTypeFunction(typeName); int save = PROTECT(type); - TinSubstitution *empty = makeEmptySubstitution(); (void) PROTECT(empty); - WResult *r = newWResult(empty, type); UNPROTECT(save); - return r; -} - -static WResult *WSymbol(TinContext *context, HashSymbol *symbol, int depth __attribute__((unused))) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WSymbol", myId, -depth); - printHashSymbol(symbol); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - TinPolyType *value = lookupInContext(context, symbol); - if (value == NULL) { - can_happen("undefined variable '%s' in WSymbol", symbol->name); - return anyResult(); - } int save = PROTECT(value); - TinMonoType *instantiated = instantiate(value); PROTECT(instantiated); - TinSubstitution *empty = makeEmptySubstitution(); (void) PROTECT(empty); - WResult *result = newWResult(empty, instantiated); UNPROTECT(save); -#ifdef DEBUG_ALGORITHM_W - leave("WSymbol", myId, depth); - printWResult(result, depth+ 1); - eprintf("\n"); -#endif - return result; -} - -static int countExpressions(AstExpressions *expressions) { - int count = 0; - while (expressions != NULL) { - count++; - expressions = expressions->next; - } - return count; -} - -static AstExpression *nthExpression(int n, AstExpressions *expressions) { - while (n > 1) { // 1-indexed - n--; - expressions = expressions->next; - } - return expressions->expression; -} - -static AstArg *nthArg(int n, AstArgList *args) { - while (n > 1) { // 1-indexed - n--; - args = args->next; - } - return args->arg; -} - -static WResult *WApplicationRec(TinContext *context, AstFunCall *funCall, int nargs, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WapplicationRec", myId, -depth); - eprintf("%d\n", nargs); - if (nargs == 0) { - printAstExpression(funCall->function, depth + 1); - } else { - printAstExpression(nthExpression(nargs, funCall->arguments), depth + 1); - } - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - WResult *result; - int save = -1; - if (nargs == 0) { - result = WExpression(context, funCall->function, depth + 1); - } else { - WResult *Rn_1 = WApplicationRec(context, funCall, nargs - 1, depth); save = PROTECT(Rn_1); - TinSubstitution *S1 = Rn_1->substitution; - TinMonoType *t1 = Rn_1->monoType; - TinContext *S1Gamma = applyContextSubstitution(S1, context); (void) PROTECT(S1Gamma); - AstExpression *en = nthExpression(nargs, funCall->arguments); - WResult *Rn = WExpression(S1Gamma, en, depth + 1); (void) PROTECT(Rn); - TinMonoType *t2 = Rn->monoType; - TinSubstitution *S2 = Rn->substitution; - TinMonoType *beta = freshMonoTypeVar("apprec"); (void) PROTECT(beta); - TinMonoType *arrow = arrowApplication(t2, beta); (void) PROTECT(arrow); - TinMonoType *S2t1 = applyMonoTypeSubstitution(S2, t1); (void) PROTECT(S2t1); - TinSubstitution *S3 = unify(S2t1, arrow, "WApplicationRec"); (void) PROTECT(S3); - TinMonoType *S3Beta = applyMonoTypeSubstitution(S3, beta); (void) PROTECT(S3Beta); - TinSubstitution *S2S1 = applySubstitutionSubstitution(S2, S1); (void) PROTECT(S2S1); - TinSubstitution *S3S2S1 = applySubstitutionSubstitution(S3, S2S1); (void) PROTECT(S3S2S1); - result = newWResult(S3S2S1, S3Beta); - } - if (save != -1) { - UNPROTECT(save); - } -#ifdef DEBUG_ALGORITHM_W - leave("WApplicatonRec", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static WResult *WUnpackRec(TinContext *context, AstUnpack *unpack, int nargs, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WUnpackRec", myId, -depth); - eprintf("%d\n", nargs); - if (nargs == 0) { - eprintf("%*s", (depth + 1) * 4, ""); - printHashSymbol(unpack->symbol); - } else { - printAstArg(nthArg(nargs, unpack->argList), depth + 1); - } - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - WResult *result; - if (nargs == 0) { - result = WSymbol(context, unpack->symbol, depth + 1); - } else { - WResult *Rn_1 = WUnpackRec(context, unpack, nargs - 1, depth); int save = PROTECT(Rn_1); - TinContext *Cn_1 = applyContextSubstitution(Rn_1->substitution, context); (void) PROTECT(Cn_1); - AstArg *e_n = nthArg(nargs, unpack->argList); - WResult *Rn = WFarg(Cn_1, e_n, depth + 1); (void) PROTECT(Rn); - TinMonoType *beta = freshMonoTypeVar("unpackrec"); (void) PROTECT(beta); - TinMonoType *arrow = arrowApplication(Rn->monoType, beta); (void) PROTECT(arrow); - TinMonoType *SnTn_1 = applyMonoTypeSubstitution(Rn->substitution, Rn_1->monoType); (void) PROTECT(SnTn_1); - TinSubstitution *Sprime = unify(SnTn_1, arrow, "WUnpackRec"); (void) PROTECT(Sprime); - TinMonoType *SprimeBeta = applyMonoTypeSubstitution(Sprime, beta); (void) PROTECT(SprimeBeta); - TinSubstitution *SnSn_1 = applySubstitutionSubstitution(Rn->substitution, Rn_1->substitution); (void) PROTECT(SnSn_1); - TinSubstitution *SprimeSnSn_1 = applySubstitutionSubstitution(Sprime, SnSn_1); (void) PROTECT(SprimeSnSn_1); - result = newWResult(SprimeSnSn_1, SprimeBeta); UNPROTECT(save); - } -#ifdef DEBUG_ALGORITHM_W - leave("WUnpackRec", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static WResult *WApplication(TinContext *context, AstFunCall *funCall, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WApplication", myId, depth); - printAstFunCall(funCall, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - int nargs = countExpressions(funCall->arguments); - WResult *result = WApplicationRec(context, funCall, nargs, depth); -#ifdef DEBUG_ALGORITHM_W - leave("WApplication", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static int countArgList(AstArgList *args) { - int count = 0; - while (args != NULL) { - args = args->next; - count++; - } - return count; -} - -static WResult *WUnpack(TinContext *context, AstUnpack *unpack, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WUnpack", myId, depth); - printAstUnpack(unpack, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - int nargs = countArgList(unpack->argList); - WResult *result = WUnpackRec(context, unpack, nargs, depth); -#ifdef DEBUG_ALGORITHM_W - leave("WUnpack", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static WResult *WFarg_d(TinContext *context, AstArg *arg, int depth) { - switch (arg->type) { - case AST_ARG_TYPE_WILDCARD: - return anyResult(); - case AST_ARG_TYPE_SYMBOL: - return WSymbol(context, arg->val.symbol, depth + 1); - case AST_ARG_TYPE_NAMED: { - WResult *r = WFarg(context, arg->val.named->arg, depth + 1); int save = PROTECT(r); - addToSubstitution(r->substitution, arg->val.named->name, r->monoType); UNPROTECT(save); - return r; - } - case AST_ARG_TYPE_ENV: - cant_happen("arg type env not supported yet"); - case AST_ARG_TYPE_UNPACK: - /* equivalent to Funcall, but only type constructors are allowed */ - return WUnpack(context, arg->val.unpack, depth + 1); - case AST_ARG_TYPE_NUMBER: - return constantResult(intSymbol()); - case AST_ARG_TYPE_CHARACTER: - return constantResult(charSymbol()); - default: - cant_happen("unrecognised AstArg type %d in WFarg_d", arg->type); - } -} - -static WResult *WFarg(TinContext *context, AstArg *arg, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WFarg", myId, depth); - printAstArg(arg, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - WResult *result = WFarg_d(context, arg, depth); -#ifdef DEBUG_ALGORITHM_W - leave("WFarg", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static void addVar(TinContext *context, AstArg *arg) { - switch (arg->type) { - case AST_ARG_TYPE_WILDCARD: - case AST_ARG_TYPE_NUMBER: - case AST_ARG_TYPE_CHARACTER: - case AST_ARG_TYPE_ENV: - return; - case AST_ARG_TYPE_SYMBOL: { - if (!isTypeConstructor(context, arg->val.symbol)) { - TinMonoType *var = freshMonoTypeVar(arg->val.symbol->name); int save = PROTECT(var); - addMonoTypeToContext(context, arg->val.symbol, var); UNPROTECT(save); - } - return; - } - case AST_ARG_TYPE_NAMED: { - addVar(context, arg->val.named->arg); - TinMonoType *var = freshMonoTypeVar(arg->val.named->name->name); int save = PROTECT(var); - addMonoTypeToContext(context, arg->val.named->name, var); UNPROTECT(save); - return; - } - case AST_ARG_TYPE_UNPACK: - addVars(context, arg->val.unpack->argList); - return; - } -} - -static void addVars(TinContext *context, AstArgList *args) { - while (args != NULL) { - addVar(context, args->arg); - args = args->next; - } -} - -static TinVarsResult *vars(TinContext *context, HashTable *V, AstArg *arg, int depth) { - TinVarsResult *result; - int save = -1; - switch (arg->type) { - case AST_ARG_TYPE_SYMBOL: { - if (hashGet(V, arg->val.symbol, NULL)) { - result = newTinVarsResult(context, V); - } else if (isTypeConstructor(context, arg->val.symbol)) { - result = newTinVarsResult(context, V); - } else { - TinMonoType *fresh = freshMonoTypeVar(arg->val.symbol->name); save = PROTECT(fresh); - addMonoTypeToContext(context, arg->val.symbol, fresh); - hashSet(V, arg->val.symbol, NULL); - result = newTinVarsResult(context, V); - } - } - break; - case AST_ARG_TYPE_NAMED: { - WResult *wr = WFarg(context, arg->val.named->arg, depth + 1); save = PROTECT(wr); - TinVarsResult *r = vars(context, V, arg->val.named->arg, depth + 1); (void) PROTECT(r); - TinMonoType *subs = applyMonoTypeSubstitution(wr->substitution, wr->monoType); (void) PROTECT(subs); - addMonoTypeToContext(context, arg->val.named->name, subs); - hashSet(V, arg->val.named->name, NULL); - result = newTinVarsResult(context, V); - } - break; - case AST_ARG_TYPE_ENV: { - cant_happen("env not implemented yet (in vars)"); - } - break; - case AST_ARG_TYPE_UNPACK: { - result = newTinVarsResult(context, V); - save = PROTECT(result); - for (AstArgList *argList = arg->val.unpack->argList; argList != NULL; argList = argList->next) { - result = vars(result->context, result->set, argList->arg, depth + 1); UNPROTECT(save); save = PROTECT(result); - } - } - break; - case AST_ARG_TYPE_WILDCARD: - case AST_ARG_TYPE_NUMBER: - case AST_ARG_TYPE_CHARACTER: { - result = newTinVarsResult(context, V); - } - break; - default: - cant_happen("unrecognised arg type %d in vars", arg->type); - } - if (save != -1) { - UNPROTECT(save); - } - return result; -} - -static TinVarResult *var(TinContext *context, HashTable *V, AstArg *arg, int depth) { - TinVarResult *result; - int save = -1; - switch (arg->type) { - case AST_ARG_TYPE_WILDCARD: { - WResult *r = anyResult(); - save = PROTECT(r); - result = newTinVarResult(r->substitution, context, r->monoType, V); - } - break; - case AST_ARG_TYPE_SYMBOL: { - if (isTypeConstructor(context, arg->val.symbol)) { - WResult *r = WSymbol(context, arg->val.symbol, depth + 1); save = PROTECT(r); - result = newTinVarResult(r->substitution, context, r->monoType, V); - } else if (hashGet(V, arg->val.symbol, NULL)) { - TinSubstitution *empty = makeEmptySubstitution(); save = PROTECT(empty); - TinPolyType *tpt = lookupInContext(context, arg->val.symbol); - // TODO verify this is correct - TinMonoType *tmt = instantiate(tpt); (void) PROTECT(tmt); - result = newTinVarResult(empty, context, tmt, V); - } else { - TinMonoType *fresh = freshMonoTypeVar(arg->val.symbol->name); save = PROTECT(fresh); - TinSubstitution *empty = makeEmptySubstitution(); (void) PROTECT(empty); - addMonoTypeToContext(context, arg->val.symbol, fresh); - hashSet(V, arg->val.symbol, NULL); - result = newTinVarResult(empty, context, fresh, V); - } - } - break; - case AST_ARG_TYPE_NAMED: { - TinVarResult *vr = var(context, V, arg->val.named->arg, depth + 1); save = PROTECT(vr); - hashSet(V, arg->val.named->name, NULL); - addMonoTypeToContext(vr->context, arg->val.named->name, vr->monoType); - result = newTinVarResult(vr->substitution, vr->context, vr->monoType, V); - } - break; - case AST_ARG_TYPE_ENV: { - cant_happen("env not implemented yet (in var())"); - } - break; - case AST_ARG_TYPE_UNPACK: { - TinVarsResult *vsr = vars(context, V, arg, depth + 1); save = PROTECT(vsr); - WResult *wr = WUnpack(vsr->context, arg->val.unpack, depth + 1); (void) PROTECT(wr); - result = newTinVarResult(wr->substitution, vsr->context, wr->monoType, vsr->set); - } - break; - case AST_ARG_TYPE_NUMBER: { - WResult *wr = constantResult(intSymbol()); save = PROTECT(wr); - result = newTinVarResult(wr->substitution, context, wr->monoType, V); - } - break; - case AST_ARG_TYPE_CHARACTER: { - WResult *wr = constantResult(charSymbol()); save = PROTECT(wr); - result = newTinVarResult(wr->substitution, context, wr->monoType, V); - } - break; - default: - cant_happen("unrecognised arg type %d in var", arg->type); - } - if (save != -1) { - UNPROTECT(save); - } - return result; -} - -static WResult *WFunctionPrime(TinContext *context, AstFunction *fun, HashTable *V, AstArgList *args, int depth) { - WResult *result; - if (args == NULL) { - result = WNest(context, fun->nest, depth + 1); - } else { - TinVarResult *vr = var(context, V, args->arg, depth + 1); int save = PROTECT(vr); - WResult *wPrime = WFunctionPrime(vr->context, fun, vr->set, args->next, depth); (void) PROTECT(wPrime); - TinSubstitution *S = applySubstitutionSubstitution(vr->substitution, wPrime->substitution); (void) PROTECT(S); - TinMonoType *beta = applyMonoTypeSubstitution(S, vr->monoType); (void) PROTECT(beta); - TinMonoType *application = arrowApplication(beta, wPrime->monoType); (void) PROTECT(application); - result = newWResult(S, application); UNPROTECT(save); - } - return result; -} - -static WResult *WFunction(TinContext *context, AstFunction *fun, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WFunction", myId, depth); - printAstFunction(fun, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - TinContext *fnContext = extendTinContext(context); int save = PROTECT(fnContext); - HashTable *V = newHashTable(0, NULL, NULL); (void) PROTECT(V); - WResult *result = WFunctionPrime(fnContext, fun, V, fun->argList, depth + 1); UNPROTECT(save); -#ifdef DEBUG_ALGORITHM_W - leave("WFunction", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -/* - * unifies between distinct function cases in a composite function - */ -static WResult *WAbstraction(TinContext *context, AstCompositeFunction *fun, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WAbstraction", myId, depth); - printAstCompositeFunction(fun, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - WResult *result; - if (fun->next == NULL) { - result = WFunction(context, fun->function, depth + 1); - } else { - WResult *Wn_1 = WAbstraction(context, fun->next, depth); int save = PROTECT(Wn_1); - TinSubstitution *Sn1 = Wn_1->substitution; - TinMonoType *tn1 = Wn_1->monoType; - TinMonoType *Sn1tn1 = applyMonoTypeSubstitution(Sn1, tn1); (void) PROTECT(Sn1tn1); - - WResult *Wn = WFunction(context, fun->function, depth + 1); (void) PROTECT(Wn); - TinSubstitution *Sn = Wn->substitution; - TinMonoType *tn = Wn->monoType; - TinMonoType *Sntn = applyMonoTypeSubstitution(Sn, tn); (void) PROTECT(Sntn); - - TinSubstitution *S = unify(Sntn, Sn1tn1, "WAbstraction"); int save2 = PROTECT(S); - S = applySubstitutionSubstitution(Sn, S); UNPROTECT(save2); save2 = PROTECT(S); - S = applySubstitutionSubstitution(Sn1, S); UNPROTECT(save2); (void) PROTECT(S); - TinMonoType *Stn = applyMonoTypeSubstitution(S, Wn->monoType); (void) PROTECT(Stn); - result = newWResult(S, Stn); - UNPROTECT(save); - } -#ifdef DEBUG_ALGORITHM_W - leave("WAbstraction", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static WResult *WExpression_d(TinContext *context, AstExpression *expr, int depth) { - switch (expr->type) { - case AST_EXPRESSION_TYPE_BACK: - return anyResult(); - case AST_EXPRESSION_TYPE_FUNCALL: - return WApplication(context, expr->val.funCall, depth + 1); - case AST_EXPRESSION_TYPE_SYMBOL: - return WSymbol(context, expr->val.symbol, depth + 1); - case AST_EXPRESSION_TYPE_NUMBER: - return constantResult(intSymbol()); - case AST_EXPRESSION_TYPE_CHARACTER: - return constantResult(charSymbol()); - case AST_EXPRESSION_TYPE_FUN: - return WAbstraction(context, expr->val.fun, depth + 1); - case AST_EXPRESSION_TYPE_NEST: - return WNest(context, expr->val.nest, depth + 1); - case AST_EXPRESSION_TYPE_IFF: - return WIff(context, expr->val.iff, depth + 1); - case AST_EXPRESSION_TYPE_ENV: - cant_happen("env type not supported yet"); - // return WEnv(context, expr->val.env); - default: - cant_happen("unrecognised type %d in WExpression", expr->type); - } -} - -static WResult *WExpression(TinContext *context, AstExpression *expr, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WExpression", myId, depth); - printAstExpression(expr, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - WResult *result = WExpression_d(context, expr, depth); -#ifdef DEBUG_ALGORITHM_W - leave("WExpression", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static TinMonoType *extractMonoType(TinPolyType *tpt) { - if (tpt == NULL) { - cant_happen("expected non-null TinPolyType in extractMonoType"); - } - if (tpt->type != TINPOLYTYPE_TYPE_MONOTYPE) { - cant_happen("expected TINPOLYTYPE_TYPE_MONOTYPE"); - } - return tpt->val.monoType; -} - -static void WDefine(TinContext *context, AstDefine *define, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WDefine", myId, depth); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - HashSymbol *F = define->symbol; - AstExpression *expression = define->expression; - TinMonoType *originalTypeOfF = extractMonoType(lookupInContext(context, F)); - WResult *wResult = WExpression(context, expression, depth + 1); int save = PROTECT(wResult); - TinSubstitution *substitutionsCalculatedFromF = wResult->substitution; - TinMonoType *substitutedOriginalTypeOfF = applyMonoTypeSubstitution(substitutionsCalculatedFromF, originalTypeOfF); PROTECT(substitutedOriginalTypeOfF); - TinMonoType *calculatedTypeOfF = wResult->monoType; - TinSubstitution *unifyingSubstitution = unify(substitutedOriginalTypeOfF, calculatedTypeOfF, "WDefine"); PROTECT(unifyingSubstitution); - TinMonoType *finalMonoTypeOfF = applyMonoTypeSubstitution(unifyingSubstitution, calculatedTypeOfF); PROTECT(finalMonoTypeOfF); - TinSubstitution *finalSubstitution = - applySubstitutionSubstitution(unifyingSubstitution, substitutionsCalculatedFromF); PROTECT(finalSubstitution); - applyContextSubstitutionInPlace(finalSubstitution, context); - TinPolyType *finalTypeOfF = generalize(context, finalMonoTypeOfF); PROTECT(finalTypeOfF); - addVarToContext(context, F, finalTypeOfF); UNPROTECT(save); -#ifdef DEBUG_ALGORITHM_W - leave("WDefine", myId, depth); - printHashSymbol(F); - eprintf(" =\n%*s", depth * 4, ""); - showTinPolyType(finalTypeOfF); - eprintf("\n"); -#endif -} - -static void WDefinition(TinContext *context, AstDefinition *definition, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WDefinition", myId, depth); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - switch (definition->type) { - case AST_DEFINITION_TYPE_DEFINE: - WDefine(context, definition->val.define, depth + 1); - break; - case AST_DEFINITION_TYPE_PROTOTYPE: - // WPrototype(context, definition->val.prototype); - break; - case AST_DEFINITION_TYPE_LOAD: - // WLoad(context, definition->val.load); - break; - case AST_DEFINITION_TYPE_TYPEDEF: - // WTypeDef(context, definition->val.typeDef); - break; - default: - cant_happen("unrecognised type %d in WDefinition", definition->type); - } -#ifdef DEBUG_ALGORITHM_W - leave("WDefinition", myId, depth); - eprintf("\n"); -#endif -} - -static AstFunction *makeBoolFunction(HashSymbol *argSymbol, AstNest *body) { - AstArg *arg = newAstArg(AST_ARG_TYPE_SYMBOL, AST_ARG_VAL_SYMBOL(argSymbol)); int save = PROTECT(arg); - AstArgList *argList = newAstArgList(arg, NULL); PROTECT(argList); - AstFunction *result = newAstFunction(argList, body); UNPROTECT(save); - return result; -} - -/* - * if (test) { consequent } else { alternative } - * - * becomes - * - * fn { (true) { consequent } (false) { alternative } }(test) - */ -static AstFunCall *fakeAstConditional(AstExpression *condition, AstNest *consequent, AstNest *alternative) { - AstFunction *trueFunction = makeBoolFunction(trueSymbol(), consequent); int save = PROTECT(trueFunction); - AstFunction *falseFunction = makeBoolFunction(falseSymbol(), alternative); PROTECT(falseFunction); - AstCompositeFunction *body = newAstCompositeFunction(falseFunction, NULL); PROTECT(body); - body = newAstCompositeFunction(trueFunction, body); PROTECT(body); - AstExpression *funExpression = newAstExpression(AST_EXPRESSION_TYPE_FUN, AST_EXPRESSION_VAL_FUN(body)); PROTECT(funExpression); - AstExpressions *args = newAstExpressions(condition, NULL); PROTECT(args); - AstFunCall *funCall = newAstFunCall(funExpression, args); UNPROTECT(save); - return funCall; -} - -static WResult *WIff(TinContext *context, AstIff *iff, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WIff", myId, depth); - printAstIff(iff, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - AstFunCall *fakeIff = fakeAstConditional(iff->test, iff->consequent, iff->alternative); int save = PROTECT(fakeIff); - WResult *result = WApplication(context, fakeIff, depth); UNPROTECT(save); -#ifdef DEBUG_ALGORITHM_W - leave("WIff", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static WResult *WNest(TinContext *context, AstNest *nest, int depth) { -#ifdef DEBUG_ALGORITHM_W - int myId = idSource++; - enter("WNest", myId, depth); - printAstNest(nest, depth + 1); - eprintf("\n"); - printTinContext(context, depth + 1); - eprintf("\n"); -#endif - if (nest == NULL) return newWResult(NULL, NULL); - TinContext *nestContext = extendTinContext(context); int save = PROTECT(nestContext); - AstDefinitions *definitions; - // the outer let of a nest is effectively letrec, so we collect the functions being defined first - for (definitions = nest->definitions; definitions != NULL; definitions = definitions->next) { - collectDefinition(definitions->definition, nestContext); - } - for (definitions = nest->definitions; definitions != NULL; definitions = definitions->next) { - WDefinition(nestContext, definitions->definition, depth + 1); - } - WResult *result = NULL; int save2 = PROTECT(nestContext); - for (AstExpressions *expressions = nest->expressions; expressions != NULL; expressions = expressions->next) { - result = WExpression(nestContext, expressions->expression, depth + 1); UNPROTECT(save2); save2 = PROTECT(result); - } - UNPROTECT(save); -#ifdef DEBUG_ALGORITHM_W - leave("WNest", myId, depth); - printWResult(result, depth + 1); - eprintf("\n"); -#endif - return result; -} - -static void addBinOp(TinContext *context, HashSymbol *op, TinMonoType *a, TinMonoType *b, TinMonoType *c) { - // #a -> ~b -> #c - TinMonoType *funApp = arrowApplication(b, c); int save = PROTECT(funApp); - funApp = arrowApplication(a, funApp); UNPROTECT(save); save = PROTECT(funApp); - generalizeMonoTypeToContext(context, op, funApp); UNPROTECT(save); -} - -static void addIntBinOp(TinContext *context, HashSymbol *symbol) { - // int -> int -> int - TinMonoType *intSym = constantTypeFunction(intSymbol()); int save = PROTECT(intSym); - addBinOp(context, symbol, intSym, intSym, intSym); UNPROTECT(save); -} - -static void addBoolBinOp(TinContext *context, HashSymbol *symbol) { - // bool -> bool -> bool - TinMonoType *boolSym = constantTypeFunction(boolSymbol()); int save = PROTECT(boolSym); - addBinOp(context, symbol, boolSym, boolSym, boolSym); UNPROTECT(save); -} - -static void addThen(TinContext *context) { - // #a -> #a -> #a - HashSymbol *then = thenSymbol(); - TinMonoType *fresh = freshMonoTypeVar(then->name); int save = PROTECT(fresh); - addBinOp(context, then, fresh, fresh, fresh); UNPROTECT(save); -} - -static void addBack(TinContext *context) { - // #a - HashSymbol *back = backSymbol(); - TinMonoType *fresh = freshMonoTypeVar(back->name); int save = PROTECT(fresh); - generalizeMonoTypeToContext(context, back, fresh); UNPROTECT(save); -} - -static void addComparisonBinOp(TinContext *context, HashSymbol *op) { - // #a -> #a -> bool - TinMonoType *fresh = freshMonoTypeVar(op->name); int save = PROTECT(fresh); - TinMonoType *boolSym = constantTypeFunction(boolSymbol()); (void) PROTECT(boolSym); - addBinOp(context, op, fresh, fresh, boolSym); UNPROTECT(save); -} - -static void addNot(TinContext *context) { - // bool -> bool - HashSymbol *op = notSymbol(); - TinMonoType *boolSym = constantTypeFunction(boolSymbol()); int save = PROTECT(boolSym); - TinMonoType *funApp = arrowApplication(boolSym, boolSym); UNPROTECT(save); save = PROTECT(funApp); - generalizeMonoTypeToContext(context, op, funApp); UNPROTECT(save); -} - -static void addNegate(TinContext *context) { - // int -> int - HashSymbol *op = negSymbol(); - TinMonoType *intSym = constantTypeFunction(intSymbol()); int save = PROTECT(intSym); - TinMonoType *funApp = arrowApplication(intSym, intSym); UNPROTECT(save); save = PROTECT(funApp); - generalizeMonoTypeToContext(context, op, funApp); UNPROTECT(save); -} - -static void addHere(TinContext *context) { - // ((#a -> #b) -> #a) -> #a - TinMonoType *a = freshMonoTypeVar("here-a"); int save = PROTECT(a); - TinMonoType *b = freshMonoTypeVar("here-b"); (void) PROTECT(b); - TinMonoType *arrow = arrowApplication(a, b); UNPROTECT(save); save = PROTECT(arrow); - arrow = arrowApplication(arrow, a); UNPROTECT(save); save = PROTECT(arrow); - arrow = arrowApplication(arrow, a); UNPROTECT(save); save = PROTECT(arrow); - HashSymbol *op = hereSymbol(); - generalizeMonoTypeToContext(context, op, arrow); UNPROTECT(save); -} - -static void addError(TinContext *context) { - // #t1 -> #t2 - HashSymbol *op = errorSymbol(); - TinMonoType *fresh1 = freshMonoTypeVar(op->name); int save = PROTECT(fresh1); - TinMonoType *fresh2 = freshMonoTypeVar(op->name); PROTECT(fresh2); - TinMonoType *funApp = arrowApplication(fresh1, fresh2); PROTECT(funApp); - generalizeMonoTypeToContext(context, op, funApp); UNPROTECT(save); -} - -WResult *WTop(AstNest *nest) { - TinContext *context = freshTinContext(); - int save = PROTECT(context); - addIntBinOp(context, addSymbol()); - addIntBinOp(context, subSymbol()); - addIntBinOp(context, mulSymbol()); - addIntBinOp(context, divSymbol()); - addIntBinOp(context, modSymbol()); - addIntBinOp(context, powSymbol()); - addBoolBinOp(context, andSymbol()); - addBoolBinOp(context, orSymbol()); - addBoolBinOp(context, xorSymbol()); - addComparisonBinOp(context, eqSymbol()); - addComparisonBinOp(context, ltSymbol()); - addComparisonBinOp(context, gtSymbol()); - addComparisonBinOp(context, leSymbol()); - addComparisonBinOp(context, geSymbol()); - addComparisonBinOp(context, neSymbol()); - addThen(context); - addNot(context); - addNegate(context); - addHere(context); - addBack(context); - addError(context); - WResult *res = WNest(context, nest, 0); - UNPROTECT(save); - return res; -} diff --git a/src/debug_ast.c b/src/ast_debug.c similarity index 53% rename from src/debug_ast.c rename to src/ast_debug.c index 78e041e..2a17638 100644 --- a/src/debug_ast.c +++ b/src/ast_debug.c @@ -22,7 +22,7 @@ #include -#include "debug_ast.h" +#include "ast_debug.h" #include "bigint.h" static void pad(int depth) { eprintf("%*s", depth * 4, ""); } @@ -55,7 +55,7 @@ void printAstDefine(struct AstDefine * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstDefine (NULL)"); return; } eprintf("AstDefine[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstExpression(x->expression, depth + 1); eprintf("\n"); @@ -67,7 +67,7 @@ void printAstPrototype(struct AstPrototype * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstPrototype (NULL)"); return; } eprintf("AstPrototype[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstPrototypeBody(x->body, depth + 1); eprintf("\n"); @@ -91,7 +91,7 @@ void printAstPrototypeSymbolType(struct AstPrototypeSymbolType * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstPrototypeSymbolType (NULL)"); return; } eprintf("AstPrototypeSymbolType[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstType(x->type, depth + 1); eprintf("\n"); @@ -105,7 +105,7 @@ void printAstLoad(struct AstLoad * x, int depth) { eprintf("AstLoad[\n"); printAstPackage(x->package, depth + 1); eprintf("\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); pad(depth); eprintf("]"); @@ -127,7 +127,7 @@ void printAstFlatType(struct AstFlatType * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstFlatType (NULL)"); return; } eprintf("AstFlatType[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstTypeSymbols(x->typeSymbols, depth + 1); eprintf("\n"); @@ -139,7 +139,7 @@ void printAstTypeSymbols(struct AstTypeSymbols * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstTypeSymbols (NULL)"); return; } eprintf("AstTypeSymbols[\n"); - printAstSymbol(x->typeSymbol, depth + 1); + printAstSymbol(x->typeSymbol, depth + 1); eprintf("\n"); printAstTypeSymbols(x->next, depth + 1); eprintf("\n"); @@ -163,7 +163,7 @@ void printAstTypeConstructor(struct AstTypeConstructor * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstTypeConstructor (NULL)"); return; } eprintf("AstTypeConstructor[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstTypeList(x->typeList, depth + 1); eprintf("\n"); @@ -175,7 +175,7 @@ void printAstTypeFunction(struct AstTypeFunction * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstTypeFunction (NULL)"); return; } eprintf("AstTypeFunction[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstTypeList(x->typeList, depth + 1); eprintf("\n"); @@ -247,7 +247,7 @@ void printAstUnpack(struct AstUnpack * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstUnpack (NULL)"); return; } eprintf("AstUnpack[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstArgList(x->argList, depth + 1); eprintf("\n"); @@ -259,7 +259,7 @@ void printAstNamedArg(struct AstNamedArg * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstNamedArg (NULL)"); return; } eprintf("AstNamedArg[\n"); - printAstSymbol(x->name, depth + 1); + printAstSymbol(x->name, depth + 1); eprintf("\n"); printAstArg(x->arg, depth + 1); eprintf("\n"); @@ -271,9 +271,9 @@ void printAstEnvType(struct AstEnvType * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstEnvType (NULL)"); return; } eprintf("AstEnvType[\n"); - printAstSymbol(x->name, depth + 1); + printAstSymbol(x->name, depth + 1); eprintf("\n"); - printAstSymbol(x->prototype, depth + 1); + printAstSymbol(x->prototype, depth + 1); eprintf("\n"); pad(depth); eprintf("]"); @@ -295,7 +295,7 @@ void printAstPackage(struct AstPackage * x, int depth) { pad(depth); if (x == NULL) { eprintf("AstPackage (NULL)"); return; } eprintf("AstPackage[\n"); - printAstSymbol(x->symbol, depth + 1); + printAstSymbol(x->symbol, depth + 1); eprintf("\n"); printAstPackage(x->next, depth + 1); eprintf("\n"); @@ -405,19 +405,19 @@ void printAstTypeClause(struct AstTypeClause * x, int depth) { case AST_TYPECLAUSE_TYPE_INTEGER: pad(depth + 1); eprintf("AST_TYPECLAUSE_TYPE_INTEGER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.integer); break; case AST_TYPECLAUSE_TYPE_CHARACTER: pad(depth + 1); eprintf("AST_TYPECLAUSE_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.character); break; case AST_TYPECLAUSE_TYPE_VAR: pad(depth + 1); eprintf("AST_TYPECLAUSE_TYPE_VAR\n"); - printAstSymbol(x->val.var, depth + 1); + printAstSymbol(x->val.var, depth + 1); break; case AST_TYPECLAUSE_TYPE_TYPEFUNCTION: pad(depth + 1); @@ -440,13 +440,13 @@ void printAstArg(struct AstArg * x, int depth) { case AST_ARG_TYPE_WILDCARD: pad(depth + 1); eprintf("AST_ARG_TYPE_WILDCARD\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.wildcard); break; case AST_ARG_TYPE_SYMBOL: pad(depth + 1); eprintf("AST_ARG_TYPE_SYMBOL\n"); - printAstSymbol(x->val.symbol, depth + 1); + printAstSymbol(x->val.symbol, depth + 1); break; case AST_ARG_TYPE_NAMED: pad(depth + 1); @@ -466,12 +466,12 @@ eprintf("void * %p", x->val.wildcard); case AST_ARG_TYPE_NUMBER: pad(depth + 1); eprintf("AST_ARG_TYPE_NUMBER\n"); - printBigInt(x->val.number, depth + 1); + printBigInt(x->val.number, depth + 1); break; case AST_ARG_TYPE_CHARACTER: pad(depth + 1); eprintf("AST_ARG_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("char %c", x->val.character); break; default: @@ -490,7 +490,7 @@ void printAstExpression(struct AstExpression * x, int depth) { case AST_EXPRESSION_TYPE_BACK: pad(depth + 1); eprintf("AST_EXPRESSION_TYPE_BACK\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.back); break; case AST_EXPRESSION_TYPE_FUNCALL: @@ -501,17 +501,17 @@ eprintf("void * %p", x->val.back); case AST_EXPRESSION_TYPE_SYMBOL: pad(depth + 1); eprintf("AST_EXPRESSION_TYPE_SYMBOL\n"); - printAstSymbol(x->val.symbol, depth + 1); + printAstSymbol(x->val.symbol, depth + 1); break; case AST_EXPRESSION_TYPE_NUMBER: pad(depth + 1); eprintf("AST_EXPRESSION_TYPE_NUMBER\n"); - printBigInt(x->val.number, depth + 1); + printBigInt(x->val.number, depth + 1); break; case AST_EXPRESSION_TYPE_CHARACTER: pad(depth + 1); eprintf("AST_EXPRESSION_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("char %c", x->val.character); break; case AST_EXPRESSION_TYPE_FUN: @@ -542,3 +542,348 @@ eprintf("char %c", x->val.character); eprintf("]"); } + +/***************************************/ + +bool eqAstNest(struct AstNest * a, struct AstNest * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstDefinitions(a->definitions, b->definitions)) return false; + if (!eqAstExpressions(a->expressions, b->expressions)) return false; + return true; +} + +bool eqAstDefinitions(struct AstDefinitions * a, struct AstDefinitions * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstDefinition(a->definition, b->definition)) return false; + if (!eqAstDefinitions(a->next, b->next)) return false; + return true; +} + +bool eqAstDefine(struct AstDefine * a, struct AstDefine * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstExpression(a->expression, b->expression)) return false; + return true; +} + +bool eqAstPrototype(struct AstPrototype * a, struct AstPrototype * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstPrototypeBody(a->body, b->body)) return false; + return true; +} + +bool eqAstPrototypeBody(struct AstPrototypeBody * a, struct AstPrototypeBody * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstSinglePrototype(a->single, b->single)) return false; + if (!eqAstPrototypeBody(a->next, b->next)) return false; + return true; +} + +bool eqAstPrototypeSymbolType(struct AstPrototypeSymbolType * a, struct AstPrototypeSymbolType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstType(a->type, b->type)) return false; + return true; +} + +bool eqAstLoad(struct AstLoad * a, struct AstLoad * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstPackage(a->package, b->package)) return false; + if (a->symbol != b->symbol) return false; + return true; +} + +bool eqAstTypeDef(struct AstTypeDef * a, struct AstTypeDef * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstFlatType(a->flatType, b->flatType)) return false; + if (!eqAstTypeBody(a->typeBody, b->typeBody)) return false; + return true; +} + +bool eqAstFlatType(struct AstFlatType * a, struct AstFlatType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstTypeSymbols(a->typeSymbols, b->typeSymbols)) return false; + return true; +} + +bool eqAstTypeSymbols(struct AstTypeSymbols * a, struct AstTypeSymbols * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->typeSymbol != b->typeSymbol) return false; + if (!eqAstTypeSymbols(a->next, b->next)) return false; + return true; +} + +bool eqAstTypeBody(struct AstTypeBody * a, struct AstTypeBody * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstTypeConstructor(a->typeConstructor, b->typeConstructor)) return false; + if (!eqAstTypeBody(a->next, b->next)) return false; + return true; +} + +bool eqAstTypeConstructor(struct AstTypeConstructor * a, struct AstTypeConstructor * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstTypeList(a->typeList, b->typeList)) return false; + return true; +} + +bool eqAstTypeFunction(struct AstTypeFunction * a, struct AstTypeFunction * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstTypeList(a->typeList, b->typeList)) return false; + return true; +} + +bool eqAstTypeList(struct AstTypeList * a, struct AstTypeList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstType(a->type, b->type)) return false; + if (!eqAstTypeList(a->next, b->next)) return false; + return true; +} + +bool eqAstType(struct AstType * a, struct AstType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstTypeClause(a->typeClause, b->typeClause)) return false; + if (!eqAstType(a->next, b->next)) return false; + return true; +} + +bool eqAstCompositeFunction(struct AstCompositeFunction * a, struct AstCompositeFunction * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstFunction(a->function, b->function)) return false; + if (!eqAstCompositeFunction(a->next, b->next)) return false; + return true; +} + +bool eqAstFunction(struct AstFunction * a, struct AstFunction * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstArgList(a->argList, b->argList)) return false; + if (!eqAstNest(a->nest, b->nest)) return false; + return true; +} + +bool eqAstArgList(struct AstArgList * a, struct AstArgList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstArg(a->arg, b->arg)) return false; + if (!eqAstArgList(a->next, b->next)) return false; + return true; +} + +bool eqAstUnpack(struct AstUnpack * a, struct AstUnpack * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstArgList(a->argList, b->argList)) return false; + return true; +} + +bool eqAstNamedArg(struct AstNamedArg * a, struct AstNamedArg * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqAstArg(a->arg, b->arg)) return false; + return true; +} + +bool eqAstEnvType(struct AstEnvType * a, struct AstEnvType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (a->prototype != b->prototype) return false; + return true; +} + +bool eqAstFunCall(struct AstFunCall * a, struct AstFunCall * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstExpression(a->function, b->function)) return false; + if (!eqAstExpressions(a->arguments, b->arguments)) return false; + return true; +} + +bool eqAstPackage(struct AstPackage * a, struct AstPackage * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->symbol != b->symbol) return false; + if (!eqAstPackage(a->next, b->next)) return false; + return true; +} + +bool eqAstExpressions(struct AstExpressions * a, struct AstExpressions * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstExpression(a->expression, b->expression)) return false; + if (!eqAstExpressions(a->next, b->next)) return false; + return true; +} + +bool eqAstEnv(struct AstEnv * a, struct AstEnv * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstPackage(a->package, b->package)) return false; + if (!eqAstDefinitions(a->definitions, b->definitions)) return false; + return true; +} + +bool eqAstIff(struct AstIff * a, struct AstIff * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqAstExpression(a->test, b->test)) return false; + if (!eqAstNest(a->consequent, b->consequent)) return false; + if (!eqAstNest(a->alternative, b->alternative)) return false; + return true; +} + +bool eqAstDefinition(struct AstDefinition * a, struct AstDefinition * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case AST_DEFINITION_TYPE_DEFINE: + if (!eqAstDefine(a->val.define, b->val.define)) return false; + break; + case AST_DEFINITION_TYPE_PROTOTYPE: + if (!eqAstPrototype(a->val.prototype, b->val.prototype)) return false; + break; + case AST_DEFINITION_TYPE_LOAD: + if (!eqAstLoad(a->val.load, b->val.load)) return false; + break; + case AST_DEFINITION_TYPE_TYPEDEF: + if (!eqAstTypeDef(a->val.typeDef, b->val.typeDef)) return false; + break; + default: + cant_happen("unrecognised type %d in eqAstDefinition", a->type); + } + return true; +} + +bool eqAstSinglePrototype(struct AstSinglePrototype * a, struct AstSinglePrototype * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case AST_SINGLEPROTOTYPE_TYPE_SYMBOLTYPE: + if (!eqAstPrototypeSymbolType(a->val.symbolType, b->val.symbolType)) return false; + break; + case AST_SINGLEPROTOTYPE_TYPE_PROTOTYPE: + if (!eqAstPrototype(a->val.prototype, b->val.prototype)) return false; + break; + default: + cant_happen("unrecognised type %d in eqAstSinglePrototype", a->type); + } + return true; +} + +bool eqAstTypeClause(struct AstTypeClause * a, struct AstTypeClause * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case AST_TYPECLAUSE_TYPE_INTEGER: + if (a->val.integer != b->val.integer) return false; + break; + case AST_TYPECLAUSE_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case AST_TYPECLAUSE_TYPE_VAR: + if (a->val.var != b->val.var) return false; + break; + case AST_TYPECLAUSE_TYPE_TYPEFUNCTION: + if (!eqAstTypeFunction(a->val.typeFunction, b->val.typeFunction)) return false; + break; + default: + cant_happen("unrecognised type %d in eqAstTypeClause", a->type); + } + return true; +} + +bool eqAstArg(struct AstArg * a, struct AstArg * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case AST_ARG_TYPE_WILDCARD: + if (a->val.wildcard != b->val.wildcard) return false; + break; + case AST_ARG_TYPE_SYMBOL: + if (a->val.symbol != b->val.symbol) return false; + break; + case AST_ARG_TYPE_NAMED: + if (!eqAstNamedArg(a->val.named, b->val.named)) return false; + break; + case AST_ARG_TYPE_ENV: + if (!eqAstEnvType(a->val.env, b->val.env)) return false; + break; + case AST_ARG_TYPE_UNPACK: + if (!eqAstUnpack(a->val.unpack, b->val.unpack)) return false; + break; + case AST_ARG_TYPE_NUMBER: + if (a->val.number != b->val.number) return false; + break; + case AST_ARG_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + default: + cant_happen("unrecognised type %d in eqAstArg", a->type); + } + return true; +} + +bool eqAstExpression(struct AstExpression * a, struct AstExpression * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case AST_EXPRESSION_TYPE_BACK: + if (a->val.back != b->val.back) return false; + break; + case AST_EXPRESSION_TYPE_FUNCALL: + if (!eqAstFunCall(a->val.funCall, b->val.funCall)) return false; + break; + case AST_EXPRESSION_TYPE_SYMBOL: + if (a->val.symbol != b->val.symbol) return false; + break; + case AST_EXPRESSION_TYPE_NUMBER: + if (a->val.number != b->val.number) return false; + break; + case AST_EXPRESSION_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case AST_EXPRESSION_TYPE_FUN: + if (!eqAstCompositeFunction(a->val.fun, b->val.fun)) return false; + break; + case AST_EXPRESSION_TYPE_ENV: + if (!eqAstEnv(a->val.env, b->val.env)) return false; + break; + case AST_EXPRESSION_TYPE_NEST: + if (!eqAstNest(a->val.nest, b->val.nest)) return false; + break; + case AST_EXPRESSION_TYPE_IFF: + if (!eqAstIff(a->val.iff, b->val.iff)) return false; + break; + default: + cant_happen("unrecognised type %d in eqAstExpression", a->type); + } + return true; +} + diff --git a/src/debug_ast.h b/src/ast_debug.h similarity index 55% rename from src/debug_ast.h rename to src/ast_debug.h index 7e8c9d8..4c4601d 100644 --- a/src/debug_ast.h +++ b/src/ast_debug.h @@ -1,5 +1,5 @@ -#ifndef cekf_debug_ast_h -#define cekf_debug_ast_h +#ifndef cekf_ast_debug_h +#define cekf_ast_debug_h /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -57,4 +57,36 @@ void printAstTypeClause(struct AstTypeClause * x, int depth); void printAstArg(struct AstArg * x, int depth); void printAstExpression(struct AstExpression * x, int depth); +bool eqAstNest(struct AstNest * a, struct AstNest * b); +bool eqAstDefinitions(struct AstDefinitions * a, struct AstDefinitions * b); +bool eqAstDefine(struct AstDefine * a, struct AstDefine * b); +bool eqAstPrototype(struct AstPrototype * a, struct AstPrototype * b); +bool eqAstPrototypeBody(struct AstPrototypeBody * a, struct AstPrototypeBody * b); +bool eqAstPrototypeSymbolType(struct AstPrototypeSymbolType * a, struct AstPrototypeSymbolType * b); +bool eqAstLoad(struct AstLoad * a, struct AstLoad * b); +bool eqAstTypeDef(struct AstTypeDef * a, struct AstTypeDef * b); +bool eqAstFlatType(struct AstFlatType * a, struct AstFlatType * b); +bool eqAstTypeSymbols(struct AstTypeSymbols * a, struct AstTypeSymbols * b); +bool eqAstTypeBody(struct AstTypeBody * a, struct AstTypeBody * b); +bool eqAstTypeConstructor(struct AstTypeConstructor * a, struct AstTypeConstructor * b); +bool eqAstTypeFunction(struct AstTypeFunction * a, struct AstTypeFunction * b); +bool eqAstTypeList(struct AstTypeList * a, struct AstTypeList * b); +bool eqAstType(struct AstType * a, struct AstType * b); +bool eqAstCompositeFunction(struct AstCompositeFunction * a, struct AstCompositeFunction * b); +bool eqAstFunction(struct AstFunction * a, struct AstFunction * b); +bool eqAstArgList(struct AstArgList * a, struct AstArgList * b); +bool eqAstUnpack(struct AstUnpack * a, struct AstUnpack * b); +bool eqAstNamedArg(struct AstNamedArg * a, struct AstNamedArg * b); +bool eqAstEnvType(struct AstEnvType * a, struct AstEnvType * b); +bool eqAstFunCall(struct AstFunCall * a, struct AstFunCall * b); +bool eqAstPackage(struct AstPackage * a, struct AstPackage * b); +bool eqAstExpressions(struct AstExpressions * a, struct AstExpressions * b); +bool eqAstEnv(struct AstEnv * a, struct AstEnv * b); +bool eqAstIff(struct AstIff * a, struct AstIff * b); +bool eqAstDefinition(struct AstDefinition * a, struct AstDefinition * b); +bool eqAstSinglePrototype(struct AstSinglePrototype * a, struct AstSinglePrototype * b); +bool eqAstTypeClause(struct AstTypeClause * a, struct AstTypeClause * b); +bool eqAstArg(struct AstArg * a, struct AstArg * b); +bool eqAstExpression(struct AstExpression * a, struct AstExpression * b); + #endif diff --git a/src/common.h b/src/common.h index 27b1471..4a00a0a 100644 --- a/src/common.h +++ b/src/common.h @@ -45,8 +45,8 @@ // #define DEBUG_BYTECODE // define this to make fatal errors dump core (if ulimit allows) #define DEBUG_DUMP_CORE -// #define DEBUG_ALGORITHM_W -#define DEBUG_LAMBDA_CONVERT +#define DEBUG_TC +// #define DEBUG_LAMBDA_CONVERT // #define DEBUG_LEAK // #define DEBUG_ANF // #define DEBUG_ALLOC diff --git a/src/debug_tin.c b/src/debug_tin.c deleted file mode 100644 index 0eed172..0000000 --- a/src/debug_tin.c +++ /dev/null @@ -1,177 +0,0 @@ -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * Type inference structures used by Algorithm W. - * - * generated from src/tin.yaml by makeAST.py - */ - -#include - -#include "debug_tin.h" - -static void pad(int depth) { eprintf("%*s", depth * 4, ""); } - -void printTinFunctionApplication(struct TinFunctionApplication * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinFunctionApplication (NULL)"); return; } - eprintf("TinFunctionApplication[\n"); - printTinSymbol(x->name, depth + 1); - eprintf("\n"); - pad(depth + 1); -eprintf("int %d", x->nargs); - eprintf("\n"); - printTinMonoTypeList(x->args, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinMonoTypeList(struct TinMonoTypeList * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinMonoTypeList (NULL)"); return; } - eprintf("TinMonoTypeList[\n"); - printTinMonoType(x->monoType, depth + 1); - eprintf("\n"); - printTinMonoTypeList(x->next, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinTypeQuantifier(struct TinTypeQuantifier * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinTypeQuantifier (NULL)"); return; } - eprintf("TinTypeQuantifier[\n"); - printTinSymbol(x->var, depth + 1); - eprintf("\n"); - printTinPolyType(x->quantifiedType, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinContext(struct TinContext * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinContext (NULL)"); return; } - eprintf("TinContext[\n"); - printHashTable(x->varFrame, depth + 1); - eprintf("\n"); - printHashTable(x->tcFrame, depth + 1); - eprintf("\n"); - printTinContext(x->next, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinSubstitution(struct TinSubstitution * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinSubstitution (NULL)"); return; } - eprintf("TinSubstitution[\n"); - printHashTable(x->map, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinArgsResult(struct TinArgsResult * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinArgsResult (NULL)"); return; } - eprintf("TinArgsResult[\n"); - printTinContext(x->context, depth + 1); - eprintf("\n"); - printTinMonoTypeList(x->vec, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinVarResult(struct TinVarResult * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinVarResult (NULL)"); return; } - eprintf("TinVarResult[\n"); - printTinSubstitution(x->substitution, depth + 1); - eprintf("\n"); - printTinContext(x->context, depth + 1); - eprintf("\n"); - printTinMonoType(x->monoType, depth + 1); - eprintf("\n"); - printHashTable(x->set, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinVarsResult(struct TinVarsResult * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinVarsResult (NULL)"); return; } - eprintf("TinVarsResult[\n"); - printTinContext(x->context, depth + 1); - eprintf("\n"); - printHashTable(x->set, depth + 1); - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinMonoType(struct TinMonoType * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinMonoType (NULL)"); return; } - eprintf("TinMonoType[\n"); - switch(x->type) { - case TINMONOTYPE_TYPE_VAR: - pad(depth + 1); - eprintf("TINMONOTYPE_TYPE_VAR\n"); - printTinSymbol(x->val.var, depth + 1); - break; - case TINMONOTYPE_TYPE_FUN: - pad(depth + 1); - eprintf("TINMONOTYPE_TYPE_FUN\n"); - printTinFunctionApplication(x->val.fun, depth + 1); - break; - default: - cant_happen("unrecognised type %d in printTinMonoType", x->type); - } - eprintf("\n"); - pad(depth); - eprintf("]"); -} - -void printTinPolyType(struct TinPolyType * x, int depth) { - pad(depth); - if (x == NULL) { eprintf("TinPolyType (NULL)"); return; } - eprintf("TinPolyType[\n"); - switch(x->type) { - case TINPOLYTYPE_TYPE_MONOTYPE: - pad(depth + 1); - eprintf("TINPOLYTYPE_TYPE_MONOTYPE\n"); - printTinMonoType(x->val.monoType, depth + 1); - break; - case TINPOLYTYPE_TYPE_QUANTIFIER: - pad(depth + 1); - eprintf("TINPOLYTYPE_TYPE_QUANTIFIER\n"); - printTinTypeQuantifier(x->val.quantifier, depth + 1); - break; - default: - cant_happen("unrecognised type %d in printTinPolyType", x->type); - } - eprintf("\n"); - pad(depth); - eprintf("]"); -} - diff --git a/src/debug_tin.h b/src/debug_tin.h deleted file mode 100644 index 9c6fafa..0000000 --- a/src/debug_tin.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef cekf_debug_tin_h -#define cekf_debug_tin_h -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * Type inference structures used by Algorithm W. - * - * generated from src/tin.yaml by makeAST.py - */ - -#include "tin_helper.h" - -void printTinFunctionApplication(struct TinFunctionApplication * x, int depth); -void printTinMonoTypeList(struct TinMonoTypeList * x, int depth); -void printTinTypeQuantifier(struct TinTypeQuantifier * x, int depth); -void printTinContext(struct TinContext * x, int depth); -void printTinSubstitution(struct TinSubstitution * x, int depth); -void printTinArgsResult(struct TinArgsResult * x, int depth); -void printTinVarResult(struct TinVarResult * x, int depth); -void printTinVarsResult(struct TinVarsResult * x, int depth); -void printTinMonoType(struct TinMonoType * x, int depth); -void printTinPolyType(struct TinPolyType * x, int depth); - -#endif diff --git a/src/debugging_on.h b/src/debugging_on.h index df521b9..72d895b 100644 --- a/src/debugging_on.h +++ b/src/debugging_on.h @@ -18,7 +18,7 @@ static int _debugInvocationId = 0; #define DEBUG(...) do { \ - eprintf("**** %s:%-5d ", __FILE__, __LINE__); \ + eprintf("*** %s:%-5d ", __FILE__, __LINE__); \ eprintf(__VA_ARGS__); eprintf("\n"); \ } while(0) #define ENTER(name) int _debugMyId = _debugInvocationId++; \ diff --git a/src/hash.c b/src/hash.c index d20ca57..f6e0d23 100644 --- a/src/hash.c +++ b/src/hash.c @@ -169,6 +169,14 @@ void hashSet(HashTable *table, HashSymbol *var, void *src) { } } +bool hashContains(HashTable *table, HashSymbol *var) { + DEBUG("hashContains(%s) [%d]", var->name, table->id); + if (table->count == 0) return false; + hash_t index = findEntry(table->keys, table->capacity, var); + if (table->keys[index] == NULL) return false; + return true; +} + bool hashGet(HashTable *table, HashSymbol *var, void *dest) { DEBUG("hashGet(%s) [%d]", var->name, table->id); IFDEBUG(printMemHeader("values", table->values)); diff --git a/src/hash.h b/src/hash.h index 349398e..5df9092 100644 --- a/src/hash.h +++ b/src/hash.h @@ -54,6 +54,7 @@ hash_t hashString(const char *string); HashTable *newHashTable(size_t valuesize, MarkHashValueFunction markfunction, PrintHashValueFunction printfunction); void hashSet(HashTable *table, struct HashSymbol *var, void *src); +bool hashContains(HashTable *table, HashSymbol *var); bool hashGet(HashTable *table, struct HashSymbol *var, void *dest); void copyHashTable(HashTable *to, HashTable *from); diff --git a/src/lambda_conversion.c b/src/lambda_conversion.c index 59375e6..6bc3ffe 100644 --- a/src/lambda_conversion.c +++ b/src/lambda_conversion.c @@ -27,7 +27,7 @@ #include "lambda_helper.h" #include "symbols.h" #include "tpmc_logic.h" -#include "debug_ast.h" +#include "ast_debug.h" #define ARG_CATEGORY_VAR 0 #define ARG_CATEGORY_CONST 1 diff --git a/src/debug_lambda.c b/src/lambda_debug.c similarity index 54% rename from src/debug_lambda.c rename to src/lambda_debug.c index c2531cd..b0366be 100644 --- a/src/debug_lambda.c +++ b/src/lambda_debug.c @@ -22,7 +22,7 @@ #include -#include "debug_lambda.h" +#include "lambda_debug.h" #include "bigint.h" static void pad(int depth) { eprintf("%*s", depth * 4, ""); } @@ -31,7 +31,7 @@ void printLamLam(struct LamLam * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamLam (NULL)"); return; } eprintf("LamLam[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->nargs); eprintf("\n"); printLamVarList(x->args, depth + 1); @@ -46,7 +46,7 @@ void printLamVarList(struct LamVarList * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamVarList (NULL)"); return; } eprintf("LamVarList[\n"); - printLambdaSymbol(x->var, depth + 1); + printLambdaSymbol(x->var, depth + 1); eprintf("\n"); printLamVarList(x->next, depth + 1); eprintf("\n"); @@ -180,7 +180,7 @@ void printLamApply(struct LamApply * x, int depth) { eprintf("LamApply[\n"); printLamExp(x->function, depth + 1); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->nargs); eprintf("\n"); printLamList(x->args, depth + 1); @@ -193,9 +193,9 @@ void printLamConstant(struct LamConstant * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamConstant (NULL)"); return; } eprintf("LamConstant[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->tag); eprintf("\n"); pad(depth); @@ -206,9 +206,9 @@ void printLamConstruct(struct LamConstruct * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamConstruct (NULL)"); return; } eprintf("LamConstruct[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->tag); eprintf("\n"); printLamList(x->args, depth + 1); @@ -221,9 +221,9 @@ void printLamDeconstruct(struct LamDeconstruct * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamDeconstruct (NULL)"); return; } eprintf("LamDeconstruct[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->vec); eprintf("\n"); printLamExp(x->exp, depth + 1); @@ -236,7 +236,7 @@ void printLamMakeVec(struct LamMakeVec * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamMakeVec (NULL)"); return; } eprintf("LamMakeVec[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->nargs); eprintf("\n"); printLamList(x->args, depth + 1); @@ -275,7 +275,7 @@ void printLamIntCondCases(struct LamIntCondCases * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamIntCondCases (NULL)"); return; } eprintf("LamIntCondCases[\n"); - printBigInt(x->constant, depth + 1); + printBigInt(x->constant, depth + 1); eprintf("\n"); printLamExp(x->body, depth + 1); eprintf("\n"); @@ -289,7 +289,7 @@ void printLamCharCondCases(struct LamCharCondCases * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamCharCondCases (NULL)"); return; } eprintf("LamCharCondCases[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("char %c", x->constant); eprintf("\n"); printLamExp(x->body, depth + 1); @@ -330,7 +330,7 @@ void printLamIntList(struct LamIntList * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamIntList (NULL)"); return; } eprintf("LamIntList[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->item); eprintf("\n"); printLamIntList(x->next, depth + 1); @@ -343,7 +343,7 @@ void printLamLet(struct LamLet * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamLet (NULL)"); return; } eprintf("LamLet[\n"); - printLambdaSymbol(x->var, depth + 1); + printLambdaSymbol(x->var, depth + 1); eprintf("\n"); printLamExp(x->value, depth + 1); eprintf("\n"); @@ -357,7 +357,7 @@ void printLamLetRec(struct LamLetRec * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamLetRec (NULL)"); return; } eprintf("LamLetRec[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->nbindings); eprintf("\n"); printLamLetRecBindings(x->bindings, depth + 1); @@ -372,7 +372,7 @@ void printLamLetRecBindings(struct LamLetRecBindings * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamLetRecBindings (NULL)"); return; } eprintf("LamLetRecBindings[\n"); - printLambdaSymbol(x->var, depth + 1); + printLambdaSymbol(x->var, depth + 1); eprintf("\n"); printLamExp(x->val, depth + 1); eprintf("\n"); @@ -386,7 +386,7 @@ void printLamContext(struct LamContext * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamContext (NULL)"); return; } eprintf("LamContext[\n"); - printHashTable(x->frame, depth + 1); + printHashTable(x->frame, depth + 1); eprintf("\n"); printLamContext(x->parent, depth + 1); eprintf("\n"); @@ -482,7 +482,7 @@ void printLamType(struct LamType * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamType (NULL)"); return; } eprintf("LamType[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); printLamTypeArgs(x->args, depth + 1); eprintf("\n"); @@ -494,7 +494,7 @@ void printLamTypeArgs(struct LamTypeArgs * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamTypeArgs (NULL)"); return; } eprintf("LamTypeArgs[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); printLamTypeArgs(x->next, depth + 1); eprintf("\n"); @@ -506,7 +506,7 @@ void printLamTypeConstructor(struct LamTypeConstructor * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamTypeConstructor (NULL)"); return; } eprintf("LamTypeConstructor[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); printLamType(x->type, depth + 1); eprintf("\n"); @@ -532,7 +532,7 @@ void printLamTypeFunction(struct LamTypeFunction * x, int depth) { pad(depth); if (x == NULL) { eprintf("LamTypeFunction (NULL)"); return; } eprintf("LamTypeFunction[\n"); - printLambdaSymbol(x->name, depth + 1); + printLambdaSymbol(x->name, depth + 1); eprintf("\n"); printLamTypeConstructorArgs(x->args, depth + 1); eprintf("\n"); @@ -546,16 +546,16 @@ void printLamTypeConstructorInfo(struct LamTypeConstructorInfo * x, int depth) { eprintf("LamTypeConstructorInfo[\n"); printLamTypeConstructor(x->type, depth + 1); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("bool %d", x->vec); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->arity); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->size); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->index); eprintf("\n"); pad(depth); @@ -575,18 +575,18 @@ void printLamExp(struct LamExp * x, int depth) { case LAMEXP_TYPE_VAR: pad(depth + 1); eprintf("LAMEXP_TYPE_VAR\n"); - printLambdaSymbol(x->val.var, depth + 1); + printLambdaSymbol(x->val.var, depth + 1); break; case LAMEXP_TYPE_STDINT: pad(depth + 1); eprintf("LAMEXP_TYPE_STDINT\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->val.stdint); break; case LAMEXP_TYPE_BIGINTEGER: pad(depth + 1); eprintf("LAMEXP_TYPE_BIGINTEGER\n"); - printBigInt(x->val.biginteger, depth + 1); + printBigInt(x->val.biginteger, depth + 1); break; case LAMEXP_TYPE_PRIM: pad(depth + 1); @@ -681,25 +681,25 @@ eprintf("int %d", x->val.stdint); case LAMEXP_TYPE_CHARACTER: pad(depth + 1); eprintf("LAMEXP_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("char %c", x->val.character); break; case LAMEXP_TYPE_BACK: pad(depth + 1); eprintf("LAMEXP_TYPE_BACK\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.back); break; case LAMEXP_TYPE_ERROR: pad(depth + 1); eprintf("LAMEXP_TYPE_ERROR\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.error); break; case LAMEXP_TYPE_COND_DEFAULT: pad(depth + 1); eprintf("LAMEXP_TYPE_COND_DEFAULT\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.cond_default); break; default: @@ -741,19 +741,19 @@ void printLamTypeConstructorType(struct LamTypeConstructorType * x, int depth) { case LAMTYPECONSTRUCTORTYPE_TYPE_INTEGER: pad(depth + 1); eprintf("LAMTYPECONSTRUCTORTYPE_TYPE_INTEGER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.integer); break; case LAMTYPECONSTRUCTORTYPE_TYPE_CHARACTER: pad(depth + 1); eprintf("LAMTYPECONSTRUCTORTYPE_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.character); break; case LAMTYPECONSTRUCTORTYPE_TYPE_VAR: pad(depth + 1); eprintf("LAMTYPECONSTRUCTORTYPE_TYPE_VAR\n"); - printLambdaSymbol(x->val.var, depth + 1); + printLambdaSymbol(x->val.var, depth + 1); break; case LAMTYPECONSTRUCTORTYPE_TYPE_FUNCTION: pad(depth + 1); @@ -768,3 +768,484 @@ eprintf("void * %p", x->val.character); eprintf("]"); } + +/***************************************/ + +bool eqLamLam(struct LamLam * a, struct LamLam * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->nargs != b->nargs) return false; + if (!eqLamVarList(a->args, b->args)) return false; + if (!eqLamExp(a->exp, b->exp)) return false; + return true; +} + +bool eqLamVarList(struct LamVarList * a, struct LamVarList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->var != b->var) return false; + if (!eqLamVarList(a->next, b->next)) return false; + return true; +} + +bool eqLamPrimApp(struct LamPrimApp * a, struct LamPrimApp * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + switch (a->type) { + case LAMPRIMOP_TYPE_ADD: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_SUB: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_MUL: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_DIV: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_MOD: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_POW: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_EQ: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_NE: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_GT: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_LT: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_GE: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_LE: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_VEC: + if (a != b) return false; + break; + case LAMPRIMOP_TYPE_XOR: + if (a != b) return false; + break; + } + if (!eqLamExp(a->exp1, b->exp1)) return false; + if (!eqLamExp(a->exp2, b->exp2)) return false; + return true; +} + +bool eqLamUnaryApp(struct LamUnaryApp * a, struct LamUnaryApp * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + switch (a->type) { + case LAMUNARYOP_TYPE_NEG: + if (a != b) return false; + break; + case LAMUNARYOP_TYPE_NOT: + if (a != b) return false; + break; + case LAMUNARYOP_TYPE_PRINT: + if (a != b) return false; + break; + } + if (!eqLamExp(a->exp, b->exp)) return false; + return true; +} + +bool eqLamSequence(struct LamSequence * a, struct LamSequence * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->exp, b->exp)) return false; + if (!eqLamSequence(a->next, b->next)) return false; + return true; +} + +bool eqLamList(struct LamList * a, struct LamList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->exp, b->exp)) return false; + if (!eqLamList(a->next, b->next)) return false; + return true; +} + +bool eqLamApply(struct LamApply * a, struct LamApply * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->function, b->function)) return false; + if (a->nargs != b->nargs) return false; + if (!eqLamList(a->args, b->args)) return false; + return true; +} + +bool eqLamConstant(struct LamConstant * a, struct LamConstant * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (a->tag != b->tag) return false; + return true; +} + +bool eqLamConstruct(struct LamConstruct * a, struct LamConstruct * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (a->tag != b->tag) return false; + if (!eqLamList(a->args, b->args)) return false; + return true; +} + +bool eqLamDeconstruct(struct LamDeconstruct * a, struct LamDeconstruct * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (a->vec != b->vec) return false; + if (!eqLamExp(a->exp, b->exp)) return false; + return true; +} + +bool eqLamMakeVec(struct LamMakeVec * a, struct LamMakeVec * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->nargs != b->nargs) return false; + if (!eqLamList(a->args, b->args)) return false; + return true; +} + +bool eqLamIff(struct LamIff * a, struct LamIff * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->condition, b->condition)) return false; + if (!eqLamExp(a->consequent, b->consequent)) return false; + if (!eqLamExp(a->alternative, b->alternative)) return false; + return true; +} + +bool eqLamCond(struct LamCond * a, struct LamCond * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->value, b->value)) return false; + if (!eqLamCondCases(a->cases, b->cases)) return false; + return true; +} + +bool eqLamIntCondCases(struct LamIntCondCases * a, struct LamIntCondCases * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->constant != b->constant) return false; + if (!eqLamExp(a->body, b->body)) return false; + if (!eqLamIntCondCases(a->next, b->next)) return false; + return true; +} + +bool eqLamCharCondCases(struct LamCharCondCases * a, struct LamCharCondCases * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->constant != b->constant) return false; + if (!eqLamExp(a->body, b->body)) return false; + if (!eqLamCharCondCases(a->next, b->next)) return false; + return true; +} + +bool eqLamMatch(struct LamMatch * a, struct LamMatch * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->index, b->index)) return false; + if (!eqLamMatchList(a->cases, b->cases)) return false; + return true; +} + +bool eqLamMatchList(struct LamMatchList * a, struct LamMatchList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamIntList(a->matches, b->matches)) return false; + if (!eqLamExp(a->body, b->body)) return false; + if (!eqLamMatchList(a->next, b->next)) return false; + return true; +} + +bool eqLamIntList(struct LamIntList * a, struct LamIntList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->item != b->item) return false; + if (!eqLamIntList(a->next, b->next)) return false; + return true; +} + +bool eqLamLet(struct LamLet * a, struct LamLet * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->var != b->var) return false; + if (!eqLamExp(a->value, b->value)) return false; + if (!eqLamExp(a->body, b->body)) return false; + return true; +} + +bool eqLamLetRec(struct LamLetRec * a, struct LamLetRec * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->nbindings != b->nbindings) return false; + if (!eqLamLetRecBindings(a->bindings, b->bindings)) return false; + if (!eqLamExp(a->body, b->body)) return false; + return true; +} + +bool eqLamLetRecBindings(struct LamLetRecBindings * a, struct LamLetRecBindings * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->var != b->var) return false; + if (!eqLamExp(a->val, b->val)) return false; + if (!eqLamLetRecBindings(a->next, b->next)) return false; + return true; +} + +bool eqLamContext(struct LamContext * a, struct LamContext * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->frame != b->frame) return false; + if (!eqLamContext(a->parent, b->parent)) return false; + return true; +} + +bool eqLamAnd(struct LamAnd * a, struct LamAnd * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->left, b->left)) return false; + if (!eqLamExp(a->right, b->right)) return false; + return true; +} + +bool eqLamOr(struct LamOr * a, struct LamOr * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->left, b->left)) return false; + if (!eqLamExp(a->right, b->right)) return false; + return true; +} + +bool eqLamAmb(struct LamAmb * a, struct LamAmb * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamExp(a->left, b->left)) return false; + if (!eqLamExp(a->right, b->right)) return false; + return true; +} + +bool eqLamTypeDefs(struct LamTypeDefs * a, struct LamTypeDefs * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamTypeDefList(a->typeDefs, b->typeDefs)) return false; + if (!eqLamExp(a->body, b->body)) return false; + return true; +} + +bool eqLamTypeDefList(struct LamTypeDefList * a, struct LamTypeDefList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamTypeDef(a->typeDef, b->typeDef)) return false; + if (!eqLamTypeDefList(a->next, b->next)) return false; + return true; +} + +bool eqLamTypeDef(struct LamTypeDef * a, struct LamTypeDef * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamType(a->type, b->type)) return false; + if (!eqLamTypeConstructorList(a->constructors, b->constructors)) return false; + return true; +} + +bool eqLamTypeConstructorList(struct LamTypeConstructorList * a, struct LamTypeConstructorList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamTypeConstructor(a->constructor, b->constructor)) return false; + if (!eqLamTypeConstructorList(a->next, b->next)) return false; + return true; +} + +bool eqLamType(struct LamType * a, struct LamType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqLamTypeArgs(a->args, b->args)) return false; + return true; +} + +bool eqLamTypeArgs(struct LamTypeArgs * a, struct LamTypeArgs * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqLamTypeArgs(a->next, b->next)) return false; + return true; +} + +bool eqLamTypeConstructor(struct LamTypeConstructor * a, struct LamTypeConstructor * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqLamType(a->type, b->type)) return false; + if (!eqLamTypeConstructorArgs(a->args, b->args)) return false; + return true; +} + +bool eqLamTypeConstructorArgs(struct LamTypeConstructorArgs * a, struct LamTypeConstructorArgs * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamTypeConstructorType(a->arg, b->arg)) return false; + if (!eqLamTypeConstructorArgs(a->next, b->next)) return false; + return true; +} + +bool eqLamTypeFunction(struct LamTypeFunction * a, struct LamTypeFunction * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqLamTypeConstructorArgs(a->args, b->args)) return false; + return true; +} + +bool eqLamTypeConstructorInfo(struct LamTypeConstructorInfo * a, struct LamTypeConstructorInfo * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqLamTypeConstructor(a->type, b->type)) return false; + if (a->vec != b->vec) return false; + if (a->arity != b->arity) return false; + if (a->size != b->size) return false; + if (a->index != b->index) return false; + return true; +} + +bool eqLamExp(struct LamExp * a, struct LamExp * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case LAMEXP_TYPE_LAM: + if (!eqLamLam(a->val.lam, b->val.lam)) return false; + break; + case LAMEXP_TYPE_VAR: + if (a->val.var != b->val.var) return false; + break; + case LAMEXP_TYPE_STDINT: + if (a->val.stdint != b->val.stdint) return false; + break; + case LAMEXP_TYPE_BIGINTEGER: + if (a->val.biginteger != b->val.biginteger) return false; + break; + case LAMEXP_TYPE_PRIM: + if (!eqLamPrimApp(a->val.prim, b->val.prim)) return false; + break; + case LAMEXP_TYPE_UNARY: + if (!eqLamUnaryApp(a->val.unary, b->val.unary)) return false; + break; + case LAMEXP_TYPE_LIST: + if (!eqLamSequence(a->val.list, b->val.list)) return false; + break; + case LAMEXP_TYPE_MAKEVEC: + if (!eqLamMakeVec(a->val.makeVec, b->val.makeVec)) return false; + break; + case LAMEXP_TYPE_CONSTRUCT: + if (!eqLamConstruct(a->val.construct, b->val.construct)) return false; + break; + case LAMEXP_TYPE_DECONSTRUCT: + if (!eqLamDeconstruct(a->val.deconstruct, b->val.deconstruct)) return false; + break; + case LAMEXP_TYPE_CONSTANT: + if (!eqLamConstant(a->val.constant, b->val.constant)) return false; + break; + case LAMEXP_TYPE_APPLY: + if (!eqLamApply(a->val.apply, b->val.apply)) return false; + break; + case LAMEXP_TYPE_IFF: + if (!eqLamIff(a->val.iff, b->val.iff)) return false; + break; + case LAMEXP_TYPE_CALLCC: + if (!eqLamExp(a->val.callcc, b->val.callcc)) return false; + break; + case LAMEXP_TYPE_LETREC: + if (!eqLamLetRec(a->val.letrec, b->val.letrec)) return false; + break; + case LAMEXP_TYPE_TYPEDEFS: + if (!eqLamTypeDefs(a->val.typedefs, b->val.typedefs)) return false; + break; + case LAMEXP_TYPE_LET: + if (!eqLamLet(a->val.let, b->val.let)) return false; + break; + case LAMEXP_TYPE_MATCH: + if (!eqLamMatch(a->val.match, b->val.match)) return false; + break; + case LAMEXP_TYPE_COND: + if (!eqLamCond(a->val.cond, b->val.cond)) return false; + break; + case LAMEXP_TYPE_AND: + if (!eqLamAnd(a->val.and, b->val.and)) return false; + break; + case LAMEXP_TYPE_OR: + if (!eqLamOr(a->val.or, b->val.or)) return false; + break; + case LAMEXP_TYPE_AMB: + if (!eqLamAmb(a->val.amb, b->val.amb)) return false; + break; + case LAMEXP_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case LAMEXP_TYPE_BACK: + if (a->val.back != b->val.back) return false; + break; + case LAMEXP_TYPE_ERROR: + if (a->val.error != b->val.error) return false; + break; + case LAMEXP_TYPE_COND_DEFAULT: + if (a->val.cond_default != b->val.cond_default) return false; + break; + default: + cant_happen("unrecognised type %d in eqLamExp", a->type); + } + return true; +} + +bool eqLamCondCases(struct LamCondCases * a, struct LamCondCases * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case LAMCONDCASES_TYPE_INTEGERS: + if (!eqLamIntCondCases(a->val.integers, b->val.integers)) return false; + break; + case LAMCONDCASES_TYPE_CHARACTERS: + if (!eqLamCharCondCases(a->val.characters, b->val.characters)) return false; + break; + default: + cant_happen("unrecognised type %d in eqLamCondCases", a->type); + } + return true; +} + +bool eqLamTypeConstructorType(struct LamTypeConstructorType * a, struct LamTypeConstructorType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case LAMTYPECONSTRUCTORTYPE_TYPE_INTEGER: + if (a->val.integer != b->val.integer) return false; + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_VAR: + if (a->val.var != b->val.var) return false; + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_FUNCTION: + if (!eqLamTypeFunction(a->val.function, b->val.function)) return false; + break; + default: + cant_happen("unrecognised type %d in eqLamTypeConstructorType", a->type); + } + return true; +} + diff --git a/src/debug_lambda.h b/src/lambda_debug.h similarity index 53% rename from src/debug_lambda.h rename to src/lambda_debug.h index 1eff672..052815f 100644 --- a/src/debug_lambda.h +++ b/src/lambda_debug.h @@ -1,5 +1,5 @@ -#ifndef cekf_debug_lambda_h -#define cekf_debug_lambda_h +#ifndef cekf_lambda_debug_h +#define cekf_lambda_debug_h /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -64,4 +64,43 @@ void printLamExp(struct LamExp * x, int depth); void printLamCondCases(struct LamCondCases * x, int depth); void printLamTypeConstructorType(struct LamTypeConstructorType * x, int depth); +bool eqLamLam(struct LamLam * a, struct LamLam * b); +bool eqLamVarList(struct LamVarList * a, struct LamVarList * b); +bool eqLamPrimApp(struct LamPrimApp * a, struct LamPrimApp * b); +bool eqLamUnaryApp(struct LamUnaryApp * a, struct LamUnaryApp * b); +bool eqLamSequence(struct LamSequence * a, struct LamSequence * b); +bool eqLamList(struct LamList * a, struct LamList * b); +bool eqLamApply(struct LamApply * a, struct LamApply * b); +bool eqLamConstant(struct LamConstant * a, struct LamConstant * b); +bool eqLamConstruct(struct LamConstruct * a, struct LamConstruct * b); +bool eqLamDeconstruct(struct LamDeconstruct * a, struct LamDeconstruct * b); +bool eqLamMakeVec(struct LamMakeVec * a, struct LamMakeVec * b); +bool eqLamIff(struct LamIff * a, struct LamIff * b); +bool eqLamCond(struct LamCond * a, struct LamCond * b); +bool eqLamIntCondCases(struct LamIntCondCases * a, struct LamIntCondCases * b); +bool eqLamCharCondCases(struct LamCharCondCases * a, struct LamCharCondCases * b); +bool eqLamMatch(struct LamMatch * a, struct LamMatch * b); +bool eqLamMatchList(struct LamMatchList * a, struct LamMatchList * b); +bool eqLamIntList(struct LamIntList * a, struct LamIntList * b); +bool eqLamLet(struct LamLet * a, struct LamLet * b); +bool eqLamLetRec(struct LamLetRec * a, struct LamLetRec * b); +bool eqLamLetRecBindings(struct LamLetRecBindings * a, struct LamLetRecBindings * b); +bool eqLamContext(struct LamContext * a, struct LamContext * b); +bool eqLamAnd(struct LamAnd * a, struct LamAnd * b); +bool eqLamOr(struct LamOr * a, struct LamOr * b); +bool eqLamAmb(struct LamAmb * a, struct LamAmb * b); +bool eqLamTypeDefs(struct LamTypeDefs * a, struct LamTypeDefs * b); +bool eqLamTypeDefList(struct LamTypeDefList * a, struct LamTypeDefList * b); +bool eqLamTypeDef(struct LamTypeDef * a, struct LamTypeDef * b); +bool eqLamTypeConstructorList(struct LamTypeConstructorList * a, struct LamTypeConstructorList * b); +bool eqLamType(struct LamType * a, struct LamType * b); +bool eqLamTypeArgs(struct LamTypeArgs * a, struct LamTypeArgs * b); +bool eqLamTypeConstructor(struct LamTypeConstructor * a, struct LamTypeConstructor * b); +bool eqLamTypeConstructorArgs(struct LamTypeConstructorArgs * a, struct LamTypeConstructorArgs * b); +bool eqLamTypeFunction(struct LamTypeFunction * a, struct LamTypeFunction * b); +bool eqLamTypeConstructorInfo(struct LamTypeConstructorInfo * a, struct LamTypeConstructorInfo * b); +bool eqLamExp(struct LamExp * a, struct LamExp * b); +bool eqLamCondCases(struct LamCondCases * a, struct LamCondCases * b); +bool eqLamTypeConstructorType(struct LamTypeConstructorType * a, struct LamTypeConstructorType * b); + #endif diff --git a/src/lambda_helper.h b/src/lambda_helper.h index 508abf4..13fd975 100644 --- a/src/lambda_helper.h +++ b/src/lambda_helper.h @@ -19,7 +19,7 @@ */ #include "lambda.h" -#include "debug_lambda.h" +#include "lambda_debug.h" #include "hash.h" #include "memory.h" diff --git a/src/main.c b/src/main.c index 2e75da4..b339afc 100644 --- a/src/main.c +++ b/src/main.c @@ -24,7 +24,6 @@ #include "common.h" #include "ast.h" #include "debug_ast.h" -#include "debug_tin.h" #include "debug_lambda.h" #include "lambda_conversion.h" #include "module.h" @@ -36,13 +35,12 @@ #include "debug.h" #include "bytecode.h" #include "desugaring.h" -#include "algorithm_W.h" -#include "tin.h" -#include "tin_helper.h" #include "hash.h" #include "lambda_pp.h" #include "anf.h" #include "bigint.h" +#include "tc_analyze.h" +#include "debug_tc.h" #ifdef DEBUG_RUN_TESTS #if DEBUG_RUN_TESTS == 1 @@ -75,29 +73,6 @@ int main(int argc, char *argv[]) { testTin(); } -#elif DEBUG_RUN_TESTS == 4 // testing algorithm W - -extern AstNest *result; - -int main(int argc, char *argv[]) { - initProtection(); - disableGC(); - if (argc < 2) { - eprintf("need filename\n"); - exit(1); - } - AstNest *result = pm_parseFile(argv[1]); - PROTECT(result); - enableGC(); - // quietPrintHashTable = true; - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); - if (hadErrors()) { - printf("(errors detected)\n"); - } -} - #else // testing lambda conversion extern AstNest *result; @@ -176,19 +151,21 @@ int main(int argc, char *argv[]) { PROTECT(mod); pmParseModule(mod); enableGC(); - // printAstNest(mod->nest, 0); - /* WResult *wr = */ (void) WTop(mod->nest); - validateLastAlloc(); - if (hadErrors()) { - printf("(errors detected)\n"); - exit(1); - } LamExp *exp = lamConvertNest(mod->nest, NULL); int save = PROTECT(exp); #ifdef DEBUG_LAMBDA_CONVERT ppLamExp(exp); eprintf("\n"); #endif + TcEnv *env = tc_init(); + PROTECT(env); + TcType *res = tc_analyze(exp, env); + if (hadErrors()) { + return 1; + } + PROTECT(res); + printTcType(res, 0); + eprintf("\n"); Exp *anfExp = anfNormalize(exp); PROTECT(anfExp); disableGC(); diff --git a/src/memory.c b/src/memory.c index 201263e..67aa00f 100644 --- a/src/memory.c +++ b/src/memory.c @@ -130,22 +130,20 @@ const char *typeName(ObjType type, void *p) { return "valuelist"; case OBJTYPE_HASHTABLE: return "hashtable"; - case OBJTYPE_WRESULT: - return "wresult"; case OBJTYPE_PROTECTION: return "protection"; case OBJTYPE_BIGINT: return "bigint"; case OBJTYPE_PMMODULE: return "pmmodule"; - TIN_OBJTYPE_CASES() - return typenameTinObj(type); AST_OBJTYPE_CASES() return typenameAstObj(type); LAMBDA_OBJTYPE_CASES() return typenameLambdaObj(type); TPMC_OBJTYPE_CASES() return typenameTpmcObj(type); + TC_OBJTYPE_CASES() + return typenameTcObj(type); default: cant_happen("unrecognised ObjType %d in typeName at %p", type, p); } @@ -356,18 +354,12 @@ void markObj(Header *h, int i) { case OBJTYPE_HASHSYMBOL: markHashSymbolObj(h); break; - case OBJTYPE_WRESULT: - markWResultObj(h); - break; case OBJTYPE_PROTECTION: markProtectionObj(h); break; case OBJTYPE_PMMODULE: markPmModule(h); break; - TIN_OBJTYPE_CASES() - markTinObj(h); - break; AST_OBJTYPE_CASES() markAstObj(h); break; @@ -377,6 +369,9 @@ void markObj(Header *h, int i) { TPMC_OBJTYPE_CASES() markTpmcObj(h); break; + TC_OBJTYPE_CASES() + markTcObj(h); + break; case OBJTYPE_BIGINT: markBigInt((BigInt *)h); break; @@ -438,18 +433,12 @@ void freeObj(Header *h) { case OBJTYPE_HASHSYMBOL: freeHashSymbolObj(h); break; - case OBJTYPE_WRESULT: - freeWResultObj(h); - break; case OBJTYPE_PROTECTION: freeProtectionObj(h); break; case OBJTYPE_PMMODULE: freePmModule(h); break; - TIN_OBJTYPE_CASES() - freeTinObj(h); - break; AST_OBJTYPE_CASES() freeAstObj(h); break; @@ -459,6 +448,9 @@ void freeObj(Header *h) { TPMC_OBJTYPE_CASES() freeTpmcObj(h); break; + TC_OBJTYPE_CASES() + freeTcObj(h); + break; default: cant_happen("unrecognised ObjType %d in freeObj at %p", h->type, (void *)h); } diff --git a/src/memory.h b/src/memory.h index f8e8a43..d5c666d 100644 --- a/src/memory.h +++ b/src/memory.h @@ -24,9 +24,9 @@ struct Header; #include "ast_objtypes.h" -#include "tin_objtypes.h" #include "lambda_objtypes.h" #include "tpmc_objtypes.h" +#include "tc_objtypes.h" typedef enum { // exp types @@ -74,15 +74,14 @@ typedef enum { OBJTYPE_HASHTABLE, OBJTYPE_HASHSYMBOL, - OBJTYPE_WRESULT, OBJTYPE_PROTECTION, OBJTYPE_BIGINT, OBJTYPE_PMMODULE, AST_OBJTYPES(), - TIN_OBJTYPES(), LAMBDA_OBJTYPES(), TPMC_OBJTYPES(), + TC_OBJTYPES(), } ObjType; typedef struct Header { @@ -103,13 +102,11 @@ void markObj(Header *h, int i); void markExpObj(Header *x); void markCekfObj(Header *x); void markHashTableObj(Header *x); -void markWResultObj(Header *x); void freeObj(Header *h); void freeExpObj(Header *x); void freeCekfObj(Header *x); void freeHashTableObj(Header *x); -void freeWResultObj(Header *x); bool enableGC(void); bool disableGC(void); diff --git a/src/tc.c b/src/tc.c new file mode 100644 index 0000000..711b1c6 --- /dev/null +++ b/src/tc.c @@ -0,0 +1,301 @@ +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Structures to support type inference + * + * generated from src/tc.yaml by makeAST.py + */ + +#include "tc.h" +#include +#include +#include "common.h" +#ifdef DEBUG_ALLOC +#include "debugging_on.h" +#else +#include "debugging_off.h" +#endif + +struct TcEnv * newTcEnv(HashTable * table, struct TcEnv * next) { + struct TcEnv * x = NEW(TcEnv, OBJTYPE_TCENV); + DEBUG("new TcEnv %pn", x); + x->table = table; + x->next = next; + return x; +} + +struct TcNg * newTcNg(HashTable * table, struct TcNg * next) { + struct TcNg * x = NEW(TcNg, OBJTYPE_TCNG); + DEBUG("new TcNg %pn", x); + x->table = table; + x->next = next; + return x; +} + +struct TcFunction * newTcFunction(struct TcType * arg, struct TcType * result) { + struct TcFunction * x = NEW(TcFunction, OBJTYPE_TCFUNCTION); + DEBUG("new TcFunction %pn", x); + x->arg = arg; + x->result = result; + return x; +} + +struct TcPair * newTcPair(struct TcType * first, struct TcType * second) { + struct TcPair * x = NEW(TcPair, OBJTYPE_TCPAIR); + DEBUG("new TcPair %pn", x); + x->first = first; + x->second = second; + return x; +} + +struct TcTypeDef * newTcTypeDef(HashSymbol * name, struct TcTypeDefArgs * args) { + struct TcTypeDef * x = NEW(TcTypeDef, OBJTYPE_TCTYPEDEF); + DEBUG("new TcTypeDef %pn", x); + x->name = name; + x->args = args; + return x; +} + +struct TcTypeDefArgs * newTcTypeDefArgs(struct TcType * type, struct TcTypeDefArgs * next) { + struct TcTypeDefArgs * x = NEW(TcTypeDefArgs, OBJTYPE_TCTYPEDEFARGS); + DEBUG("new TcTypeDefArgs %pn", x); + x->type = type; + x->next = next; + return x; +} + +struct TcVar * newTcVar(HashSymbol * name) { + struct TcVar * x = NEW(TcVar, OBJTYPE_TCVAR); + DEBUG("new TcVar %pn", x); + x->name = name; + x->instance = NULL; + return x; +} + +struct TcType * newTcType(enum TcTypeType type, union TcTypeVal val) { + struct TcType * x = NEW(TcType, OBJTYPE_TCTYPE); + DEBUG("new TcType %pn", x); + x->type = type; + x->val = val; + return x; +} + + + +/************************************/ + +void markTcEnv(struct TcEnv * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markHashTable(x->table); + markTcEnv(x->next); +} + +void markTcNg(struct TcNg * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markHashTable(x->table); + markTcNg(x->next); +} + +void markTcFunction(struct TcFunction * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markTcType(x->arg); + markTcType(x->result); +} + +void markTcPair(struct TcPair * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markTcType(x->first); + markTcType(x->second); +} + +void markTcTypeDef(struct TcTypeDef * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markTcTypeDefArgs(x->args); +} + +void markTcTypeDefArgs(struct TcTypeDefArgs * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markTcType(x->type); + markTcTypeDefArgs(x->next); +} + +void markTcVar(struct TcVar * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + markTcType(x->instance); +} + +void markTcType(struct TcType * x) { + if (x == NULL) return; + if (MARKED(x)) return; + MARK(x); + switch(x->type) { + case TCTYPE_TYPE_FUNCTION: + markTcFunction(x->val.function); + break; + case TCTYPE_TYPE_PAIR: + markTcPair(x->val.pair); + break; + case TCTYPE_TYPE_VAR: + markTcVar(x->val.var); + break; + case TCTYPE_TYPE_INTEGER: + break; + case TCTYPE_TYPE_CHARACTER: + break; + case TCTYPE_TYPE_TYPEDEF: + markTcTypeDef(x->val.typeDef); + break; + default: + cant_happen("unrecognised type %d in markTcType", x->type); + } +} + + +void markTcObj(struct Header *h) { + switch(h->type) { + case OBJTYPE_TCENV: + markTcEnv((TcEnv *)h); + break; + case OBJTYPE_TCNG: + markTcNg((TcNg *)h); + break; + case OBJTYPE_TCFUNCTION: + markTcFunction((TcFunction *)h); + break; + case OBJTYPE_TCPAIR: + markTcPair((TcPair *)h); + break; + case OBJTYPE_TCTYPEDEF: + markTcTypeDef((TcTypeDef *)h); + break; + case OBJTYPE_TCTYPEDEFARGS: + markTcTypeDefArgs((TcTypeDefArgs *)h); + break; + case OBJTYPE_TCVAR: + markTcVar((TcVar *)h); + break; + case OBJTYPE_TCTYPE: + markTcType((TcType *)h); + break; + default: + cant_happen("unrecognised type %d in markTcObj\n", h->type); + } +} + +/************************************/ + +void freeTcEnv(struct TcEnv * x) { + FREE(x, TcEnv); +} + +void freeTcNg(struct TcNg * x) { + FREE(x, TcNg); +} + +void freeTcFunction(struct TcFunction * x) { + FREE(x, TcFunction); +} + +void freeTcPair(struct TcPair * x) { + FREE(x, TcPair); +} + +void freeTcTypeDef(struct TcTypeDef * x) { + FREE(x, TcTypeDef); +} + +void freeTcTypeDefArgs(struct TcTypeDefArgs * x) { + FREE(x, TcTypeDefArgs); +} + +void freeTcVar(struct TcVar * x) { + FREE(x, TcVar); +} + +void freeTcType(struct TcType * x) { + FREE(x, TcType); +} + + +void freeTcObj(struct Header *h) { + switch(h->type) { + case OBJTYPE_TCENV: + freeTcEnv((TcEnv *)h); + break; + case OBJTYPE_TCNG: + freeTcNg((TcNg *)h); + break; + case OBJTYPE_TCFUNCTION: + freeTcFunction((TcFunction *)h); + break; + case OBJTYPE_TCPAIR: + freeTcPair((TcPair *)h); + break; + case OBJTYPE_TCTYPEDEF: + freeTcTypeDef((TcTypeDef *)h); + break; + case OBJTYPE_TCTYPEDEFARGS: + freeTcTypeDefArgs((TcTypeDefArgs *)h); + break; + case OBJTYPE_TCVAR: + freeTcVar((TcVar *)h); + break; + case OBJTYPE_TCTYPE: + freeTcType((TcType *)h); + break; + default: + cant_happen("unrecognised type %d in freeTcObj\n", h->type); + } +} + +char *typenameTcObj(int type) { + switch(type) { + case OBJTYPE_TCENV: + return "TcEnv"; + case OBJTYPE_TCNG: + return "TcNg"; + case OBJTYPE_TCFUNCTION: + return "TcFunction"; + case OBJTYPE_TCPAIR: + return "TcPair"; + case OBJTYPE_TCTYPEDEF: + return "TcTypeDef"; + case OBJTYPE_TCTYPEDEFARGS: + return "TcTypeDefArgs"; + case OBJTYPE_TCVAR: + return "TcVar"; + case OBJTYPE_TCTYPE: + return "TcType"; + default: + cant_happen("unrecognised type %d in typenameTcObj\n", type); + } +} + diff --git a/src/tc.h b/src/tc.h new file mode 100644 index 0000000..31d2049 --- /dev/null +++ b/src/tc.h @@ -0,0 +1,137 @@ +#ifndef cekf_tc_h +#define cekf_tc_h +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Structures to support type inference + * + * generated from src/tc.yaml by makeAST.py + */ + +#include "hash.h" +#include "memory.h" +#include "common.h" + +typedef enum TcTypeType { + TCTYPE_TYPE_FUNCTION, + TCTYPE_TYPE_PAIR, + TCTYPE_TYPE_VAR, + TCTYPE_TYPE_INTEGER, + TCTYPE_TYPE_CHARACTER, + TCTYPE_TYPE_TYPEDEF, +} TcTypeType; + + + +typedef union TcTypeVal { + struct TcFunction * function; + struct TcPair * pair; + struct TcVar * var; + void * integer; + void * character; + struct TcTypeDef * typeDef; +} TcTypeVal; + + + +typedef struct TcEnv { + Header header; + HashTable * table; + struct TcEnv * next; +} TcEnv; + +typedef struct TcNg { + Header header; + HashTable * table; + struct TcNg * next; +} TcNg; + +typedef struct TcFunction { + Header header; + struct TcType * arg; + struct TcType * result; +} TcFunction; + +typedef struct TcPair { + Header header; + struct TcType * first; + struct TcType * second; +} TcPair; + +typedef struct TcTypeDef { + Header header; + HashSymbol * name; + struct TcTypeDefArgs * args; +} TcTypeDef; + +typedef struct TcTypeDefArgs { + Header header; + struct TcType * type; + struct TcTypeDefArgs * next; +} TcTypeDefArgs; + +typedef struct TcVar { + Header header; + HashSymbol * name; + struct TcType * instance; +} TcVar; + +typedef struct TcType { + Header header; + enum TcTypeType type; + union TcTypeVal val; +} TcType; + + + +struct TcEnv * newTcEnv(HashTable * table, struct TcEnv * next); +struct TcNg * newTcNg(HashTable * table, struct TcNg * next); +struct TcFunction * newTcFunction(struct TcType * arg, struct TcType * result); +struct TcPair * newTcPair(struct TcType * first, struct TcType * second); +struct TcTypeDef * newTcTypeDef(HashSymbol * name, struct TcTypeDefArgs * args); +struct TcTypeDefArgs * newTcTypeDefArgs(struct TcType * type, struct TcTypeDefArgs * next); +struct TcVar * newTcVar(HashSymbol * name); +struct TcType * newTcType(enum TcTypeType type, union TcTypeVal val); + +void markTcEnv(struct TcEnv * x); +void markTcNg(struct TcNg * x); +void markTcFunction(struct TcFunction * x); +void markTcPair(struct TcPair * x); +void markTcTypeDef(struct TcTypeDef * x); +void markTcTypeDefArgs(struct TcTypeDefArgs * x); +void markTcVar(struct TcVar * x); +void markTcType(struct TcType * x); + +void freeTcEnv(struct TcEnv * x); +void freeTcNg(struct TcNg * x); +void freeTcFunction(struct TcFunction * x); +void freeTcPair(struct TcPair * x); +void freeTcTypeDef(struct TcTypeDef * x); +void freeTcTypeDefArgs(struct TcTypeDefArgs * x); +void freeTcVar(struct TcVar * x); +void freeTcType(struct TcType * x); + + +#define TCTYPE_VAL_FUNCTION(x) ((union TcTypeVal ){.function = (x)}) +#define TCTYPE_VAL_PAIR(x) ((union TcTypeVal ){.pair = (x)}) +#define TCTYPE_VAL_VAR(x) ((union TcTypeVal ){.var = (x)}) +#define TCTYPE_VAL_INTEGER() ((union TcTypeVal ){.integer = (NULL)}) +#define TCTYPE_VAL_CHARACTER() ((union TcTypeVal ){.character = (NULL)}) +#define TCTYPE_VAL_TYPEDEF(x) ((union TcTypeVal ){.typeDef = (x)}) + + +#endif diff --git a/src/tc.yaml b/src/tc.yaml new file mode 100644 index 0000000..01da121 --- /dev/null +++ b/src/tc.yaml @@ -0,0 +1,81 @@ +# +# CEKF - VM supporting amb +# Copyright (C) 2022-2023 Bill Hails +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +config: + name: tc + +description: Structures to support type inference + +structs: + TcEnv: + table: HashTable + next: TcEnv + + TcNg: + table: HashTable + next: TcNg + + TcFunction: + arg: TcType + result: TcType + + TcPair: + first: TcType + second: TcType + + TcTypeDef: + name: HashSymbol + args: TcTypeDefArgs + + TcTypeDefArgs: + type: TcType + next: TcTypeDefArgs + + TcVar: + name: HashSymbol + instance: TcType=NULL + +unions: + TcType: + function: TcFunction + pair: TcPair + var: TcVar + integer: void_ptr + character: void_ptr + typeDef: TcTypeDef + +enums: {} + +primitives: + HashSymbol: + cname: "HashSymbol *" + printFn: printAstSymbol + valued: true + HashTable: + cname: "HashTable *" + printFn: printHashTable + markFn: markHashTable + valued: true + int: + cname: int + printf: "%d" + valued: true + void_ptr: + cname: "void *" + printf: "%p" + valued: false diff --git a/src/tc_analyze.c b/src/tc_analyze.c new file mode 100644 index 0000000..b573e87 --- /dev/null +++ b/src/tc_analyze.c @@ -0,0 +1,1295 @@ +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include + +#include "tc_analyze.h" +#include "symbols.h" +#include "symbol.h" +#include "memory.h" +#include "hash.h" +#include "tc_debug.h" +#include "tc_helper.h" + +#ifdef DEBUG_TC +#include "debugging_on.h" +#else +#include "debugging_off.h" +#endif + + +static TcEnv *extendEnv(TcEnv *parent); +static TcNg *extendNg(TcNg *parent); +static void addToEnv(TcEnv *env, HashSymbol *key, TcType *type); +static bool getFromEnv(TcEnv *env, HashSymbol *symbol, TcType **type); +static void addToNg(TcNg *env, HashSymbol *symbol, TcType *type); +static void addFreshVarToEnv(TcEnv *env, HashSymbol *key); +static void addCmpToEnv(TcEnv *env, HashSymbol *key); +static TcType *makeBoolean(void); +static TcType *makeInteger(void); +static TcType *makeCharacter(void); +static TcType *makeFreshVar(void); +static TcType *makeVar(HashSymbol *t); +static TcType *makeFn(TcType *arg, TcType *result); +static void addBoolBinOpToEnv(TcEnv *env, HashSymbol *symbol); +static void addHereToEnv(TcEnv *env); +static void addIfToEnv(TcEnv *env); +static void addIntBinOpToEnv(TcEnv *env, HashSymbol *symbol); +static void addNegToEnv(TcEnv *env); +static void addNotToEnv(TcEnv *env); +static void addThenToEnv(TcEnv *env); +static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng); +static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng); +static TcType *analyzeVar(HashSymbol *var, TcEnv *env, TcNg *ng); +static TcType *analyzeStdInt(int val, TcEnv *env, TcNg *ng); +static TcType *analyzeBigInteger(); +static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng); +static TcType *analyzeUnary(LamUnaryApp *app, TcEnv *env, TcNg *ng); +static TcType *analyzeSequence(LamSequence *sequence, TcEnv *env, TcNg *ng); +static TcType *analyzeConstruct(LamConstruct *construct, TcEnv *env, TcNg *ng); +static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, TcNg *ng); +static TcType *analyzeConstant(LamConstant *constant, TcEnv *env, TcNg *ng); +static TcType *analyzeApply(LamApply *apply, TcEnv *env, TcNg *ng); +static TcType *analyzeIff(LamIff *iff, TcEnv *env, TcNg *ng); +static TcType *analyzeCallCC(LamExp *called, TcEnv *env, TcNg *ng); +static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng); +static TcType *analyzeTypeDefs(LamTypeDefs *typeDefs, TcEnv *env, TcNg *ng); +static TcType *analyzeLet(LamLet *let, TcEnv *env, TcNg *ng); +static TcType *analyzeMatch(LamMatch *match, TcEnv *env, TcNg *ng); +static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng); +static TcType *analyzeAnd(LamAnd *and, TcEnv *env, TcNg *ng); +static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng); +static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng); +static TcType *analyzeCharacter(char c, TcEnv *env, TcNg *ng); +static TcType *analyzeBack(); +static TcType *analyzeError(); +static bool unify(TcType *a, TcType *b); +static TcType *prune(TcType *t); +static bool occursInType(TcType *a, TcType *b); +static bool occursIn(TcType *a, TcType *b); +static bool sameType(TcType *a, TcType *b); +static TcType *analyzeIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng); +static TcType *analyzeBooleanExp(LamExp *exp, TcEnv *env, TcNg *ng); +static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map); +static TcType *lookup(TcEnv *env, HashSymbol *symbol, TcNg *ng); +static TcType *makeTypeDef(HashSymbol *name, TcTypeDefArgs *args); + +TcEnv *tc_init(void) { + TcEnv *env = extendEnv(NULL); + int save = PROTECT(env); + addBoolBinOpToEnv(env, andSymbol()); + addBoolBinOpToEnv(env, orSymbol()); + addBoolBinOpToEnv(env, xorSymbol()); + addCmpToEnv(env, eqSymbol()); + addCmpToEnv(env, geSymbol()); + addCmpToEnv(env, gtSymbol()); + addCmpToEnv(env, leSymbol()); + addCmpToEnv(env, ltSymbol()); + addCmpToEnv(env, neSymbol()); + addFreshVarToEnv(env, backSymbol()); + addFreshVarToEnv(env, errorSymbol()); + addHereToEnv(env); + addIfToEnv(env); + addIntBinOpToEnv(env, addSymbol()); + addIntBinOpToEnv(env, divSymbol()); + addIntBinOpToEnv(env, mulSymbol()); + addIntBinOpToEnv(env, powSymbol()); + addIntBinOpToEnv(env, subSymbol()); + addNegToEnv(env); + addNotToEnv(env); + addThenToEnv(env); + UNPROTECT(save); + return env; +} + +TcType *tc_analyze(LamExp *exp, TcEnv *env) { + TcNg *ng = extendNg(NULL); + return prune(analyzeExp(exp, env, ng)); +} + +static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { + if (exp == NULL) return NULL; + switch(exp->type) { + case LAMEXP_TYPE_LAM: + return analyzeLam(exp->val.lam, env, ng); + case LAMEXP_TYPE_VAR: + return analyzeVar(exp->val.var, env, ng); + case LAMEXP_TYPE_STDINT: + return analyzeStdInt(exp->val.stdint, env, ng); + case LAMEXP_TYPE_BIGINTEGER: + return analyzeBigInteger(); + case LAMEXP_TYPE_PRIM: + return analyzePrim(exp->val.prim, env, ng); + case LAMEXP_TYPE_UNARY: + return analyzeUnary(exp->val.unary, env, ng); + case LAMEXP_TYPE_LIST: + return analyzeSequence(exp->val.list, env, ng); + case LAMEXP_TYPE_MAKEVEC: + cant_happen("encountered make-vec in analyzeLamExp"); + case LAMEXP_TYPE_CONSTRUCT: + return analyzeConstruct(exp->val.construct, env, ng); + case LAMEXP_TYPE_DECONSTRUCT: + return analyzeDeconstruct(exp->val.deconstruct, env, ng); + case LAMEXP_TYPE_CONSTANT: + return analyzeConstant(exp->val.constant, env, ng); + case LAMEXP_TYPE_APPLY: + return analyzeApply(exp->val.apply, env, ng); + case LAMEXP_TYPE_IFF: + return analyzeIff(exp->val.iff, env, ng); + case LAMEXP_TYPE_CALLCC: + return analyzeCallCC(exp->val.callcc, env, ng); + case LAMEXP_TYPE_LETREC: + return analyzeLetRec(exp->val.letrec, env, ng); + case LAMEXP_TYPE_TYPEDEFS: + return analyzeTypeDefs(exp->val.typedefs, env, ng); + case LAMEXP_TYPE_LET: + return analyzeLet(exp->val.let, env, ng); + case LAMEXP_TYPE_MATCH: + return analyzeMatch(exp->val.match, env, ng); + case LAMEXP_TYPE_COND: + return analyzeCond(exp->val.cond, env, ng); + case LAMEXP_TYPE_AND: + return analyzeAnd(exp->val.and, env, ng); + case LAMEXP_TYPE_OR: + return analyzeOr(exp->val.or, env, ng); + case LAMEXP_TYPE_AMB: + return analyzeAmb(exp->val.amb, env, ng); + case LAMEXP_TYPE_CHARACTER: + return analyzeCharacter(exp->val.character, env, ng); + case LAMEXP_TYPE_BACK: + return analyzeBack(); + case LAMEXP_TYPE_ERROR: + return analyzeError(); + case LAMEXP_TYPE_COND_DEFAULT: + cant_happen("encountered cond default in analyzeLamExp"); + default: + cant_happen("unrecognized type %d in analyzeLamExp", exp->type); + } +} + +static TcType *makeFunctionType(LamVarList *args, TcEnv *env, TcType *returnType) { + ENTER(makeFunctionType); + if (args == NULL) return returnType; + TcType *next = makeFunctionType(args->next, env, returnType); + int save = PROTECT(next); + TcType *this = NULL; + if (!getFromEnv(env, args->var, &this)) { + cant_happen("cannot find var in env in makeFunctionType"); + } + TcType *ret = makeFn(this, next); + UNPROTECT(save); + LEAVE(makeFunctionType); + return ret; +} + +static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng) { + ENTER(analyzeLam); + env = extendEnv(env); + int save = PROTECT(env); + ng = extendNg(ng); + PROTECT(ng); + for (LamVarList *args = lam->args; args != NULL; args = args->next) { + TcType *fresh = makeFreshVar(); + int save2 = PROTECT(fresh); + addToEnv(env, args->var, fresh); + addToNg(ng, fresh->val.var->name, fresh); + UNPROTECT(save2); + } + TcType *returnType = analyzeExp(lam->exp, env, ng); + PROTECT(returnType); + TcType *functionType = makeFunctionType(lam->args, env, returnType); + UNPROTECT(save); + LEAVE(analyzeLam); + return functionType; +} + +static TcType *analyzeVar(HashSymbol *var, TcEnv *env, TcNg *ng) { + ENTER(analyzeVar); + TcType *res = lookup(env, var, ng); + if (res == NULL) { + can_happen("undefined variable %s in analyzeVar", var->name); + } + LEAVE(analyzeVar); + return res; +} + +static TcType *analyzeStdInt(int val __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { + cant_happen("analyzeStdInt not implemented yet"); +} + +static TcType *analyzeBigInteger() { + ENTER(analyzeBigInteger); + TcType *res = makeInteger(); + LEAVE(analyzeBigInteger); + return res; +} + +static TcType *analyzeBinaryArith(LamExp *exp1, LamExp *exp2, TcEnv *env, TcNg *ng) { + ENTER(analyzeBinaryArith); + (void) analyzeIntegerExp(exp1, env, ng); + TcType *res = analyzeIntegerExp(exp2, env, ng); + LEAVE(analyzeBinaryArith); + return res; +} + +static TcType *analyzeComparison(LamExp *exp1, LamExp *exp2, TcEnv *env, TcNg *ng) { + ENTER(analyzeComparison); + TcType *type1 = analyzeExp(exp1, env, ng); + int save = PROTECT(type1); + TcType *type2 = analyzeExp(exp2, env, ng); + PROTECT(type2); + unify(type1, type2); + UNPROTECT(save); + TcType *res = makeBoolean(); + LEAVE(analyzeComparison); + return res; +} + +static TcType *analyzeBinaryBool(LamExp *exp1, LamExp *exp2, TcEnv *env, TcNg *ng) { + ENTER(analyzeBinaryBool); + (void) analyzeBooleanExp(exp1, env, ng); + TcType *res = analyzeBooleanExp(exp2, env, ng); + LEAVE(analyzeBinaryBool); + return res; +} + +static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng) { + ENTER(analyzePrim); + switch (app->type) { + case LAMPRIMOP_TYPE_ADD: + case LAMPRIMOP_TYPE_SUB: + case LAMPRIMOP_TYPE_MUL: + case LAMPRIMOP_TYPE_DIV: + case LAMPRIMOP_TYPE_MOD: + case LAMPRIMOP_TYPE_POW: { + TcType *res = analyzeBinaryArith(app->exp1, app->exp2, env, ng); + LEAVE(analyzePrim); + return res; + } + case LAMPRIMOP_TYPE_EQ: + case LAMPRIMOP_TYPE_NE: + case LAMPRIMOP_TYPE_GT: + case LAMPRIMOP_TYPE_LT: + case LAMPRIMOP_TYPE_GE: + case LAMPRIMOP_TYPE_LE: { + TcType *res = analyzeComparison(app->exp1, app->exp2, env, ng); + LEAVE(analyzePrim); + return res; + } + case LAMPRIMOP_TYPE_VEC: { + // Hacky, but the only time the type checker should encounter a literal vec + // is the (vec 0 ) to retrieve the tag of a constructor. All other calls to vec + // should be hidden behind (deconstruct ...) syntax + // plus in this case the exp2 arg will always be a simple var so analysis would not prove anything + TcType *res = makeInteger(); + LEAVE(analyzePrim); + return res; + } + case LAMPRIMOP_TYPE_XOR: { + TcType *res = analyzeBinaryBool(app->exp1, app->exp2, env, ng); + LEAVE(analyzePrim); + return res; + } + default: + cant_happen("unrecognised type %d in analyzePrim", app->type); + } +} + +static TcType *analyzeUnary(LamUnaryApp *app __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { + cant_happen("analyzeUnary not implemented yet"); +} + +static TcType *analyzeSequence(LamSequence *sequence, TcEnv *env, TcNg *ng) { + ENTER(analyzeSequence); + if (sequence == NULL) { + cant_happen("NULL sequence in analyzeSequence"); + } + TcType *type = analyzeExp(sequence->exp, env, ng); + if (sequence->next != NULL) { + TcType *res = analyzeSequence(sequence->next, env, ng); + LEAVE(analyzeSequence); + return res; + } + LEAVE(analyzeSequence); + return type; +} + +static int countLamList(LamList *list) { + int i = 0; + while (list != NULL) { + i++; + list = list->next; + } + return i; +} + +static LamApply *constructToApply(LamConstruct *construct) { + ENTER(constructToApply); + LamExp *constructor = newLamExp(LAMEXP_TYPE_VAR, LAMEXP_VAL_VAR(construct->name)); + int save = PROTECT(constructor); + LamApply *apply = newLamApply(constructor, countLamList(construct->args), construct->args); + UNPROTECT(save); + LEAVE(constructToApply); + return apply; +} + +static TcType *analyzeConstruct(LamConstruct *construct, TcEnv *env, TcNg *ng) { + ENTER(analyzeConstruct); + LamApply *apply = constructToApply(construct); + int save = PROTECT(apply); + TcType *res = analyzeApply(apply, env, ng); + UNPROTECT(save); + LEAVE(analyzeConstruct); + return res; +} + +static TcType *findNthArg(int n, TcType *fn) { + if (fn == NULL) { + cant_happen("findNthArg hit NULL"); + } + if (fn->type != TCTYPE_TYPE_FUNCTION) { + cant_happen("findNthArg given non-function type %d", fn->type); + } + if (n == 0) { + return fn->val.function->arg; + } + TcType *res = findNthArg(n - 1, fn->val.function->result); + return res; +} + +static TcType *findResultType(TcType *fn) { + if (fn == NULL) { + cant_happen("findResultType hit NULL"); + } + if (fn->type != TCTYPE_TYPE_FUNCTION) { + return fn; + } + TcType *res = findResultType(fn->val.function->result); + return res; +} + +static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, TcNg *ng) { + ENTER(analyzeDeconstruct); + TcType *constructor = NULL; + if (!getFromEnv(env, deconstruct->name, &constructor)) { + can_happen("undefined type deconstructor %s", deconstruct->name->name); + TcType *res = makeFreshVar(); + LEAVE(analyzeDeconstruct); + return res; + } + TcType *fieldType = findNthArg(deconstruct->vec - 1, constructor); + TcType *resultType = findResultType(constructor); + TcType *expType = analyzeExp(deconstruct->exp, env, ng); + int save = PROTECT(expType); + unify(expType, resultType); + UNPROTECT(save); + LEAVE(analyzeDeconstruct); + return fieldType; +} + +static TcType *analyzeConstant(LamConstant *constant, TcEnv *env, TcNg *ng) { + ENTER(analyzeConstant); + TcType *constType = lookup(env, constant->name, ng); + if (constType == NULL) { + can_happen("undefined constant %s", constant->name->name); + TcType *res = makeFreshVar(); + LEAVE(analyzeConstant); + return res; + } + LEAVE(analyzeConstant); + return constType; +} + +// apply(fn) => fn +// apply(fn, arg_1, arg_2, arg_3) => apply(apply(apply(fn, arg1), arg_2), arg_3) +static LamApply *curryLamApplyHelper(int nargs, LamExp *function, LamList *args) { + if (nargs == 1) { + LamApply *res = newLamApply(function, 1, args); + return res; + } + LamList *singleArg = newLamList(args->exp, NULL); + int save = PROTECT(singleArg); + LamApply *new = newLamApply(function, 1, singleArg); + PROTECT(new); + LamExp *newFunction = newLamExp(LAMEXP_TYPE_APPLY, LAMEXP_VAL_APPLY(new)); + PROTECT(newFunction); + LamApply *curried = curryLamApplyHelper(nargs - 1, newFunction, args->next); + UNPROTECT(save); + return curried; +} + +static LamApply *curryLamApply(LamApply *apply) { + return curryLamApplyHelper(apply->nargs, apply->function, apply->args); +} + +static TcType *analyzeApply(LamApply *apply, TcEnv *env, TcNg *ng) { + ENTER(analyzeApply); + switch (apply->nargs) { + case 0: { + TcType *res = analyzeExp(apply->function, env, ng); + LEAVE(analyzeApply); + return res; + } + case 1: { + TcType *fn = analyzeExp(apply->function, env, ng); + int save = PROTECT(fn); + TcType *arg = analyzeExp(apply->args->exp, env, ng); + PROTECT(arg); + TcType *res = makeFreshVar(); + PROTECT(res); + TcType *functionType = makeFn(arg, res); + PROTECT(functionType); + unify(fn, functionType); + UNPROTECT(save); + LEAVE(analyzeApply); + return res; + } + default:{ + LamApply *curried = curryLamApply(apply); + int save = PROTECT(curried); + TcType *res = analyzeApply(curried, env, ng); + UNPROTECT(save); + LEAVE(analyzeApply); + return res; + } + } +} + +static TcType *analyzeIff(LamIff *iff, TcEnv *env, TcNg *ng) { + ENTER(analyzeIff); + (void) analyzeBooleanExp(iff->condition, env, ng); + TcType *consequent = analyzeExp(iff->consequent, env, ng); + int save = PROTECT(consequent); + TcType *alternative = analyzeExp(iff->alternative, env, ng); + PROTECT(alternative); + unify(consequent, alternative); + UNPROTECT(save); + LEAVE(analyzeIff); + return consequent; +} + +static TcType *analyzeCallCC(LamExp *called __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { + cant_happen("analyzeCallCC not implemented yet"); +} + +static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { + ENTER(analyzeLetRec); + env = extendEnv(env); + int save = PROTECT(env); + ng = extendNg(ng); + PROTECT(ng); + for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { + TcType *fresh = makeFreshVar(); + int save2 = PROTECT(fresh); + addToEnv(env, bindings->var, fresh); + addToNg(ng, fresh->val.var->name, fresh); + UNPROTECT(save2); + } + for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { + TcType *fresh = lookup(env, bindings->var, ng); + if (fresh == NULL) { + cant_happen("failed to retrieve fresh var from env in analyzeLetRec"); + } + TcType *type = analyzeExp(bindings->val, env, ng); + int save2 = PROTECT(type); + unify(fresh, type); + UNPROTECT(save2); + } + printTcEnv(env, 0); + eprintf("\n"); + printTcNg(ng, 0); + eprintf("\n"); + TcType *res = analyzeExp(letRec->body, env, ng); + UNPROTECT(save); + LEAVE(analyzeLetRec); + return res; +} + +static TcTypeDefArgs *makeTcTypeDefArgs(LamTypeArgs *lamTypeArgs) { + if (lamTypeArgs == NULL) { + return NULL; + } + TcTypeDefArgs *next = makeTcTypeDefArgs(lamTypeArgs->next); + int save = PROTECT(next); + TcType *name = makeVar(lamTypeArgs->name); + PROTECT(name); + TcTypeDefArgs *this = newTcTypeDefArgs(name, next); + UNPROTECT(save); + return this; +} + +static TcType *makeTypeDef(HashSymbol *name, TcTypeDefArgs *args) { + TcTypeDef *tcTypeDef = newTcTypeDef(name, args); + int save = PROTECT(tcTypeDef); + TcType *res = newTcType(TCTYPE_TYPE_TYPEDEF, TCTYPE_VAL_TYPEDEF(tcTypeDef)); + UNPROTECT(save); + DEBUG("makeTypeDef: %s %p", name->name, res); + return res; +} + +static TcType *makeTcTypeDefType(LamType *lamType) { + TcTypeDefArgs *args = makeTcTypeDefArgs(lamType->args); + int save = PROTECT(args); + TcType *res = makeTypeDef(lamType->name, args); + UNPROTECT(save); + return res; +} + +static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg); + +static TcTypeDefArgs *makeTypeDefArgs(LamTypeConstructorArgs *args) { + if (args == NULL) { + return NULL; + } + TcTypeDefArgs *next = makeTypeDefArgs(args->next); + int save = PROTECT(next); + TcType *arg = makeTypeConstructorArg(args->arg); + PROTECT(arg); + TcTypeDefArgs *this = newTcTypeDefArgs(arg, next); + UNPROTECT(save); + return this; +} + +static TcType *makeTypeConstructorApplication(LamTypeFunction *func) { + // this code is building the inner application of a type, i.e. + // list(t) in the context of t -> list(t) -> list(t) + TcTypeDefArgs *args = makeTypeDefArgs(func->args); + if (args == NULL) { + cant_happen("null args to LamTypeDefFunction '%s' in makeTypeConstructorApplication", func->name->name); + } + int save = PROTECT(args); + TcType *res = makeTypeDef(func->name, args); + UNPROTECT(save); + return res; +} + +static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg) { + TcType *res = NULL; + switch (arg->type) { + case LAMTYPECONSTRUCTORTYPE_TYPE_INTEGER: + res = makeInteger(); + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_CHARACTER: + res = makeCharacter(); + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_VAR: + res = makeVar(arg->val.var); + break; + case LAMTYPECONSTRUCTORTYPE_TYPE_FUNCTION: + res = makeTypeConstructorApplication(arg->val.function); + break; + default: + cant_happen("unrecognised type %d in collectTypeConstructorArg", arg->type); + } + return res; +} + +static TcType *makeTypeDefConstructor(LamTypeConstructorArgs *args, TcType *result) { + // this code is building the top-level type of a type constructor, i.e. + // pair => t -> list(t) -> list(t) + if (args == NULL) { + return result; + } + TcType *next = makeTypeDefConstructor(args->next, result); + int save = PROTECT(next); + TcType *this = makeTypeConstructorArg(args->arg); + PROTECT(this); + TcType *res = makeFn(this, next); + UNPROTECT(save); + return res; +} + +static void collectTypeDefConstructor(LamTypeConstructor *constructor, TcType *type, TcEnv *env) { + TcType *res = makeTypeDefConstructor(constructor->args, type); + int save = PROTECT(res); + addToEnv(env, constructor->name, res); + UNPROTECT(save); +} + +static void collectTypeDef(LamTypeDef *lamTypeDef, TcEnv *env) { + LamType *lamType = lamTypeDef->type; + TcType *tcType = makeTcTypeDefType(lamType); + int save = PROTECT(tcType); + for (LamTypeConstructorList *list = lamTypeDef->constructors; list != NULL; list = list->next) { + collectTypeDefConstructor(list->constructor, tcType, env); + } + UNPROTECT(save); +} + +static TcType *analyzeTypeDefs(LamTypeDefs *typeDefs, TcEnv *env, TcNg *ng) { + ENTER(analyzeTypeDefs); + env = extendEnv(env); + int save = PROTECT(env); + for (LamTypeDefList *list = typeDefs->typeDefs; list != NULL; list = list->next) { + collectTypeDef(list->typeDef, env); + } + TcType *res = analyzeExp(typeDefs->body, env, ng); + UNPROTECT(save); + LEAVE(analyzeTypeDefs); + return res; +} + +static TcType *analyzeLet(LamLet *let, TcEnv *env, TcNg *ng) { + ENTER(analyzeLet); + // let expression is evaluated in the current environment + TcType *valType = analyzeExp(let->value, env, ng); + int save = PROTECT(valType); + env = extendEnv(env); + PROTECT(env); + addToEnv(env, let->var, valType); + TcType *res = analyzeExp(let->body, env, ng); + UNPROTECT(save); + LEAVE(analyzeLet); + return res; +} + +static TcType *unifyMatchCases(LamMatchList *cases, TcEnv *env, TcNg *ng) { + ENTER(unifyMatchCases); + if (cases == NULL) { + TcType *res = makeFreshVar(); + LEAVE(unifyMatchCases); + return res; + } + TcType *rest = unifyMatchCases(cases->next, env, ng); + int save = PROTECT(rest); + TcType *this = analyzeExp(cases->body, env, ng); + PROTECT(this); + unify(this, rest); + UNPROTECT(save); + LEAVE(unifyMatchCases); + return this; +} + +static TcType *analyzeIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { + TcType *type = analyzeExp(exp, env, ng); + int save = PROTECT(type); + TcType *integer = makeInteger(); + PROTECT(integer); + unify(type, integer); + UNPROTECT(save); + return integer; +} + +static TcType *analyzeBooleanExp(LamExp *exp, TcEnv *env, TcNg *ng) { + TcType *type = analyzeExp(exp, env, ng); + int save = PROTECT(type); + TcType *boolean = makeBoolean(); + PROTECT(boolean); + unify(type, boolean); + UNPROTECT(save); + return boolean; +} + +static TcType *analyzeMatch(LamMatch *match, TcEnv *env, TcNg *ng) { + (void) analyzeIntegerExp(match->index, env, ng); + TcType *res = unifyMatchCases(match->cases, env, ng); + return res; +} + +static TcType *unifyIntCondCases(LamIntCondCases *cases, TcEnv *env, TcNg *ng) { + if (cases == NULL) return makeFreshVar(); + TcType *rest = unifyIntCondCases(cases->next, env, ng); + int save = PROTECT(rest); + TcType *this = analyzeExp(cases->body, env, ng); + PROTECT(this); + unify(this, rest); + UNPROTECT(save); + return this; +} + +static TcType *unifyCharCondCases(LamCharCondCases *cases, TcEnv *env, TcNg *ng) { + if (cases == NULL) return makeFreshVar(); + TcType *rest = unifyCharCondCases(cases->next, env, ng); + int save = PROTECT(rest); + TcType *this = analyzeExp(cases->body, env, ng); + PROTECT(this); + unify(this, rest); + UNPROTECT(save); + return this; +} + +static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng) { + TcType *result = NULL; + int save = PROTECT(result); + TcType *value = analyzeExp(cond->value, env, ng); + PROTECT(value); + switch (cond->cases->type) { + case LAMCONDCASES_TYPE_INTEGERS: { + TcType *integer = makeInteger(); + PROTECT(integer); + unify(value, integer); + result = unifyIntCondCases(cond->cases->val.integers, env, ng); + } + break; + case LAMCONDCASES_TYPE_CHARACTERS: { + TcType *character = makeCharacter(); + PROTECT(character); + unify(value, character); + result = unifyCharCondCases(cond->cases->val.characters, env, ng); + } + break; + default: + cant_happen("unrecognized type %d in analyzeCond", cond->cases->type); + } + UNPROTECT(save); + return result; +} + +static TcType *analyzeAnd(LamAnd *and __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { + cant_happen("analyzeAnd not implemented yet"); +} + +static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng) { + TcType *res = analyzeBinaryBool(or->left, or->right, env, ng); + return res; +} + +static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng) { + TcType *left = analyzeExp(amb->left, env, ng); + int save = PROTECT(left); + TcType *right = analyzeExp(amb->right, env, ng); + PROTECT(right); + unify(left, right); + UNPROTECT(save); + return left; +} + +static TcType *analyzeCharacter(char c __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { + cant_happen("analyzeCharacter not implemented yet"); +} + +static TcType *analyzeBack() { + TcType *res = makeFreshVar(); + return res; +} + +static TcType *analyzeError() { + TcType *res = makeFreshVar(); + return res; +} + + +static void markType(void *ptr) { + markTcType(*((TcType **) ptr)); +} + +static void printType(void *ptr, int depth) { + eprintf("%*s", depth * 4, ""); + ppTcType(*((TcType **) ptr)); +} + +static void addToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { + hashSet(env->table, symbol, &type); +} + +static bool getFromEnv(TcEnv *env, HashSymbol *symbol, TcType **type) { + if (env == NULL) { + return false; + } + if (hashGet(env->table, symbol, type)) { + return true; + } + bool res = getFromEnv(env->next, symbol, type); + return res; +} + +static HashTable *makeTypeMap() { + HashTable *res = newHashTable(sizeof(TcType *), markType, printType); + return res; +} + +static TcType *freshFunction(TcFunction *fn, TcNg *ng, HashTable *map) { + TcType *arg = freshRec(fn->arg, ng, map); + int save = PROTECT(arg); + TcType *result = freshRec(fn->result, ng, map); + PROTECT(result); + TcType *res = makeFn(arg, result); + UNPROTECT(save); + return res; +} + +static TcType *makePair(TcType *first, TcType *second) { + TcPair *resPair = newTcPair(first, second); + int save = PROTECT(resPair); + TcType *res = newTcType(TCTYPE_TYPE_PAIR, TCTYPE_VAL_PAIR(resPair)); + UNPROTECT(save); + DEBUG("makePair: %p", res); + return res; +} + +static TcType *freshPair(TcPair *pair, TcNg *ng, HashTable *map) { + TcType *first = freshRec(pair->first, ng, map); + int save = PROTECT(first); + TcType *second = freshRec(pair->second, ng, map); + PROTECT(second); + TcType *res = makePair(first, second); + UNPROTECT(save); + return res; +} + +static TcTypeDefArgs *freshTypeDefArgs(TcTypeDefArgs *args, TcNg *ng, HashTable *map) { + if (args == NULL) return NULL; + TcTypeDefArgs *next = freshTypeDefArgs(args->next, ng, map); + int save = PROTECT(next); + TcType *type = freshRec(args->type, ng, map); + PROTECT(type); + TcTypeDefArgs *this = newTcTypeDefArgs(type, next); + UNPROTECT(save); + return this; +} + +static TcType *freshTypeDef(TcTypeDef *typeDef, TcNg *ng, HashTable *map) { + TcTypeDefArgs *args = freshTypeDefArgs(typeDef->args, ng, map); + int save = PROTECT(args); + TcType *res = makeTypeDef(typeDef->name, args); + UNPROTECT(save); + return res; +} + +static bool isGeneric(TcType *typeVar, TcNg *ng) { + if (ng == NULL) { + return true; + } + int i = 0; + TcType *entry = NULL; + HashSymbol *s = NULL; + while ((s = iterateHashTable(ng->table, &i, &entry)) != NULL) { + if (occursInType(typeVar, entry)) { + return false; + } + } + bool res = isGeneric(typeVar, ng->next); + return res; +} + +static TcType *typeGetOrPut(HashTable *map, TcType *typeVar, TcType *defaultValue) { + HashSymbol *name = typeVar->val.var->name; + TcType *res = NULL; + if (hashGet(map, name, &res)) { + return res; + } + hashSet(map, name, defaultValue); + return defaultValue; +} + +static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map) { + type = prune(type); + switch (type->type) { + case TCTYPE_TYPE_FUNCTION: + TcType *res = freshFunction(type->val.function, ng, map); + return res; + case TCTYPE_TYPE_PAIR: { + TcType *res = freshPair(type->val.pair, ng, map); + return res; + } + case TCTYPE_TYPE_VAR: + if (isGeneric(type, ng)) { + TcType *freshVar = makeFreshVar(); + int save = PROTECT(freshVar); + TcType *res = typeGetOrPut(map, type, freshVar); + UNPROTECT(save); + return res; + } + return type; + case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_CHARACTER: + return type; + case TCTYPE_TYPE_TYPEDEF: { + TcType *res = freshTypeDef(type->val.typeDef, ng, map); + return res; + } + default: + cant_happen("unrecognised type %d in freshRec", type->type); + } +} + +static TcType *fresh(TcType *type, TcNg *ng) { + HashTable *map = makeTypeMap(); + int save = PROTECT(map); + TcType *res = freshRec(type, ng, map); + UNPROTECT(save); + return res; +} + +static TcType *lookup(TcEnv *env, HashSymbol *symbol, TcNg *ng) { + TcType *type = NULL; + if (getFromEnv(env, symbol, &type)) { + TcType *res = fresh(type, ng); + return res; + } + return NULL; +} + +static void addToNg(TcNg *env, HashSymbol *symbol, TcType *type) { + hashSet(env->table, symbol, &type); +} + +static TcType *makeBoolean() { + TcType *res = makeTypeDef(boolSymbol(), NULL); + return res; +} + +static TcType *makeFn(TcType *arg, TcType *result) { + TcFunction *fn = newTcFunction(arg, result); + int save = PROTECT(fn); + assert(fn != NULL); + TcType *type = newTcType(TCTYPE_TYPE_FUNCTION, TCTYPE_VAL_FUNCTION(fn)); + UNPROTECT(save); + DEBUG("makeFunction: %p", type); + return type; +} + +static TcEnv *extendEnv(TcEnv *parent) { + HashTable *table = newHashTable(sizeof(TcType *), markType, printType); + int save = PROTECT(table); + table->shortEntries = true; + TcEnv *env = newTcEnv(table, parent); + UNPROTECT(save); + return env; +} + +static TcNg *extendNg(TcNg *parent) { + HashTable *table = newHashTable(sizeof(TcType *), markType, printType); + int save = PROTECT(table); + table->shortEntries = true; + TcNg *ng = newTcNg(table, parent); + UNPROTECT(save); + return ng; +} + +static TcType *makeVar(HashSymbol *t) { + TcVar *var = newTcVar(t); + int save = PROTECT(var); + TcType *res = newTcType(TCTYPE_TYPE_VAR, TCTYPE_VAL_VAR(var)); + UNPROTECT(save); + DEBUG("makeVar %p", res); + return res; +} + +static TcType *makeFreshVar() { + return makeVar(genSym("t$")); +} + +static TcType *makeInteger() { + TcType *res = newTcType(TCTYPE_TYPE_INTEGER, TCTYPE_VAL_INTEGER()); + DEBUG("makeInteger %p", res); + return res; +} + +static TcType *makeCharacter() { + TcType *res = newTcType(TCTYPE_TYPE_CHARACTER, TCTYPE_VAL_CHARACTER()); + DEBUG("makeCharacter %p", res); + return res; +} + +static void addUnOpToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { + TcType *aa = makeFn(type, type); + int save = PROTECT(aa); + addToEnv(env, symbol, aa); + UNPROTECT(save); +} + +static void addNegToEnv(TcEnv *env) { + TcType *integer = makeInteger(); + int save = PROTECT(integer); + addUnOpToEnv(env, negSymbol(), integer); + UNPROTECT(save); +} + +static void addNotToEnv(TcEnv *env) { + TcType *boolean = makeBoolean(); + int save = PROTECT(boolean); + addUnOpToEnv(env, negSymbol(), boolean); + UNPROTECT(save); +} + +static void addIfToEnv(TcEnv *env) { + // 'if' is bool -> a -> a -> a + TcType *boolean = makeBoolean(); + int save = PROTECT(boolean); + TcType *a = makeFreshVar(); + (void) PROTECT(a); + TcType *aa = makeFn(a, a); + (void) PROTECT(aa); + TcType *aaa = makeFn(a, aa); + (void) PROTECT(aaa); + TcType *baaa = makeFn(boolean, aaa); + (void) PROTECT(baaa); + addToEnv(env, ifSymbol(), baaa); + UNPROTECT(save); +} + +static void addHereToEnv(TcEnv *env) { + // 'call/cc' is ((a -> b) -> a) -> a + TcType *a = makeFreshVar(); + int save = PROTECT(a); + TcType *b = makeFreshVar(); + (void) PROTECT(b); + TcType *ab = makeFn(a, b); + (void) PROTECT(ab); + TcType *aba = makeFn(ab, a); + (void) PROTECT(aba); + TcType *abaa = makeFn(aba, a); + (void) PROTECT(abaa); + addToEnv(env, hereSymbol(), abaa); + UNPROTECT(save); +} + +static void addCmpToEnv(TcEnv *env, HashSymbol *symbol) { + // all binary comparisons are a -> a -> bool + TcType *freshVar = makeFreshVar(); + int save = PROTECT(freshVar); + TcType *boolean = makeBoolean(); + (void) PROTECT(boolean); + TcType *unOp = makeFn(freshVar, boolean); + (void) PROTECT(unOp); + TcType *binOp = makeFn(freshVar, unOp); + (void) PROTECT(binOp); + addToEnv(env, symbol, binOp); + UNPROTECT(save); +} + +static void addFreshVarToEnv(TcEnv *env, HashSymbol *symbol) { + // 'error' and 'back' both have unconstrained types + TcType *freshVar = makeFreshVar(); + int save = PROTECT(freshVar); + addToEnv(env, symbol, freshVar); + UNPROTECT(save); +} + +static void addBinOpToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { + // handle all fonctions of the form a -> a -> a + TcType *unOp = makeFn(type, type); + int save = PROTECT(unOp); + TcType *binOp = makeFn(type, unOp); + (void) PROTECT(binOp); + addToEnv(env, symbol, binOp); + UNPROTECT(save); +} + +static void addIntBinOpToEnv(TcEnv *env, HashSymbol *symbol) { + // int -> int -> int + TcType *integer = makeInteger(); + int save = PROTECT(integer); + addBinOpToEnv(env, symbol, integer); + UNPROTECT(save); +} + +static void addBoolBinOpToEnv(TcEnv *env, HashSymbol *symbol) { + // bool -> bool -> bool + TcType *boolean = makeBoolean(); + int save = PROTECT(boolean); + addBinOpToEnv(env, symbol, boolean); + UNPROTECT(save); +} + +static void addThenToEnv(TcEnv *env) { + // a -> a -> a + TcType *freshVar = makeFreshVar(); + int save = PROTECT(freshVar); + addBinOpToEnv(env, thenSymbol(), freshVar); + UNPROTECT(save); +} + +static bool unifyFunctions(TcFunction *a, TcFunction *b) { + bool res = unify(a->arg, b->arg) && unify(a->result, b->result); + return res; +} + +static bool unifyPairs(TcPair *a, TcPair *b) { + bool res = unify(a->first, b->first) && unify(a->second, b->second); + return res; +} + +static bool unifyTypeDefs(TcTypeDef *a, TcTypeDef *b) { + if (a->name != b->name) { + can_happen("unification failed"); + ppTcTypeDef(a); + eprintf(" vs "); + ppTcTypeDef(b); + eprintf("\n"); + return false; + } + TcTypeDefArgs *aArgs = a->args; + TcTypeDefArgs *bArgs = b->args; + while (aArgs != NULL && bArgs != NULL) { + if (!unify(aArgs->type, bArgs->type)) { + return false; + } + aArgs = aArgs->next; + bArgs = bArgs->next; + } + if (aArgs != NULL || bArgs != NULL) { + can_happen("unification failed"); + ppTcTypeDef(a); + eprintf(" vs "); + ppTcTypeDef(b); + eprintf("\n"); + return false; + } + return true; +} + +static bool unify(TcType *a, TcType *b) { + a = prune(a); + b = prune(b); + if (a->type == TCTYPE_TYPE_VAR) { + if (b->type != TCTYPE_TYPE_VAR) { + if (occursInType(a, b)) { + can_happen("occurs-in check failed"); + return false; + } + a->val.var->instance = b; + return true; + } + if (a->val.var->name != b->val.var->name) { + a->val.var->instance = b; + } + return true; + } else if (b->type == TCTYPE_TYPE_VAR) { + return unify(b, a); + } else { + if (a->type != b->type) { + can_happen("unification failed"); + ppTcType(a); + eprintf(" vs "); + ppTcType(b); + eprintf("\n"); + return false; + } + switch (a->type) { + case TCTYPE_TYPE_FUNCTION: + return unifyFunctions(a->val.function, b->val.function); + case TCTYPE_TYPE_PAIR: + return unifyPairs(a->val.pair, b->val.pair); + case TCTYPE_TYPE_VAR: + cant_happen("encountered var in unify"); + case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_CHARACTER: + return true; + case TCTYPE_TYPE_TYPEDEF: + return unifyTypeDefs(a->val.typeDef, b->val.typeDef); + default: + cant_happen("unrecognised type %d in unify", a->type); + } + } + cant_happen("reached end of unify"); +} + +static TcType *prune(TcType *t) { + if (t == NULL) return NULL; + if (t->type == TCTYPE_TYPE_VAR) { + if (t->val.var->instance != NULL) { + t->val.var->instance = prune(t->val.var->instance); + return t->val.var->instance; + } + } + return t; +} + +static bool sameFunctionType(TcFunction *a, TcFunction *b) { + return sameType(a->arg, b->arg) && sameType(a->result, b->result); +} + +static bool samePairType(TcPair *a, TcPair *b) { + return sameType(a->first, b->first) && sameType(a->second, b->second); +} + +static bool sameTypeDefType(TcTypeDef *a, TcTypeDef *b) { + if (a->name != b->name) { + return false; + } + TcTypeDefArgs *aArgs = a->args; + TcTypeDefArgs *bArgs = b->args; + while (aArgs != NULL && bArgs != NULL) { + if (!sameType(aArgs->type, bArgs->type)) return false; + aArgs = aArgs->next; + bArgs = bArgs->next; + } + if (aArgs != NULL || bArgs != NULL) { + return false; + } + return true; +} + +static bool sameType(TcType *a, TcType *b) { + if (a == NULL || b == NULL) { + cant_happen("NULL in sameType"); + } + if (a->type != b->type) { + return false; + } + switch (a->type) { + case TCTYPE_TYPE_FUNCTION: + return sameFunctionType(a->val.function, b->val.function); + case TCTYPE_TYPE_PAIR: + return samePairType(a->val.pair, b->val.pair); + case TCTYPE_TYPE_VAR: + return a->val.var->name == b->val.var->name; + case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_CHARACTER: + return true; + case TCTYPE_TYPE_TYPEDEF: + return sameTypeDefType(a->val.typeDef, b->val.typeDef); + default: + cant_happen("unrecognised type %d in sameType", a->type); + } +} + +static bool occursInType(TcType *a, TcType *b) { + b = prune(b); + if (b->type == TCTYPE_TYPE_VAR) { + return sameType(a, b); + } else { + return occursIn(a, b); + } +} + +static bool occursInFunction(TcType *var, TcFunction *fun) { + assert(fun != NULL); + return occursInType(var, fun->arg) || occursInType(var, fun->result); +} + +static bool occursInPair(TcType *var, TcPair *pair) { + return occursInType(var, pair->first) || occursInType(var, pair->second); +} + +static bool occursInTypeDef(TcType *var, TcTypeDef *typeDef) { + for (TcTypeDefArgs *args = typeDef->args; args != NULL; args = args->next) { + if (occursInType(var, args->type)) return true; + } + return false; +} + +static bool occursIn(TcType *a, TcType *b) { + switch(b->type) { + case TCTYPE_TYPE_FUNCTION: + return occursInFunction(a, b->val.function); + case TCTYPE_TYPE_PAIR: + return occursInPair(a, b->val.pair); + case TCTYPE_TYPE_VAR: + cant_happen("occursIn 2nd arg should not be a var"); + case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_CHARACTER: + return false; + case TCTYPE_TYPE_TYPEDEF: + return occursInTypeDef(a, b->val.typeDef); + default: + cant_happen("unrecognised type %d in occursIn", b->type); + } +} diff --git a/src/algorithm_W.h b/src/tc_analyze.h similarity index 66% rename from src/algorithm_W.h rename to src/tc_analyze.h index 48c59da..076e28f 100644 --- a/src/algorithm_W.h +++ b/src/tc_analyze.h @@ -1,5 +1,5 @@ -#ifndef cekf_algorithm_w_h -#define cekf_algorithm_w_h +#ifndef cekf_tc_analyze_h +#define cekf_tc_analyze_h /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -18,18 +18,10 @@ * along with this program. If not, see . */ -#include "tin.h" -#include "ast.h" +#include "tc.h" +#include "lambda.h" -typedef struct WResult { - struct Header header; - struct TinSubstitution *substitution; - struct TinMonoType *monoType; -} WResult; - -void markWResult(struct WResult *result); -void printWResult(struct WResult *result, int depth); - -struct WResult *WTop(struct AstNest *nest); +TcEnv *tc_init(void); +TcType *tc_analyze(LamExp *exp, TcEnv *env); #endif diff --git a/src/tc_debug.c b/src/tc_debug.c new file mode 100644 index 0000000..04e59ac --- /dev/null +++ b/src/tc_debug.c @@ -0,0 +1,245 @@ +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Structures to support type inference + * + * generated from src/tc.yaml by makeAST.py + */ + +#include + +#include "tc_debug.h" + +static void pad(int depth) { eprintf("%*s", depth * 4, ""); } + +void printTcEnv(struct TcEnv * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcEnv (NULL)"); return; } + eprintf("TcEnv[\n"); + printHashTable(x->table, depth + 1); + eprintf("\n"); + printTcEnv(x->next, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcNg(struct TcNg * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcNg (NULL)"); return; } + eprintf("TcNg[\n"); + printHashTable(x->table, depth + 1); + eprintf("\n"); + printTcNg(x->next, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcFunction(struct TcFunction * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcFunction (NULL)"); return; } + eprintf("TcFunction[\n"); + printTcType(x->arg, depth + 1); + eprintf("\n"); + printTcType(x->result, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcPair(struct TcPair * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcPair (NULL)"); return; } + eprintf("TcPair[\n"); + printTcType(x->first, depth + 1); + eprintf("\n"); + printTcType(x->second, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcTypeDef(struct TcTypeDef * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcTypeDef (NULL)"); return; } + eprintf("TcTypeDef[\n"); + printAstSymbol(x->name, depth + 1); + eprintf("\n"); + printTcTypeDefArgs(x->args, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcTypeDefArgs(struct TcTypeDefArgs * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcTypeDefArgs (NULL)"); return; } + eprintf("TcTypeDefArgs[\n"); + printTcType(x->type, depth + 1); + eprintf("\n"); + printTcTypeDefArgs(x->next, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcVar(struct TcVar * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcVar (NULL)"); return; } + eprintf("TcVar[\n"); + printAstSymbol(x->name, depth + 1); + eprintf("\n"); + printTcType(x->instance, depth + 1); + eprintf("\n"); + pad(depth); + eprintf("]"); +} + +void printTcType(struct TcType * x, int depth) { + pad(depth); + if (x == NULL) { eprintf("TcType (NULL)"); return; } + eprintf("TcType[\n"); + switch(x->type) { + case TCTYPE_TYPE_FUNCTION: + pad(depth + 1); + eprintf("TCTYPE_TYPE_FUNCTION\n"); + printTcFunction(x->val.function, depth + 1); + break; + case TCTYPE_TYPE_PAIR: + pad(depth + 1); + eprintf("TCTYPE_TYPE_PAIR\n"); + printTcPair(x->val.pair, depth + 1); + break; + case TCTYPE_TYPE_VAR: + pad(depth + 1); + eprintf("TCTYPE_TYPE_VAR\n"); + printTcVar(x->val.var, depth + 1); + break; + case TCTYPE_TYPE_INTEGER: + pad(depth + 1); + eprintf("TCTYPE_TYPE_INTEGER\n"); + pad(depth + 1); +eprintf("void * %p", x->val.integer); + break; + case TCTYPE_TYPE_CHARACTER: + pad(depth + 1); + eprintf("TCTYPE_TYPE_CHARACTER\n"); + pad(depth + 1); +eprintf("void * %p", x->val.character); + break; + case TCTYPE_TYPE_TYPEDEF: + pad(depth + 1); + eprintf("TCTYPE_TYPE_TYPEDEF\n"); + printTcTypeDef(x->val.typeDef, depth + 1); + break; + default: + cant_happen("unrecognised type %d in printTcType", x->type); + } + eprintf("\n"); + pad(depth); + eprintf("]"); +} + + +/***************************************/ + +bool eqTcEnv(struct TcEnv * a, struct TcEnv * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->table != b->table) return false; + if (!eqTcEnv(a->next, b->next)) return false; + return true; +} + +bool eqTcNg(struct TcNg * a, struct TcNg * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->table != b->table) return false; + if (!eqTcNg(a->next, b->next)) return false; + return true; +} + +bool eqTcFunction(struct TcFunction * a, struct TcFunction * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTcType(a->arg, b->arg)) return false; + if (!eqTcType(a->result, b->result)) return false; + return true; +} + +bool eqTcPair(struct TcPair * a, struct TcPair * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTcType(a->first, b->first)) return false; + if (!eqTcType(a->second, b->second)) return false; + return true; +} + +bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqTcTypeDefArgs(a->args, b->args)) return false; + return true; +} + +bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTcType(a->type, b->type)) return false; + if (!eqTcTypeDefArgs(a->next, b->next)) return false; + return true; +} + +bool eqTcVar(struct TcVar * a, struct TcVar * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqTcType(a->instance, b->instance)) return false; + return true; +} + +bool eqTcType(struct TcType * a, struct TcType * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case TCTYPE_TYPE_FUNCTION: + if (!eqTcFunction(a->val.function, b->val.function)) return false; + break; + case TCTYPE_TYPE_PAIR: + if (!eqTcPair(a->val.pair, b->val.pair)) return false; + break; + case TCTYPE_TYPE_VAR: + if (!eqTcVar(a->val.var, b->val.var)) return false; + break; + case TCTYPE_TYPE_INTEGER: + if (a->val.integer != b->val.integer) return false; + break; + case TCTYPE_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case TCTYPE_TYPE_TYPEDEF: + if (!eqTcTypeDef(a->val.typeDef, b->val.typeDef)) return false; + break; + default: + cant_happen("unrecognised type %d in eqTcType", a->type); + } + return true; +} + diff --git a/src/tc_debug.h b/src/tc_debug.h new file mode 100644 index 0000000..8b2f496 --- /dev/null +++ b/src/tc_debug.h @@ -0,0 +1,45 @@ +#ifndef cekf_tc_debug_h +#define cekf_tc_debug_h +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Structures to support type inference + * + * generated from src/tc.yaml by makeAST.py + */ + +#include "tc_helper.h" + +void printTcEnv(struct TcEnv * x, int depth); +void printTcNg(struct TcNg * x, int depth); +void printTcFunction(struct TcFunction * x, int depth); +void printTcPair(struct TcPair * x, int depth); +void printTcTypeDef(struct TcTypeDef * x, int depth); +void printTcTypeDefArgs(struct TcTypeDefArgs * x, int depth); +void printTcVar(struct TcVar * x, int depth); +void printTcType(struct TcType * x, int depth); + +bool eqTcEnv(struct TcEnv * a, struct TcEnv * b); +bool eqTcNg(struct TcNg * a, struct TcNg * b); +bool eqTcFunction(struct TcFunction * a, struct TcFunction * b); +bool eqTcPair(struct TcPair * a, struct TcPair * b); +bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b); +bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b); +bool eqTcVar(struct TcVar * a, struct TcVar * b); +bool eqTcType(struct TcType * a, struct TcType * b); + +#endif diff --git a/src/tc_helper.c b/src/tc_helper.c new file mode 100644 index 0000000..73fe837 --- /dev/null +++ b/src/tc_helper.c @@ -0,0 +1,86 @@ +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "tc_helper.h" + +void ppTcType(TcType *type) { + if (type == NULL) { + eprintf(""); + return; + } + switch (type->type) { + case TCTYPE_TYPE_FUNCTION: + ppTcFunction(type->val.function); + break; + case TCTYPE_TYPE_PAIR: + ppTcPair(type->val.pair); + break; + case TCTYPE_TYPE_VAR: + ppTcVar(type->val.var); + break; + case TCTYPE_TYPE_INTEGER: + eprintf("int"); + break; + case TCTYPE_TYPE_CHARACTER: + eprintf("char"); + break; + case TCTYPE_TYPE_TYPEDEF: + ppTcTypeDef(type->val.typeDef); + break; + default: + cant_happen("unrecognized type %d in ppTcType", type->type); + } +} + +void ppTcFunction(TcFunction *function) { + eprintf("("); + ppTcType(function->arg); + eprintf(") -> "); + ppTcType(function->result); +} + +void ppTcPair(TcPair *pair) { + eprintf("#("); + ppTcType(pair->first); + eprintf(", "); + ppTcType(pair->second); + eprintf(")"); +} + +void ppTcVar(TcVar *var) { + eprintf("<%s>", var->name->name); + if (var->instance != NULL) { + eprintf(" ["); + ppTcType(var->instance); + eprintf("]"); + } +} + +static void ppTypeDefArgs(TcTypeDefArgs *args) { + while (args != NULL) { + ppTcType(args->type); + if (args->next) eprintf(", "); + args = args->next; + } +} + +void ppTcTypeDef(TcTypeDef *typeDef) { + eprintf("%s(", typeDef->name->name); + ppTypeDefArgs(typeDef->args); + eprintf(")"); +} diff --git a/src/tc_helper.h b/src/tc_helper.h new file mode 100644 index 0000000..61e767b --- /dev/null +++ b/src/tc_helper.h @@ -0,0 +1,30 @@ +#ifndef cekf_tc_helper_h +#define cekf_tc_helper_h +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "ast_helper.h" +#include "tc.h" + +void ppTcType(TcType *type); +void ppTcFunction(TcFunction *function); +void ppTcPair(TcPair *pair); +void ppTcVar(TcVar *var); +void ppTcTypeDef(TcTypeDef *typeDef); + +#endif diff --git a/src/tc_objtypes.h b/src/tc_objtypes.h new file mode 100644 index 0000000..f3a4db0 --- /dev/null +++ b/src/tc_objtypes.h @@ -0,0 +1,49 @@ +#ifndef cekf_tc_objtypes_h +#define cekf_tc_objtypes_h +/* + * CEKF - VM supporting amb + * Copyright (C) 2022-2023 Bill Hails + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Structures to support type inference + * + * generated from src/tc.yaml by makeAST.py + */ + +#define TC_OBJTYPES() OBJTYPE_TCENV, \ +OBJTYPE_TCNG, \ +OBJTYPE_TCFUNCTION, \ +OBJTYPE_TCPAIR, \ +OBJTYPE_TCTYPEDEF, \ +OBJTYPE_TCTYPEDEFARGS, \ +OBJTYPE_TCVAR, \ +OBJTYPE_TCTYPE + +#define TC_OBJTYPE_CASES() \ +case OBJTYPE_TCENV:\ +case OBJTYPE_TCNG:\ +case OBJTYPE_TCFUNCTION:\ +case OBJTYPE_TCPAIR:\ +case OBJTYPE_TCTYPEDEF:\ +case OBJTYPE_TCTYPEDEFARGS:\ +case OBJTYPE_TCVAR:\ +case OBJTYPE_TCTYPE:\ + + +void markTcObj(struct Header *h); +void freeTcObj(struct Header *h); +char *typenameTcObj(int type); + +#endif diff --git a/src/tin.c b/src/tin.c deleted file mode 100644 index 8f3c273..0000000 --- a/src/tin.c +++ /dev/null @@ -1,362 +0,0 @@ -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * Type inference structures used by Algorithm W. - * - * generated from src/tin.yaml by makeAST.py - */ - -#include "tin.h" -#include -#include -#include "common.h" -#ifdef DEBUG_ALLOC -#include "debugging_on.h" -#else -#include "debugging_off.h" -#endif - -struct TinFunctionApplication * newTinFunctionApplication(HashSymbol * name, int nargs, struct TinMonoTypeList * args) { - struct TinFunctionApplication * x = NEW(TinFunctionApplication, OBJTYPE_TINFUNCTIONAPPLICATION); - DEBUG("new TinFunctionApplication %pn", x); - x->name = name; - x->nargs = nargs; - x->args = args; - return x; -} - -struct TinMonoTypeList * newTinMonoTypeList(struct TinMonoType * monoType, struct TinMonoTypeList * next) { - struct TinMonoTypeList * x = NEW(TinMonoTypeList, OBJTYPE_TINMONOTYPELIST); - DEBUG("new TinMonoTypeList %pn", x); - x->monoType = monoType; - x->next = next; - return x; -} - -struct TinTypeQuantifier * newTinTypeQuantifier(HashSymbol * var, struct TinPolyType * quantifiedType) { - struct TinTypeQuantifier * x = NEW(TinTypeQuantifier, OBJTYPE_TINTYPEQUANTIFIER); - DEBUG("new TinTypeQuantifier %pn", x); - x->var = var; - x->quantifiedType = quantifiedType; - return x; -} - -struct TinContext * newTinContext(HashTable * varFrame, HashTable * tcFrame, struct TinContext * next) { - struct TinContext * x = NEW(TinContext, OBJTYPE_TINCONTEXT); - DEBUG("new TinContext %pn", x); - x->varFrame = varFrame; - x->tcFrame = tcFrame; - x->next = next; - return x; -} - -struct TinSubstitution * newTinSubstitution(HashTable * map) { - struct TinSubstitution * x = NEW(TinSubstitution, OBJTYPE_TINSUBSTITUTION); - DEBUG("new TinSubstitution %pn", x); - x->map = map; - return x; -} - -struct TinArgsResult * newTinArgsResult(struct TinContext * context, struct TinMonoTypeList * vec) { - struct TinArgsResult * x = NEW(TinArgsResult, OBJTYPE_TINARGSRESULT); - DEBUG("new TinArgsResult %pn", x); - x->context = context; - x->vec = vec; - return x; -} - -struct TinVarResult * newTinVarResult(struct TinSubstitution * substitution, struct TinContext * context, struct TinMonoType * monoType, HashTable * set) { - struct TinVarResult * x = NEW(TinVarResult, OBJTYPE_TINVARRESULT); - DEBUG("new TinVarResult %pn", x); - x->substitution = substitution; - x->context = context; - x->monoType = monoType; - x->set = set; - return x; -} - -struct TinVarsResult * newTinVarsResult(struct TinContext * context, HashTable * set) { - struct TinVarsResult * x = NEW(TinVarsResult, OBJTYPE_TINVARSRESULT); - DEBUG("new TinVarsResult %pn", x); - x->context = context; - x->set = set; - return x; -} - -struct TinMonoType * newTinMonoType(enum TinMonoTypeType type, union TinMonoTypeVal val) { - struct TinMonoType * x = NEW(TinMonoType, OBJTYPE_TINMONOTYPE); - DEBUG("new TinMonoType %pn", x); - x->type = type; - x->val = val; - return x; -} - -struct TinPolyType * newTinPolyType(enum TinPolyTypeType type, union TinPolyTypeVal val) { - struct TinPolyType * x = NEW(TinPolyType, OBJTYPE_TINPOLYTYPE); - DEBUG("new TinPolyType %pn", x); - x->type = type; - x->val = val; - return x; -} - - - -/************************************/ - -void markTinFunctionApplication(struct TinFunctionApplication * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markHashSymbol(x->name); - markTinMonoTypeList(x->args); -} - -void markTinMonoTypeList(struct TinMonoTypeList * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markTinMonoType(x->monoType); - markTinMonoTypeList(x->next); -} - -void markTinTypeQuantifier(struct TinTypeQuantifier * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markHashSymbol(x->var); - markTinPolyType(x->quantifiedType); -} - -void markTinContext(struct TinContext * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markHashTable(x->varFrame); - markHashTable(x->tcFrame); - markTinContext(x->next); -} - -void markTinSubstitution(struct TinSubstitution * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markHashTable(x->map); -} - -void markTinArgsResult(struct TinArgsResult * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markTinContext(x->context); - markTinMonoTypeList(x->vec); -} - -void markTinVarResult(struct TinVarResult * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markTinSubstitution(x->substitution); - markTinContext(x->context); - markTinMonoType(x->monoType); - markHashTable(x->set); -} - -void markTinVarsResult(struct TinVarsResult * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - markTinContext(x->context); - markHashTable(x->set); -} - -void markTinMonoType(struct TinMonoType * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - switch(x->type) { - case TINMONOTYPE_TYPE_VAR: - markHashSymbol(x->val.var); - break; - case TINMONOTYPE_TYPE_FUN: - markTinFunctionApplication(x->val.fun); - break; - default: - cant_happen("unrecognised type %d in markTinMonoType", x->type); - } -} - -void markTinPolyType(struct TinPolyType * x) { - if (x == NULL) return; - if (MARKED(x)) return; - MARK(x); - switch(x->type) { - case TINPOLYTYPE_TYPE_MONOTYPE: - markTinMonoType(x->val.monoType); - break; - case TINPOLYTYPE_TYPE_QUANTIFIER: - markTinTypeQuantifier(x->val.quantifier); - break; - default: - cant_happen("unrecognised type %d in markTinPolyType", x->type); - } -} - - -void markTinObj(struct Header *h) { - switch(h->type) { - case OBJTYPE_TINFUNCTIONAPPLICATION: - markTinFunctionApplication((TinFunctionApplication *)h); - break; - case OBJTYPE_TINMONOTYPELIST: - markTinMonoTypeList((TinMonoTypeList *)h); - break; - case OBJTYPE_TINTYPEQUANTIFIER: - markTinTypeQuantifier((TinTypeQuantifier *)h); - break; - case OBJTYPE_TINCONTEXT: - markTinContext((TinContext *)h); - break; - case OBJTYPE_TINSUBSTITUTION: - markTinSubstitution((TinSubstitution *)h); - break; - case OBJTYPE_TINARGSRESULT: - markTinArgsResult((TinArgsResult *)h); - break; - case OBJTYPE_TINVARRESULT: - markTinVarResult((TinVarResult *)h); - break; - case OBJTYPE_TINVARSRESULT: - markTinVarsResult((TinVarsResult *)h); - break; - case OBJTYPE_TINMONOTYPE: - markTinMonoType((TinMonoType *)h); - break; - case OBJTYPE_TINPOLYTYPE: - markTinPolyType((TinPolyType *)h); - break; - default: - cant_happen("unrecognised type %d in markTinObj\n", h->type); - } -} - -/************************************/ - -void freeTinFunctionApplication(struct TinFunctionApplication * x) { - FREE(x, TinFunctionApplication); -} - -void freeTinMonoTypeList(struct TinMonoTypeList * x) { - FREE(x, TinMonoTypeList); -} - -void freeTinTypeQuantifier(struct TinTypeQuantifier * x) { - FREE(x, TinTypeQuantifier); -} - -void freeTinContext(struct TinContext * x) { - FREE(x, TinContext); -} - -void freeTinSubstitution(struct TinSubstitution * x) { - FREE(x, TinSubstitution); -} - -void freeTinArgsResult(struct TinArgsResult * x) { - FREE(x, TinArgsResult); -} - -void freeTinVarResult(struct TinVarResult * x) { - FREE(x, TinVarResult); -} - -void freeTinVarsResult(struct TinVarsResult * x) { - FREE(x, TinVarsResult); -} - -void freeTinMonoType(struct TinMonoType * x) { - FREE(x, TinMonoType); -} - -void freeTinPolyType(struct TinPolyType * x) { - FREE(x, TinPolyType); -} - - -void freeTinObj(struct Header *h) { - switch(h->type) { - case OBJTYPE_TINFUNCTIONAPPLICATION: - freeTinFunctionApplication((TinFunctionApplication *)h); - break; - case OBJTYPE_TINMONOTYPELIST: - freeTinMonoTypeList((TinMonoTypeList *)h); - break; - case OBJTYPE_TINTYPEQUANTIFIER: - freeTinTypeQuantifier((TinTypeQuantifier *)h); - break; - case OBJTYPE_TINCONTEXT: - freeTinContext((TinContext *)h); - break; - case OBJTYPE_TINSUBSTITUTION: - freeTinSubstitution((TinSubstitution *)h); - break; - case OBJTYPE_TINARGSRESULT: - freeTinArgsResult((TinArgsResult *)h); - break; - case OBJTYPE_TINVARRESULT: - freeTinVarResult((TinVarResult *)h); - break; - case OBJTYPE_TINVARSRESULT: - freeTinVarsResult((TinVarsResult *)h); - break; - case OBJTYPE_TINMONOTYPE: - freeTinMonoType((TinMonoType *)h); - break; - case OBJTYPE_TINPOLYTYPE: - freeTinPolyType((TinPolyType *)h); - break; - default: - cant_happen("unrecognised type %d in freeTinObj\n", h->type); - } -} - -char *typenameTinObj(int type) { - switch(type) { - case OBJTYPE_TINFUNCTIONAPPLICATION: - return "TinFunctionApplication"; - case OBJTYPE_TINMONOTYPELIST: - return "TinMonoTypeList"; - case OBJTYPE_TINTYPEQUANTIFIER: - return "TinTypeQuantifier"; - case OBJTYPE_TINCONTEXT: - return "TinContext"; - case OBJTYPE_TINSUBSTITUTION: - return "TinSubstitution"; - case OBJTYPE_TINARGSRESULT: - return "TinArgsResult"; - case OBJTYPE_TINVARRESULT: - return "TinVarResult"; - case OBJTYPE_TINVARSRESULT: - return "TinVarsResult"; - case OBJTYPE_TINMONOTYPE: - return "TinMonoType"; - case OBJTYPE_TINPOLYTYPE: - return "TinPolyType"; - default: - cant_happen("unrecognised type %d in typenameTinObj\n", type); - } -} - diff --git a/src/tin.h b/src/tin.h deleted file mode 100644 index 1fb013f..0000000 --- a/src/tin.h +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef cekf_tin_h -#define cekf_tin_h -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * Type inference structures used by Algorithm W. - * - * generated from src/tin.yaml by makeAST.py - */ - -#include "hash.h" -#include "memory.h" -#include "common.h" - -typedef enum TinMonoTypeType { - TINMONOTYPE_TYPE_VAR, - TINMONOTYPE_TYPE_FUN, -} TinMonoTypeType; - -typedef enum TinPolyTypeType { - TINPOLYTYPE_TYPE_MONOTYPE, - TINPOLYTYPE_TYPE_QUANTIFIER, -} TinPolyTypeType; - - - -typedef union TinMonoTypeVal { - HashSymbol * var; - struct TinFunctionApplication * fun; -} TinMonoTypeVal; - -typedef union TinPolyTypeVal { - struct TinMonoType * monoType; - struct TinTypeQuantifier * quantifier; -} TinPolyTypeVal; - - - -typedef struct TinFunctionApplication { - Header header; - HashSymbol * name; - int nargs; - struct TinMonoTypeList * args; -} TinFunctionApplication; - -typedef struct TinMonoTypeList { - Header header; - struct TinMonoType * monoType; - struct TinMonoTypeList * next; -} TinMonoTypeList; - -typedef struct TinTypeQuantifier { - Header header; - HashSymbol * var; - struct TinPolyType * quantifiedType; -} TinTypeQuantifier; - -typedef struct TinContext { - Header header; - HashTable * varFrame; - HashTable * tcFrame; - struct TinContext * next; -} TinContext; - -typedef struct TinSubstitution { - Header header; - HashTable * map; -} TinSubstitution; - -typedef struct TinArgsResult { - Header header; - struct TinContext * context; - struct TinMonoTypeList * vec; -} TinArgsResult; - -typedef struct TinVarResult { - Header header; - struct TinSubstitution * substitution; - struct TinContext * context; - struct TinMonoType * monoType; - HashTable * set; -} TinVarResult; - -typedef struct TinVarsResult { - Header header; - struct TinContext * context; - HashTable * set; -} TinVarsResult; - -typedef struct TinMonoType { - Header header; - enum TinMonoTypeType type; - union TinMonoTypeVal val; -} TinMonoType; - -typedef struct TinPolyType { - Header header; - enum TinPolyTypeType type; - union TinPolyTypeVal val; -} TinPolyType; - - - -struct TinFunctionApplication * newTinFunctionApplication(HashSymbol * name, int nargs, struct TinMonoTypeList * args); -struct TinMonoTypeList * newTinMonoTypeList(struct TinMonoType * monoType, struct TinMonoTypeList * next); -struct TinTypeQuantifier * newTinTypeQuantifier(HashSymbol * var, struct TinPolyType * quantifiedType); -struct TinContext * newTinContext(HashTable * varFrame, HashTable * tcFrame, struct TinContext * next); -struct TinSubstitution * newTinSubstitution(HashTable * map); -struct TinArgsResult * newTinArgsResult(struct TinContext * context, struct TinMonoTypeList * vec); -struct TinVarResult * newTinVarResult(struct TinSubstitution * substitution, struct TinContext * context, struct TinMonoType * monoType, HashTable * set); -struct TinVarsResult * newTinVarsResult(struct TinContext * context, HashTable * set); -struct TinMonoType * newTinMonoType(enum TinMonoTypeType type, union TinMonoTypeVal val); -struct TinPolyType * newTinPolyType(enum TinPolyTypeType type, union TinPolyTypeVal val); - -void markTinFunctionApplication(struct TinFunctionApplication * x); -void markTinMonoTypeList(struct TinMonoTypeList * x); -void markTinTypeQuantifier(struct TinTypeQuantifier * x); -void markTinContext(struct TinContext * x); -void markTinSubstitution(struct TinSubstitution * x); -void markTinArgsResult(struct TinArgsResult * x); -void markTinVarResult(struct TinVarResult * x); -void markTinVarsResult(struct TinVarsResult * x); -void markTinMonoType(struct TinMonoType * x); -void markTinPolyType(struct TinPolyType * x); - -void freeTinFunctionApplication(struct TinFunctionApplication * x); -void freeTinMonoTypeList(struct TinMonoTypeList * x); -void freeTinTypeQuantifier(struct TinTypeQuantifier * x); -void freeTinContext(struct TinContext * x); -void freeTinSubstitution(struct TinSubstitution * x); -void freeTinArgsResult(struct TinArgsResult * x); -void freeTinVarResult(struct TinVarResult * x); -void freeTinVarsResult(struct TinVarsResult * x); -void freeTinMonoType(struct TinMonoType * x); -void freeTinPolyType(struct TinPolyType * x); - - -#define TINMONOTYPE_VAL_VAR(x) ((union TinMonoTypeVal ){.var = (x)}) -#define TINMONOTYPE_VAL_FUN(x) ((union TinMonoTypeVal ){.fun = (x)}) -#define TINPOLYTYPE_VAL_MONOTYPE(x) ((union TinPolyTypeVal ){.monoType = (x)}) -#define TINPOLYTYPE_VAL_QUANTIFIER(x) ((union TinPolyTypeVal ){.quantifier = (x)}) - - -#endif diff --git a/src/tin.yaml b/src/tin.yaml deleted file mode 100644 index a4829eb..0000000 --- a/src/tin.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# -# CEKF - VM supporting amb -# Copyright (C) 2022-2023 Bill Hails -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# - -config: - name: tin - -description: Type inference structures used by Algorithm W. - -structs: - TinFunctionApplication: - name: HashSymbol - nargs: int - args: TinMonoTypeList - - TinMonoTypeList: - monoType: TinMonoType - next: TinMonoTypeList - - TinTypeQuantifier: - var: HashSymbol - quantifiedType: TinPolyType - - TinContext: - varFrame: HashTable - tcFrame: HashTable - next: TinContext - - TinSubstitution: - map: HashTable - - TinArgsResult: - context: TinContext - vec: TinMonoTypeList - - TinVarResult: - substitution: TinSubstitution - context: TinContext - monoType: TinMonoType - set: HashTable - - TinVarsResult: - context: TinContext - set: HashTable - -unions: - TinMonoType: - var: HashSymbol - fun: TinFunctionApplication - - TinPolyType: - monoType: TinMonoType - quantifier: TinTypeQuantifier - -enums: {} - -primitives: - HashSymbol: - cname: "HashSymbol *" - printFn: "printTinSymbol" - markFn: "markHashSymbol" - valued: true - - HashTable: - cname: "HashTable *" - printFn: "printHashTable" - markFn: "markHashTable" - valued: true - - int: - cname: "int" - printf: "%d" - valued: true diff --git a/src/tin_helper.c b/src/tin_helper.c deleted file mode 100644 index 551828a..0000000 --- a/src/tin_helper.c +++ /dev/null @@ -1,781 +0,0 @@ -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -#include -#include -#include - -#include "common.h" -#include "tin_helper.h" -#include "debug_tin.h" -#include "symbol.h" - -static TinFunctionApplication *applyFunSubstitution(TinSubstitution *s, TinFunctionApplication *funApp); -static TinMonoType *applyVarSubstitution(TinSubstitution *s, TinMonoType *mtype); -static TinTypeQuantifier *applyQuantifierSubstitution(TinSubstitution *s, TinTypeQuantifier *tq); - -static TinMonoType *instantiateMonoType(TinMonoType *tmt, HashTable *map); -static TinMonoType *instantiateQuantifier(TinTypeQuantifier *ttq, HashTable *map); -static TinMonoType *instantiatePolyType(TinPolyType *tpt, HashTable *map); - -static void findMonoTypeVariables(HashTable *map, TinMonoType *monoType); - -static bool monoTypeContains(HashSymbol *var, TinMonoType *tmt); - -static void showTinSymbol(HashSymbol *symbol); -static void showTinTypeQuantifier(TinTypeQuantifier *quantifier); -static void showTinFun(TinFunctionApplication *fun); -static void showTinMonoTypeList(TinMonoTypeList *list); - -void printTinSymbol(struct HashSymbol * x, int depth) { - eprintf("%*s", depth * 4, ""); - if (x == NULL) { eprintf("TinSymbol (NULL)"); return; } - eprintf("TinSymbol[\"%s\"]", x->name); -} - -static void printSubstitutionFn(void *ptr, int depth) { - eprintf("%*s", depth * 4, ""); - showTinMonoType(*(TinMonoType **)ptr); -} - -static void markSubstitutionFn(void *ptr) { - markTinMonoType(*(TinMonoType **)ptr); -} - -static void printContextFn(void *ptr, int depth) { - eprintf("%*s", depth * 4, ""); - showTinPolyType(*(TinPolyType **)ptr); -} - -static void markContextFn(void *ptr) { - markTinPolyType(*(TinPolyType **)ptr); -} - -static void markVarFn(void *ptr) { - markHashSymbol(*(HashSymbol **) ptr); -} - -static void printVarFn(void *ptr, int depth __attribute__((unused))) { - printHashSymbol(*(HashSymbol **) ptr); -} - -static HashTable *newSubstitutionTable() { - HashTable *h = newHashTable( - sizeof(TinMonoType*), - markSubstitutionFn, - printSubstitutionFn - ); - h->shortEntries = true; - return h; -} - -static HashTable *newContextTcTable() { - HashTable *h = newHashTable(0, NULL, NULL); - h->shortEntries = true; - return h; -} - -static HashTable *newContextVarTable() { - HashTable *h = newHashTable( - sizeof(TinPolyType*), - markContextFn, - printContextFn - ); - h->shortEntries = true; - return h; -} - -static HashTable *newVarTable() { - HashTable *h = newHashTable( - sizeof(HashSymbol *), - markVarFn, - printVarFn - ); - h->shortEntries = true; - return h; -} - -TinContext *freshTinContext() { - HashTable *vars = newContextVarTable(); - int save = PROTECT(vars); - HashTable *constructors = newContextTcTable(); - PROTECT(constructors); - TinContext *c = newTinContext(vars, constructors, NULL); - UNPROTECT(save); - return c; -} - -static HashTable *newFreeVariableTable() { - return newHashTable(0, NULL, NULL); -} - -void addToSubstitution(TinSubstitution *substitution, HashSymbol *symbol, TinMonoType *monoType) { - hashSet(substitution->map, symbol, &monoType); -} - -void addVarToContext(TinContext *context, HashSymbol *symbol, TinPolyType *polyType) { - hashSet(context->varFrame, symbol, &polyType); -} - -void addConstructorToContext(TinContext *context, HashSymbol *symbol, TinPolyType *polyType) { - hashSet(context->varFrame, symbol, &polyType); - hashSet(context->tcFrame, symbol, NULL); -} - -TinPolyType *lookupInContext(TinContext *context, HashSymbol *var) { - if (context == NULL) return NULL; - TinPolyType *result; - if (hashGet(context->varFrame, var, &result)) { - return result; - } - return lookupInContext(context->next, var); -} - -bool isTypeConstructor(TinContext *context, HashSymbol *var) { - if (context == NULL) return false; - if (hashGet(context->tcFrame, var, NULL)) { - return true; - } - return isTypeConstructor(context->next, var); -} - -static TinMonoType *lookupInSubstitution(TinSubstitution *substitution, HashSymbol *var) { - TinMonoType *result; - if (hashGet(substitution->map, var, &result)) return result; - return NULL; -} - -static HashSymbol *lookupInMap(HashTable *map, HashSymbol *symbol) { - HashSymbol *result; - if (hashGet(map, symbol, &result)) return result; - return NULL; -} - -static void setInMap(HashTable *map, HashSymbol *key, HashSymbol *value) { - hashSet(map, key, &value); -} - -static void addFreeVariable(HashTable *map, HashSymbol *key) { - hashSet(map, key, NULL); -} - -static TinContext *copyContext(TinContext *source) { - if (source == NULL) { - return freshTinContext(); - } else { - TinContext *context = copyContext(source->next); - int save = PROTECT(context); - copyHashTable(context->varFrame, source->varFrame); - copyHashTable(context->tcFrame, source->tcFrame); - UNPROTECT(save); - return context; - } -} - -TinContext *extendTinContext(TinContext *parent) { - HashTable *varFrame = newContextVarTable(); - int save = PROTECT(varFrame); - HashTable *ctFrame = newContextTcTable(); - PROTECT(ctFrame); - TinContext *context = newTinContext(varFrame, ctFrame, parent); - UNPROTECT(save); - return context; -} - -static TinSubstitution *copySubstitution(TinSubstitution *source) { - TinSubstitution *ts = makeEmptySubstitution(); - int save = PROTECT(ts); - copyHashTable(ts->map, source->map); - UNPROTECT(save); - return ts; -} - -static TinMonoTypeList *applyArgsSubstitution(TinSubstitution *s, TinMonoTypeList *args) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyArgsSubstitution "); - showTinSubstitution(s); - eprintf("\n"); - showTinMonoTypeList(args); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - if (args == NULL) return NULL; - TinMonoTypeList *soFar = applyArgsSubstitution(s, args->next); - int save = PROTECT(soFar); - TinMonoType *t = applyMonoTypeSubstitution(s, args->monoType); - if (soFar == args->next && t == args->monoType) { - UNPROTECT(save); - return args; - } - PROTECT(t); - TinMonoTypeList *current = newTinMonoTypeList(t, soFar); - UNPROTECT(save); - return current; -} - -static TinFunctionApplication *applyFunSubstitution(TinSubstitution *s, TinFunctionApplication *funApp) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyFunSubstitution "); - showTinFun(funApp); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinMonoTypeList *args = applyArgsSubstitution(s, funApp->args); - int save = PROTECT(args); - if (args == funApp->args) return funApp; - TinFunctionApplication *result = newTinFunctionApplication(funApp->name, funApp->nargs, args); - UNPROTECT(save); - return result; -} - -static TinMonoType *applyVarSubstitution(TinSubstitution *s, TinMonoType *mtype) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyVarSubstitution "); - showTinMonoType(mtype); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinMonoType *replacement = lookupInSubstitution(s, mtype->val.var); - if (replacement != NULL) return replacement; - return mtype; -} - -TinMonoType *applyMonoTypeSubstitution(TinSubstitution *s, TinMonoType *mtype) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyMonoTypeSubstitution "); - showTinMonoType(mtype); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - switch (mtype->type) { - case TINMONOTYPE_TYPE_VAR: - return applyVarSubstitution(s, mtype); - case TINMONOTYPE_TYPE_FUN: { - TinFunctionApplication *tfa = applyFunSubstitution(s, mtype->val.fun); - if (tfa == mtype->val.fun) return mtype; - int save = PROTECT(tfa); - TinMonoType *tmt = newTinMonoType(TINMONOTYPE_TYPE_FUN, TINMONOTYPE_VAL_FUN(tfa)); - UNPROTECT(save); - return tmt; - } - default: - cant_happen("unrecognised type %d in applyMonoTypeSubstitution", mtype->type); - } -} - -static TinTypeQuantifier *applyQuantifierSubstitution(TinSubstitution *s, TinTypeQuantifier *tq) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyQuantifierSubstitution "); - showTinTypeQuantifier(tq); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinPolyType *pt = applyPolyTypeSubstitution(s, tq->quantifiedType); - if (pt == tq->quantifiedType) return tq; - int save = PROTECT(pt); - TinTypeQuantifier *result = newTinTypeQuantifier(tq->var, pt); - UNPROTECT(save); - return result; -} - -TinPolyType *applyPolyTypeSubstitution(TinSubstitution *s, TinPolyType *ptype) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyPolyTypeSubstitution "); - showTinPolyType(ptype); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - switch (ptype->type) { - case TINPOLYTYPE_TYPE_MONOTYPE: { - TinMonoType *tmt = applyMonoTypeSubstitution(s, ptype->val.monoType); - if (tmt == ptype->val.monoType) return ptype; - int save = PROTECT(tmt); - TinPolyType *pt = newTinPolyType(TINPOLYTYPE_TYPE_MONOTYPE, TINPOLYTYPE_VAL_MONOTYPE(tmt)); - UNPROTECT(save); - return pt; - } - case TINPOLYTYPE_TYPE_QUANTIFIER: { - TinTypeQuantifier *ttq = applyQuantifierSubstitution(s, ptype->val.quantifier); - if (ttq == ptype->val.quantifier) return ptype; - int save = PROTECT(ttq); - TinPolyType *pt = newTinPolyType(TINPOLYTYPE_TYPE_QUANTIFIER, TINPOLYTYPE_VAL_QUANTIFIER(ttq)); - UNPROTECT(save); - return pt; - } - default: - cant_happen("unrecognised type %d in applyPolyTypeSubstitution", ptype->type); - } -} - -void applyContextSubstitutionInPlace(TinSubstitution *s, TinContext *context) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyContextSubstitutionInPlace "); - showTinContext(context); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - int i = 0; - TinPolyType *t = NULL; - HashSymbol *key; - while ((key = iterateHashTable(context->varFrame, &i, &t)) != NULL) { - TinPolyType *ts = applyPolyTypeSubstitution(s, t); - addVarToContext(context, key, ts); // we know it's already in the context so this won't change the table - } -} - -TinContext *applyContextSubstitution(TinSubstitution *s, TinContext *context) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applyContextSubstitution "); - showTinContext(context); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinContext *result = copyContext(context); - int save = PROTECT(result); - int i = 0; - TinPolyType *t = NULL; - HashSymbol *key; - while ((key = iterateHashTable(result->varFrame, &i, &t)) != NULL) { - TinPolyType *ts = applyPolyTypeSubstitution(s, t); - addVarToContext(result, key, ts); // we know it's already in the context so this won't change the table - } - UNPROTECT(save); - return result; -} - -TinSubstitution *applySubstitutionSubstitution(TinSubstitution *s1, TinSubstitution *s2) { -#ifdef DEBUG_TIN_SUBSTITUTION - eprintf("applySubstitutionSubstitution "); - showTinSubstitution(s2); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinSubstitution *ts = copySubstitution(s1); - int save = PROTECT(ts); - int i = 0; - TinMonoType *tmt1 = NULL; - HashSymbol *key; - while ((key = iterateHashTable(s2->map, &i, &tmt1)) != NULL) { - TinMonoType *tmt2 = applyMonoTypeSubstitution(s1, tmt1); - int save2 = PROTECT(tmt2); - addToSubstitution(ts, key, tmt2); - UNPROTECT(save2); - } - UNPROTECT(save); - return ts; -} - -HashSymbol *freshTypeVariable(const char *suffix) { - char buf[128]; - sprintf(buf, "#%s", suffix); - return genSym(buf); -} - -static HashSymbol *instantiateVar(HashSymbol *symbol, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiateVar "); - printHashSymbol(symbol); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - HashSymbol *replacement = lookupInMap(map, symbol); - if (replacement != NULL) return replacement; - return symbol; -} - -static TinMonoTypeList *instantiateArgs(TinMonoTypeList *args, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiateArgs "); - showTinMonoTypeList(args); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - if (args == NULL) return NULL; - TinMonoTypeList *next = instantiateArgs(args->next, map); - int save = PROTECT(next); - TinMonoType *monoType = instantiateMonoType(args->monoType, map); - if (next == args->next && monoType == args->monoType) { - UNPROTECT(save); - return args; - } - PROTECT(monoType); - TinMonoTypeList *newArgs = newTinMonoTypeList(monoType, next); - UNPROTECT(save); - return newArgs; - -} - -static TinFunctionApplication *instantiateFun(TinFunctionApplication *tfa, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiateFun "); - showTinFun(tfa); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - TinMonoTypeList *args = instantiateArgs(tfa->args, map); - if (args == tfa->args) return tfa; - int save = PROTECT(tfa); - PROTECT(args); - TinFunctionApplication *newTfa = newTinFunctionApplication(tfa->name, tfa->nargs, args); - UNPROTECT(save); - return newTfa; -} - -static TinMonoType *instantiateMonoType(TinMonoType *tmt, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiateMonoType "); - showTinMonoType(tmt); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - switch (tmt->type) { - case TINMONOTYPE_TYPE_VAR: { - HashSymbol *var = instantiateVar(tmt->val.var, map); - if (tmt->val.var == var) return tmt; - int save = PROTECT(var); - TinMonoType *newTmt = newTinMonoType( - TINMONOTYPE_TYPE_VAR, - TINMONOTYPE_VAL_VAR(var) - ); - UNPROTECT(save); - return newTmt; - } - case TINMONOTYPE_TYPE_FUN: { - TinFunctionApplication *tfa = instantiateFun(tmt->val.fun, map); - if (tfa == tmt->val.fun) return tmt; - int save = PROTECT(tfa); - TinMonoType *newTmt = newTinMonoType( - TINMONOTYPE_TYPE_FUN, - TINMONOTYPE_VAL_FUN(tfa) - ); - UNPROTECT(save); - return newTmt; - } - default: - cant_happen("unrecognised type %d in instantiateMonoType", tmt->type); - } -} - -static TinMonoType *instantiateQuantifier(TinTypeQuantifier *ttq, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiateQuantifier "); - showTinTypeQuantifier(ttq); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - HashSymbol *newVar = freshTypeVariable("instantiate"); - int save = PROTECT(newVar); - setInMap(map, ttq->var, newVar); - UNPROTECT(save); - return instantiatePolyType(ttq->quantifiedType, map); -} - - -static TinMonoType *instantiatePolyType(TinPolyType *tpt, HashTable *map) { -#ifdef DEBUG_TIN_INSTANTIATION - eprintf("instantiatePolyType "); - showTinPolyType(tpt); - eprintf("\n"); -#ifdef DEBUG_TIN_SUBSTITUTION_SLOWLY - sleep(1); -#endif -#endif - switch (tpt->type) { - case TINPOLYTYPE_TYPE_MONOTYPE: - return instantiateMonoType(tpt->val.monoType, map); - case TINPOLYTYPE_TYPE_QUANTIFIER: - return instantiateQuantifier(tpt->val.quantifier, map); - default: - cant_happen("unrecognised type %d in instantiatePolyType", tpt->type); - } -} - -TinMonoType *instantiate(TinPolyType *tpt) { - HashTable *map = newVarTable(); - int save = PROTECT(map); - TinMonoType *result = instantiatePolyType(tpt, map); - validateLastAlloc(); - UNPROTECT(save); - return result; -} - -static void findArgVariables(HashTable *map, TinMonoTypeList *args) { - while (args != NULL) { - findMonoTypeVariables(map, args->monoType); - args = args->next; - } -} - -static void findFunVariables(HashTable *map, TinFunctionApplication *app) { - findArgVariables(map, app->args); -} - -static void findMonoTypeVariables(HashTable *map, TinMonoType *monoType) { - switch (monoType->type) { - case TINMONOTYPE_TYPE_VAR: - addFreeVariable(map, monoType->val.var); - break; - case TINMONOTYPE_TYPE_FUN: - findFunVariables(map, monoType->val.fun); - break; - default: - cant_happen("unrecognised type %d in findMonoTypeVariables", monoType->type); - } -} - -TinPolyType *generalize(TinContext *context, TinMonoType *monoType) { - HashTable *map = newFreeVariableTable(); - int save = PROTECT(map); - findMonoTypeVariables(map, monoType); - TinPolyType *tpt = newTinPolyType( - TINPOLYTYPE_TYPE_MONOTYPE, - TINPOLYTYPE_VAL_MONOTYPE(monoType) - ); - HashSymbol *var; - TinTypeQuantifier *tqt = NULL; - int i = 0; - while ((var = iterateHashTable(map, &i, NULL)) != NULL) { - if (lookupInContext(context, var) == NULL) { - int save2 = PROTECT(tpt); - tqt = newTinTypeQuantifier(var, tpt); - PROTECT(tqt); - tpt = newTinPolyType( - TINPOLYTYPE_TYPE_QUANTIFIER, - TINPOLYTYPE_VAL_QUANTIFIER(tqt) - ); - UNPROTECT(save2); - } - } - UNPROTECT(save); - return tpt; -} - -TinSubstitution *makeEmptySubstitution() { - HashTable *h = newSubstitutionTable(); - int save = PROTECT(h); - TinSubstitution *s = newTinSubstitution(h); - UNPROTECT(save); - return s; -} - -static bool argsContains(HashSymbol *var, TinMonoTypeList *args) { - while (args != NULL) { - if (monoTypeContains(var, args->monoType)) return true; - args = args->next; - } - return false; -} - -static bool funContains(HashSymbol *var, TinFunctionApplication *tfa) { - return argsContains(var, tfa->args); -} - -static bool monoTypeContains(HashSymbol *var, TinMonoType *tmt) { - switch (tmt->type) { - case TINMONOTYPE_TYPE_VAR: - return var == tmt->val.var; - case TINMONOTYPE_TYPE_FUN: - return funContains(var, tmt->val.fun); - default: - cant_happen("unrecognised type %d in monoTypeContains", tmt->type); - } -} - -TinSubstitution *unify(TinMonoType *t1, TinMonoType *t2, const char *caller) { -#ifdef DEBUG_TIN_UNIFICATION - eprintf("%s unify ", caller); - showTinMonoType(t1); - eprintf(" with "); - showTinMonoType(t2); - eprintf("\n"); -#endif - switch (t1->type) { - case TINMONOTYPE_TYPE_VAR: { - TinSubstitution *s = makeEmptySubstitution(); - HashSymbol *var = t1->val.var; - if (t2->type == TINMONOTYPE_TYPE_VAR && var == t2->val.var) { - return s; - } - if (monoTypeContains(var, t2)) { - can_happen("occurs check failed: %s", t1->val.var->name); - return s; - } - int save = PROTECT(s); - addToSubstitution(s, var, t2); - UNPROTECT(save); - return s; - } - case TINMONOTYPE_TYPE_FUN: { - switch (t2->type) { - case TINMONOTYPE_TYPE_VAR: - return unify(t2, t1, caller); - case TINMONOTYPE_TYPE_FUN: { - TinSubstitution *s = makeEmptySubstitution(); - TinFunctionApplication *f1 = t1->val.fun; - TinFunctionApplication *f2 = t2->val.fun; - if (f1->name != f2->name) { - can_happen("%s cannot unify %s with %s", caller, f1->name->name, f2->name->name); - return s; - } - if (f1->nargs != f2->nargs) { - cant_happen("%s different argument counts for type function %s: %d vs. %d", caller, f1->name->name, f1->nargs, f2->nargs); - } - int save = PROTECT(s); - TinMonoTypeList *a1 = f1->args; - TinMonoTypeList *a2 = f2->args; - while (a1 != NULL) { - if (a2 == NULL) { - cant_happen("%s argument lengths underrun in unify %s", caller, f1->name->name); - } - TinMonoType *tmt1 = applyMonoTypeSubstitution(s, a1->monoType); - PROTECT(tmt1); - TinMonoType *tmt2 = applyMonoTypeSubstitution(s, a2->monoType); - PROTECT(tmt2); - TinSubstitution *s2 = unify(tmt1, tmt2, caller); - PROTECT(s2); - s = applySubstitutionSubstitution(s, s2); - UNPROTECT(save); - save = PROTECT(s); - a1 = a1->next; - a2 = a2->next; - } - if (a2 != NULL) { - cant_happen("%s, argument lengths overrun in unify %s", caller, f1->name->name); - } - UNPROTECT(save); - return s; - } - default: - cant_happen("%s unrecognised t2 %d in unify", caller, t2->type); - } - } - default: - cant_happen("%s unrecognised t1 %d in unify", caller, t1->type); - } -} - - -static bool isAlphaSymbol(HashSymbol *symbol) { - if (symbol == NULL) return true; - return (bool) isalpha(symbol->name[0]); -} - -static void showTinMonoTypeList(TinMonoTypeList *list) { - eprintf("("); - for (TinMonoTypeList *args = list; args != NULL; args = args->next) { - showTinMonoType(args->monoType); - if (args->next != NULL) eprintf(", "); - } - eprintf(")"); -} - -static void showTinFun(TinFunctionApplication *fun) { - if (fun == NULL) { eprintf(""); return; } - if (isAlphaSymbol(fun->name) || fun->nargs != 2) { - showTinSymbol(fun->name); - showTinMonoTypeList(fun->args); - } else { - eprintf("("); - for (TinMonoTypeList *args = fun->args; args != NULL; args = args->next) { - showTinMonoType(args->monoType); - if (args->next != NULL) { - eprintf(" "); - showTinSymbol(fun->name); - eprintf(" "); - } - } - eprintf(")"); - } -} - -static void showTinSymbol(HashSymbol *symbol) { - if (symbol == NULL) { eprintf(""); return; } - eprintf("%s", symbol->name); -} - -static void showTinTypeQuantifier(TinTypeQuantifier *quantifier) { - if (quantifier == NULL) { eprintf(""); return; } - eprintf("V"); - showTinSymbol(quantifier->var); - eprintf("."); - showTinPolyType(quantifier->quantifiedType); -} - -void showTinMonoType(TinMonoType *monoType) { - if (monoType == NULL) { eprintf(""); return; } - switch (monoType->type) { - case TINMONOTYPE_TYPE_VAR: - showTinSymbol(monoType->val.var); - break; - case TINMONOTYPE_TYPE_FUN: - showTinFun(monoType->val.fun); - break; - default: - cant_happen("unrecognised type %d in showTinMonoType", monoType->type); - } -} - -void showTinPolyType(TinPolyType *polyType) { - if (polyType == NULL) { eprintf(""); return; } - switch (polyType->type) { - case TINPOLYTYPE_TYPE_MONOTYPE: - showTinMonoType(polyType->val.monoType); - break; - case TINPOLYTYPE_TYPE_QUANTIFIER: - showTinTypeQuantifier(polyType->val.quantifier); - break; - default: - cant_happen("unrecognised type %d in showTinPolyType", polyType->type); - } -} - -#ifdef DEBUG_RUN_TESTS -#if DEBUG_RUN_TESTS == 3 - -#include "tests/tin.inc" - -#endif -#endif diff --git a/src/tin_helper.h b/src/tin_helper.h deleted file mode 100644 index 60b6f10..0000000 --- a/src/tin_helper.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef cekf_tin_helper_h -#define cekf_tin_helper_h -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -#include "tin.h" -#include "hash.h" -#include "memory.h" - -void markTinSymbolTable(void); - -void showTinMonoType(TinMonoType *monoType); -void showTinPolyType(TinPolyType *polyType); - -void printTinSymbol(HashSymbol *x, int depth); -TinContext *freshTinContext(void); - -void addToSubstitution(TinSubstitution *substitution, HashSymbol *symbol, TinMonoType *monotype); -TinContext *extendTinContext(TinContext *parent); -HashSymbol *freshTypeVariable(const char *suffix); -void addVarToContext(TinContext *context, HashSymbol *symbol, TinPolyType *polyType); -void addConstructorToContext(TinContext *context, HashSymbol *symbol, TinPolyType *polyType); -TinSubstitution *unify(TinMonoType *t1, TinMonoType *t2, const char *caller); -TinMonoType *applyMonoTypeSubstitution(TinSubstitution *s, TinMonoType *mtype); - -TinPolyType *applyPolyTypeSubstitution(TinSubstitution *s, TinPolyType *ptype); -TinPolyType *lookupInContext(TinContext *context, HashSymbol *var); -bool isTypeConstructor(TinContext *context, HashSymbol *var); -TinSubstitution *makeEmptySubstitution(void); -TinContext *applyContextSubstitution(TinSubstitution *s, TinContext *context); -void applyContextSubstitutionInPlace(TinSubstitution *s, TinContext *context); -TinMonoType *instantiate(TinPolyType *tpt); -TinSubstitution *applySubstitutionSubstitution(TinSubstitution *s1, TinSubstitution *s2); -TinPolyType *generalize(TinContext *context, TinMonoType *monoType); - -#endif diff --git a/src/tin_objtypes.h b/src/tin_objtypes.h deleted file mode 100644 index acaa916..0000000 --- a/src/tin_objtypes.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef cekf_tin_objtypes_h -#define cekf_tin_objtypes_h -/* - * CEKF - VM supporting amb - * Copyright (C) 2022-2023 Bill Hails - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - * Type inference structures used by Algorithm W. - * - * generated from src/tin.yaml by makeAST.py - */ - -#define TIN_OBJTYPES() OBJTYPE_TINFUNCTIONAPPLICATION, \ -OBJTYPE_TINMONOTYPELIST, \ -OBJTYPE_TINTYPEQUANTIFIER, \ -OBJTYPE_TINCONTEXT, \ -OBJTYPE_TINSUBSTITUTION, \ -OBJTYPE_TINARGSRESULT, \ -OBJTYPE_TINVARRESULT, \ -OBJTYPE_TINVARSRESULT, \ -OBJTYPE_TINMONOTYPE, \ -OBJTYPE_TINPOLYTYPE - -#define TIN_OBJTYPE_CASES() \ -case OBJTYPE_TINFUNCTIONAPPLICATION:\ -case OBJTYPE_TINMONOTYPELIST:\ -case OBJTYPE_TINTYPEQUANTIFIER:\ -case OBJTYPE_TINCONTEXT:\ -case OBJTYPE_TINSUBSTITUTION:\ -case OBJTYPE_TINARGSRESULT:\ -case OBJTYPE_TINVARRESULT:\ -case OBJTYPE_TINVARSRESULT:\ -case OBJTYPE_TINMONOTYPE:\ -case OBJTYPE_TINPOLYTYPE:\ - - -void markTinObj(struct Header *h); -void freeTinObj(struct Header *h); -char *typenameTinObj(int type); - -#endif diff --git a/src/debug_tpmc.c b/src/tpmc_debug.c similarity index 53% rename from src/debug_tpmc.c rename to src/tpmc_debug.c index 542fd91..e6d2a3b 100644 --- a/src/debug_tpmc.c +++ b/src/tpmc_debug.c @@ -22,7 +22,7 @@ #include -#include "debug_tpmc.h" +#include "tpmc_debug.h" #include "lambda_pp.h" #include "bigint.h" @@ -68,7 +68,7 @@ void printTpmcAssignmentPattern(struct TpmcAssignmentPattern * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcAssignmentPattern (NULL)"); return; } eprintf("TpmcAssignmentPattern[\n"); - printAstSymbol(x->name, depth + 1); + printAstSymbol(x->name, depth + 1); eprintf("\n"); printTpmcPattern(x->value, depth + 1); eprintf("\n"); @@ -80,9 +80,9 @@ void printTpmcConstructorPattern(struct TpmcConstructorPattern * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcConstructorPattern (NULL)"); return; } eprintf("TpmcConstructorPattern[\n"); - printAstSymbol(x->tag, depth + 1); + printAstSymbol(x->tag, depth + 1); eprintf("\n"); - printLamTypeConstructorInfo(x->info, depth + 1); + printLamTypeConstructorInfo(x->info, depth + 1); eprintf("\n"); printTpmcPatternArray(x->components, depth+1); eprintf("\n"); @@ -94,7 +94,7 @@ void printTpmcPattern(struct TpmcPattern * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcPattern (NULL)"); return; } eprintf("TpmcPattern[\n"); - printAstSymbol(x->path, depth + 1); + printAstSymbol(x->path, depth + 1); eprintf("\n"); printTpmcPatternValue(x->pattern, depth + 1); eprintf("\n"); @@ -106,7 +106,7 @@ void printTpmcTestState(struct TpmcTestState * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcTestState (NULL)"); return; } eprintf("TpmcTestState[\n"); - printAstSymbol(x->path, depth + 1); + printAstSymbol(x->path, depth + 1); eprintf("\n"); printTpmcArcArray(x->arcs, depth+1); eprintf("\n"); @@ -118,7 +118,7 @@ void printTpmcFinalState(struct TpmcFinalState * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcFinalState (NULL)"); return; } eprintf("TpmcFinalState[\n"); - ppLamExpD(x->action, depth + 1); + ppLamExpD(x->action, depth + 1); eprintf("\n"); pad(depth); eprintf("]"); @@ -128,13 +128,13 @@ void printTpmcState(struct TpmcState * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcState (NULL)"); return; } eprintf("TpmcState[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->refcount); eprintf("\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->stamp); eprintf("\n"); - printHashTable(x->freeVariables, depth + 1); + printHashTable(x->freeVariables, depth + 1); eprintf("\n"); printTpmcStateValue(x->state, depth + 1); eprintf("\n"); @@ -150,7 +150,7 @@ void printTpmcArc(struct TpmcArc * x, int depth) { eprintf("\n"); printTpmcPattern(x->test, depth + 1); eprintf("\n"); - printHashTable(x->freeVariables, depth + 1); + printHashTable(x->freeVariables, depth + 1); eprintf("\n"); pad(depth); eprintf("]"); @@ -172,7 +172,7 @@ void printTpmcIntList(struct TpmcIntList * x, int depth) { pad(depth); if (x == NULL) { eprintf("TpmcIntList (NULL)"); return; } eprintf("TpmcIntList[\n"); - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->integer); eprintf("\n"); printTpmcIntList(x->next, depth + 1); @@ -189,7 +189,7 @@ void printTpmcPatternValue(struct TpmcPatternValue * x, int depth) { case TPMCPATTERNVALUE_TYPE_VAR: pad(depth + 1); eprintf("TPMCPATTERNVALUE_TYPE_VAR\n"); - printAstSymbol(x->val.var, depth + 1); + printAstSymbol(x->val.var, depth + 1); break; case TPMCPATTERNVALUE_TYPE_COMPARISON: pad(depth + 1); @@ -204,19 +204,19 @@ void printTpmcPatternValue(struct TpmcPatternValue * x, int depth) { case TPMCPATTERNVALUE_TYPE_WILDCARD: pad(depth + 1); eprintf("TPMCPATTERNVALUE_TYPE_WILDCARD\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.wildcard); break; case TPMCPATTERNVALUE_TYPE_CHARACTER: pad(depth + 1); eprintf("TPMCPATTERNVALUE_TYPE_CHARACTER\n"); - pad(depth + 1); + pad(depth + 1); eprintf("char '%c'", x->val.character); break; case TPMCPATTERNVALUE_TYPE_BIGINTEGER: pad(depth + 1); eprintf("TPMCPATTERNVALUE_TYPE_BIGINTEGER\n"); - printBigInt(x->val.biginteger, depth + 1); + printBigInt(x->val.biginteger, depth + 1); break; case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: pad(depth + 1); @@ -249,7 +249,7 @@ void printTpmcStateValue(struct TpmcStateValue * x, int depth) { case TPMCSTATEVALUE_TYPE_ERROR: pad(depth + 1); eprintf("TPMCSTATEVALUE_TYPE_ERROR\n"); - pad(depth + 1); + pad(depth + 1); eprintf("void * %p", x->val.error); break; default: @@ -277,7 +277,7 @@ void printTpmcVariableArray(struct TpmcVariableArray * x, int depth) { if (x == NULL) { eprintf("TpmcVariableArray (NULL)"); return; } eprintf("TpmcVariableArray(%d)[\n", x->size); for (int i = 0; i < x->size; i++) { - printAstSymbol(x->entries[i], depth + 1); + printAstSymbol(x->entries[i], depth + 1); eprintf("\n"); } pad(depth); @@ -327,7 +327,7 @@ void printTpmcIntArray(struct TpmcIntArray * x, int depth) { if (x == NULL) { eprintf("TpmcIntArray (NULL)"); return; } eprintf("TpmcIntArray(%d)[\n", x->size); for (int i = 0; i < x->size; i++) { - pad(depth + 1); + pad(depth + 1); eprintf("int %d", x->entries[i]); eprintf("\n"); } @@ -353,3 +353,227 @@ void printTpmcMatrix(struct TpmcMatrix * x, int depth) { eprintf("]"); } + +/***************************************/ + +bool eqTpmcMatchRules(struct TpmcMatchRules * a, struct TpmcMatchRules * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTpmcMatchRuleArray(a->rules, b->rules)) return false; + if (!eqTpmcVariableArray(a->rootVariables, b->rootVariables)) return false; + return true; +} + +bool eqTpmcMatchRule(struct TpmcMatchRule * a, struct TpmcMatchRule * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTpmcState(a->action, b->action)) return false; + if (!eqTpmcPatternArray(a->patterns, b->patterns)) return false; + return true; +} + +bool eqTpmcComparisonPattern(struct TpmcComparisonPattern * a, struct TpmcComparisonPattern * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTpmcPattern(a->previous, b->previous)) return false; + if (!eqTpmcPattern(a->current, b->current)) return false; + return true; +} + +bool eqTpmcAssignmentPattern(struct TpmcAssignmentPattern * a, struct TpmcAssignmentPattern * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->name != b->name) return false; + if (!eqTpmcPattern(a->value, b->value)) return false; + return true; +} + +bool eqTpmcConstructorPattern(struct TpmcConstructorPattern * a, struct TpmcConstructorPattern * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->tag != b->tag) return false; + if (a->info != b->info) return false; + if (!eqTpmcPatternArray(a->components, b->components)) return false; + return true; +} + +bool eqTpmcPattern(struct TpmcPattern * a, struct TpmcPattern * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->path != b->path) return false; + if (!eqTpmcPatternValue(a->pattern, b->pattern)) return false; + return true; +} + +bool eqTpmcTestState(struct TpmcTestState * a, struct TpmcTestState * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->path != b->path) return false; + if (!eqTpmcArcArray(a->arcs, b->arcs)) return false; + return true; +} + +bool eqTpmcFinalState(struct TpmcFinalState * a, struct TpmcFinalState * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->action != b->action) return false; + return true; +} + +bool eqTpmcState(struct TpmcState * a, struct TpmcState * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->refcount != b->refcount) return false; + if (a->stamp != b->stamp) return false; + if (a->freeVariables != b->freeVariables) return false; + if (!eqTpmcStateValue(a->state, b->state)) return false; + return true; +} + +bool eqTpmcArc(struct TpmcArc * a, struct TpmcArc * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTpmcState(a->state, b->state)) return false; + if (!eqTpmcPattern(a->test, b->test)) return false; + if (a->freeVariables != b->freeVariables) return false; + return true; +} + +bool eqTpmcArcList(struct TpmcArcList * a, struct TpmcArcList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (!eqTpmcArc(a->arc, b->arc)) return false; + if (!eqTpmcArcList(a->next, b->next)) return false; + return true; +} + +bool eqTpmcIntList(struct TpmcIntList * a, struct TpmcIntList * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->integer != b->integer) return false; + if (!eqTpmcIntList(a->next, b->next)) return false; + return true; +} + +bool eqTpmcPatternValue(struct TpmcPatternValue * a, struct TpmcPatternValue * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case TPMCPATTERNVALUE_TYPE_VAR: + if (a->val.var != b->val.var) return false; + break; + case TPMCPATTERNVALUE_TYPE_COMPARISON: + if (!eqTpmcComparisonPattern(a->val.comparison, b->val.comparison)) return false; + break; + case TPMCPATTERNVALUE_TYPE_ASSIGNMENT: + if (!eqTpmcAssignmentPattern(a->val.assignment, b->val.assignment)) return false; + break; + case TPMCPATTERNVALUE_TYPE_WILDCARD: + if (a->val.wildcard != b->val.wildcard) return false; + break; + case TPMCPATTERNVALUE_TYPE_CHARACTER: + if (a->val.character != b->val.character) return false; + break; + case TPMCPATTERNVALUE_TYPE_BIGINTEGER: + if (a->val.biginteger != b->val.biginteger) return false; + break; + case TPMCPATTERNVALUE_TYPE_CONSTRUCTOR: + if (!eqTpmcConstructorPattern(a->val.constructor, b->val.constructor)) return false; + break; + default: + cant_happen("unrecognised type %d in eqTpmcPatternValue", a->type); + } + return true; +} + +bool eqTpmcStateValue(struct TpmcStateValue * a, struct TpmcStateValue * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->type != b->type) return false; + switch(a->type) { + case TPMCSTATEVALUE_TYPE_TEST: + if (!eqTpmcTestState(a->val.test, b->val.test)) return false; + break; + case TPMCSTATEVALUE_TYPE_FINAL: + if (!eqTpmcFinalState(a->val.final, b->val.final)) return false; + break; + case TPMCSTATEVALUE_TYPE_ERROR: + if (a->val.error != b->val.error) return false; + break; + default: + cant_happen("unrecognised type %d in eqTpmcStateValue", a->type); + } + return true; +} + +bool eqTpmcMatchRuleArray(struct TpmcMatchRuleArray * a, struct TpmcMatchRuleArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (!eqTpmcMatchRule(a->entries[i], b->entries[i])) return false; + } + return true; +} + +bool eqTpmcVariableArray(struct TpmcVariableArray * a, struct TpmcVariableArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (a->entries[i] != b->entries[i]) return false; + } + return true; +} + +bool eqTpmcPatternArray(struct TpmcPatternArray * a, struct TpmcPatternArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (!eqTpmcPattern(a->entries[i], b->entries[i])) return false; + } + return true; +} + +bool eqTpmcStateArray(struct TpmcStateArray * a, struct TpmcStateArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (!eqTpmcState(a->entries[i], b->entries[i])) return false; + } + return true; +} + +bool eqTpmcArcArray(struct TpmcArcArray * a, struct TpmcArcArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (!eqTpmcArc(a->entries[i], b->entries[i])) return false; + } + return true; +} + +bool eqTpmcIntArray(struct TpmcIntArray * a, struct TpmcIntArray * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->size != b->size) return false; + for (int i = 0; i < a->size; i++) { + if (a->entries[i] != b->entries[i]) return false; + } + return true; +} + +bool eqTpmcMatrix(struct TpmcMatrix * a, struct TpmcMatrix * b) { + if (a == b) return true; + if (a == NULL || b == NULL) return false; + if (a->width != b->width || a->height != b->height) return false; + for (int i = 0; i < (a->width * a->height); i++) { + if (!eqTpmcPattern(a->entries[i], b->entries[i])) return false; + } + return true; +} + diff --git a/src/debug_tpmc.h b/src/tpmc_debug.h similarity index 56% rename from src/debug_tpmc.h rename to src/tpmc_debug.h index b110cf9..a11f675 100644 --- a/src/debug_tpmc.h +++ b/src/tpmc_debug.h @@ -1,5 +1,5 @@ -#ifndef cekf_debug_tpmc_h -#define cekf_debug_tpmc_h +#ifndef cekf_tpmc_debug_h +#define cekf_tpmc_debug_h /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -23,7 +23,7 @@ */ #include "tpmc_helper.h" -#include "debug_lambda.h" +#include "lambda_debug.h" #include "lambda_pp.h" #include "bigint.h" @@ -49,4 +49,26 @@ void printTpmcArcArray(struct TpmcArcArray * x, int depth); void printTpmcIntArray(struct TpmcIntArray * x, int depth); void printTpmcMatrix(struct TpmcMatrix * x, int depth); +bool eqTpmcMatchRules(struct TpmcMatchRules * a, struct TpmcMatchRules * b); +bool eqTpmcMatchRule(struct TpmcMatchRule * a, struct TpmcMatchRule * b); +bool eqTpmcComparisonPattern(struct TpmcComparisonPattern * a, struct TpmcComparisonPattern * b); +bool eqTpmcAssignmentPattern(struct TpmcAssignmentPattern * a, struct TpmcAssignmentPattern * b); +bool eqTpmcConstructorPattern(struct TpmcConstructorPattern * a, struct TpmcConstructorPattern * b); +bool eqTpmcPattern(struct TpmcPattern * a, struct TpmcPattern * b); +bool eqTpmcTestState(struct TpmcTestState * a, struct TpmcTestState * b); +bool eqTpmcFinalState(struct TpmcFinalState * a, struct TpmcFinalState * b); +bool eqTpmcState(struct TpmcState * a, struct TpmcState * b); +bool eqTpmcArc(struct TpmcArc * a, struct TpmcArc * b); +bool eqTpmcArcList(struct TpmcArcList * a, struct TpmcArcList * b); +bool eqTpmcIntList(struct TpmcIntList * a, struct TpmcIntList * b); +bool eqTpmcPatternValue(struct TpmcPatternValue * a, struct TpmcPatternValue * b); +bool eqTpmcStateValue(struct TpmcStateValue * a, struct TpmcStateValue * b); +bool eqTpmcMatchRuleArray(struct TpmcMatchRuleArray * a, struct TpmcMatchRuleArray * b); +bool eqTpmcVariableArray(struct TpmcVariableArray * a, struct TpmcVariableArray * b); +bool eqTpmcPatternArray(struct TpmcPatternArray * a, struct TpmcPatternArray * b); +bool eqTpmcStateArray(struct TpmcStateArray * a, struct TpmcStateArray * b); +bool eqTpmcArcArray(struct TpmcArcArray * a, struct TpmcArcArray * b); +bool eqTpmcIntArray(struct TpmcIntArray * a, struct TpmcIntArray * b); +bool eqTpmcMatrix(struct TpmcMatrix * a, struct TpmcMatrix * b); + #endif diff --git a/src/tpmc_logic.c b/src/tpmc_logic.c index 460be4d..3661dcb 100644 --- a/src/tpmc_logic.c +++ b/src/tpmc_logic.c @@ -23,7 +23,7 @@ #include "tpmc_logic.h" #include "tpmc_translate.h" #include "tpmc.h" -#include "debug_tpmc.h" +#include "tpmc_debug.h" #include "tpmc_match.h" #include "ast_helper.h" #include "symbol.h" diff --git a/src/tpmc_match.c b/src/tpmc_match.c index f3297df..41dec1a 100644 --- a/src/tpmc_match.c +++ b/src/tpmc_match.c @@ -22,8 +22,8 @@ #include "common.h" #include "tpmc_match.h" #include "tpmc_compare.h" -#include "debug_tpmc.h" -#include "debug_lambda.h" +#include "tpmc_debug.h" +#include "lambda_debug.h" #include "lambda_helper.h" #include "symbol.h" diff --git a/tests/src/test.h b/tests/src/test.h index 6903a68..bb1e000 100644 --- a/tests/src/test.h +++ b/tests/src/test.h @@ -22,7 +22,9 @@ #include "memory.h" #include "ast.h" #include "module.h" -#include "algorithm_W.h" -#include "tin_helper.h" +#include "lambda_conversion.h" +#include "lambda_pp.h" +#include "tc_debug.h" +#include "tc_analyze.h" #endif diff --git a/tests/src/test_typechecker.c b/tests/src/test_typechecker.c index 4b6cbc3..cb4c812 100644 --- a/tests/src/test_typechecker.c +++ b/tests/src/test_typechecker.c @@ -18,7 +18,6 @@ #include "test.h" -/* static AstNest *parseWrapped(char *string) { disableGC(); PmModule *mod = newPmToplevelFromString(string, string); @@ -30,8 +29,8 @@ static AstNest *parseWrapped(char *string) { assert(mod->nest != NULL); return mod->nest; } -*/ +/* static AstNest *parseSolo(char *string) { disableGC(); PmModule *mod = newPmModuleFromString(string, string); @@ -43,6 +42,7 @@ static AstNest *parseSolo(char *string) { assert(mod->nest != NULL); return mod->nest; } +*/ /* static void test_car() { @@ -66,18 +66,24 @@ static void test_cdr() { UNPROTECT(save); assert(!hadErrors()); } - +*/ static void test_car_of() { printf("test_car_of\n"); AstNest *result = parseWrapped("<[1]"); int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); + LamExp *exp = lamConvertNest(result, NULL); + PROTECT(exp); + ppLamExp(exp); printf("\n"); - UNPROTECT(save); + TcEnv *env = tc_init(); + PROTECT(env); + TcType *res = tc_analyze(exp, env); assert(!hadErrors()); + PROTECT(res); + printTcType(res, 0); + UNPROTECT(save); } - +/* static void test_adder() { printf("test_adder\n"); AstNest *result = parseSolo("fn(a,b){a+b}"); @@ -99,7 +105,6 @@ static void test_fact() { UNPROTECT(save); assert(!hadErrors()); } -*/ static void test_small_car() { printf("test_small_car\n"); @@ -113,15 +118,14 @@ static void test_small_car() { UNPROTECT(save); assert(!hadErrors()); } +*/ int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) { initProtection(); - /* - test_car(); - test_cdr(); + // test_car(); + // test_cdr(); test_car_of(); - test_adder(); - */ - test_small_car(); + // test_adder(); + // test_small_car(); } diff --git a/tools/makeAST.py b/tools/makeAST.py index 3ad0ca7..75a53b6 100644 --- a/tools/makeAST.py +++ b/tools/makeAST.py @@ -94,10 +94,18 @@ def printPrintFunctions(self): for entity in self.contents.values(): entity.printPrintFunction(self) + def printCompareFunctions(self): + for entity in self.contents.values(): + entity.printCompareFunction(self) + def printPrintDeclarations(self): for entity in self.contents.values(): entity.printPrintDeclaration(self) + def printCompareDeclarations(self): + for entity in self.contents.values(): + entity.printCompareDeclaration(self) + def printDefines(self): for entity in self.contents.values(): entity.printDefines(self) @@ -205,9 +213,15 @@ def printNewFunction(self, catalog): def printPrintDeclaration(self, catalog): pass + def printCompareDeclaration(self, catalog): + pass + def printPrintFunction(self, catalog): pass + def printCompareFunction(self, catalog): + pass + def printMarkObjCase(self, catalog): pass @@ -264,6 +278,15 @@ def makeTypeName(self): v = v.upper().replace('AST', 'AST_') return v + def printCompareCase(self, depth): + typeName = self.makeTypeName() + pad(depth) + print(f'case {typeName}:') + pad(depth + 1) + print("if (a != b) return false;") + pad(depth + 1) + print('break;') + def printPrintCase(self, depth): typeName = self.makeTypeName() pad(depth) @@ -311,6 +334,10 @@ def printMarkArrayLine(self, catalog, key, depth): obj = catalog.get(self.typeName) obj.printMarkField(f"{self.name}[{key}]", depth) + def printCompareLine(self, catalog, depth): + obj = catalog.get(self.typeName) + obj.printCompareField(self.name, depth) + def printPrintLine(self, catalog, depth): obj = catalog.get(self.typeName) obj.printPrintField(self.name, depth) @@ -319,6 +346,10 @@ def printPrintArrayLine(self, catalog, key, depth): obj = catalog.get(self.typeName) obj.printPrintField(f"{self.name}[{key}]", depth) + def printCompareArrayLine(self, catalog, key, depth): + obj = catalog.get(self.typeName) + obj.printCompareField(f"{self.name}[{key}]", depth) + def printStructTypedefLine(self, catalog): print(" {decl};".format(decl=self.getSignature(catalog))) @@ -347,6 +378,11 @@ def tag(self): def getTypeDeclaration(self): return "struct {name} *".format(name=self.getName()) + def printCompareField(self, field, depth, prefix=''): + myName=self.getName() + pad(depth) + print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field})) return false;") + def printPrintField(self, field, depth, prefix=''): myName=self.getName() pad(depth) @@ -506,10 +542,35 @@ def printMark2dFunctionBody(self, catalog): def printPrintDeclaration(self, catalog): print("{decl};".format(decl=self.getPrintSignature(catalog))) + def printCompareDeclaration(self, catalog): + print("{decl};".format(decl=self.getCompareSignature(catalog))) + def getPrintSignature(self, catalog): myType = self.getTypeDeclaration() return "void print{myName}({myType} x, int depth)".format(myName=self.getName(), myType=myType) + def getCompareSignature(self, catalog): + myType = self.getTypeDeclaration() + return "bool eq{myName}({myType} a, {myType} b)".format(myName=self.getName(), myType=myType) + + def printCompareFunction(self, catalog): + myName = self.getName() + print("{decl} {{".format(decl=self.getCompareSignature(catalog))) + print(" if (a == b) return true;") + print(" if (a == NULL || b == NULL) return false;") + if self.dimension == 1: + print(" if (a->size != b->size) return false;") + print(" for (int i = 0; i < a->size; i++) {") + self.entries.printCompareArrayLine(catalog, "i", 2) + print(" }") + else: + print(" if (a->width != b->width || a->height != b->height) return false;") + print(" for (int i = 0; i < (a->width * a->height); i++) {") + self.entries.printCompareArrayLine(catalog, "i", 2) + print(" }") + print(" return true;") + print("}\n") + def printPrintFunction(self, catalog): myName = self.getName() print("{decl} {{".format(decl=self.getPrintSignature(catalog))) @@ -622,6 +683,10 @@ def getPrintSignature(self, catalog): myType = self.getTypeDeclaration() return "void print{myName}({myType} x, int depth)".format(myName=self.getName(), myType=myType) + def getCompareSignature(self, catalog): + myType = self.getTypeDeclaration() + return "bool eq{myName}({myType} a, {myType} b)".format(myName=self.getName(), myType=myType) + def getNewArgs(self): return [x for x in self.fields if x.default is None] @@ -649,6 +714,9 @@ def printMarkDeclaration(self, catalog): def printPrintDeclaration(self, catalog): print("{decl};".format(decl=self.getPrintSignature(catalog))) + def printCompareDeclaration(self, catalog): + print("{decl};".format(decl=self.getCompareSignature(catalog))) + def printNewFunction(self, catalog): print("{decl} {{".format(decl=self.getNewSignature(catalog))) myType = self.getTypeDeclaration() @@ -667,6 +735,10 @@ def printMarkFunctionBody(self, catalog): for field in self.fields: field.printMarkLine(catalog, 1) + def printCompareFunctionBody(self, catalog): + for field in self.fields: + field.printCompareLine(catalog, 1) + def printPrintFunctionBody(self, catalog): for field in self.fields: field.printPrintLine(catalog, 1) @@ -676,6 +748,11 @@ def printMarkField(self, field, depth, prefix=''): pad(depth) print("mark{myName}(x->{prefix}{field});".format(field=field, myName=self.getName(), prefix=prefix)) + def printCompareField(self, field, depth, prefix=''): + myName=self.getName() + pad(depth) + print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field})) return false;") + def printPrintField(self, field, depth, prefix=''): myName=self.getName() pad(depth) @@ -716,6 +793,15 @@ def printTypeObjCase(self, catalog): pad(3) print('return "{name}";'.format(name=self.getName())) + def printCompareFunction(self, catalog): + myName = self.getName() + print("{decl} {{".format(decl=self.getCompareSignature(catalog))) + print(" if (a == b) return true;") + print(" if (a == NULL || b == NULL) return false;") + self.printCompareFunctionBody(catalog) + print(" return true;") + print("}\n") + def printPrintFunction(self, catalog): myName = self.getName() print("{decl} {{".format(decl=self.getPrintSignature(catalog))) @@ -793,6 +879,13 @@ def printMarkCase(self, catalog): obj.printMarkField(self.name, 3, 'val.') print(" break;") + def printCompareCase(self, catalog): + typeName = self.makeTypeName() + print(f" case {typeName}:") + obj = catalog.get(self.typeName) + obj.printCompareField(self.name, 3, 'val.') + print(" break;") + def printPrintCase(self, catalog): typeName = self.makeTypeName() print(f" case {typeName}:") @@ -845,6 +938,15 @@ def printMarkFunctionBody(self, catalog): print(' cant_happen("unrecognised type %d in mark{myName}", x->type);'.format(myName=self.getName())) print(" }") + def printCompareFunctionBody(self, catalog): + print(" if (a->type != b->type) return false;") + print(" switch(a->type) {") + for field in self.fields: + field.printCompareCase(catalog) + print(" default:") + print(' cant_happen("unrecognised type %d in eq{myName}", a->type);'.format(myName=self.getName())) + print(" }") + def printPrintFunctionBody(self, catalog): print(" switch(x->type) {") for field in self.fields: @@ -906,6 +1008,14 @@ def printTypedef(self, catalog): def isEnum(self): return True + def printCompareField(self, field, depth, prefix=''): + pad(depth) + print("switch (a->type) {") + for field in self.fields: + field.printCompareCase(depth + 1) + pad(depth) + print('}') + def printPrintField(self, field, depth, prefix=''): pad(depth) print('switch (x->type) {') @@ -964,6 +1074,10 @@ def __init__(self, name, data): else: self.printFn = data['printFn'] self.valued = data['valued'] + if 'compareFn' in data: + self.compareFn = data['compareFn'] + else: + self.compareFn = None def printMarkCase(self, catalog): if self.markFn is not None: @@ -980,8 +1094,14 @@ def printMarkField(self, field, depth, prefix=''): def getTypeDeclaration(self): return self.cname - def printPrintField(self, field, depth, prefix=''): + def printCompareField(self, field, depth, prefix=''): pad(depth) + if self.compareFn is None: + print(f"if (a->{prefix}{field} != b->{prefix}{field}) return false;") + else: + print(f"if (!{self.compareFn}(a->{prefix}{field}, b->{prefix}{field})) return false;") + + def printPrintField(self, field, depth, prefix=''): pad(depth) if self.printFn == 'printf': print('pad(depth + 1);') @@ -1153,28 +1273,34 @@ def printGpl(file, document): catalog.printTypeObjFunction() print("") elif args.type == 'debug_h': - print(f"#ifndef cekf_debug_{typeName}_h") - print(f"#define cekf_debug_{typeName}_h") + print(f"#ifndef cekf_{typeName}_debug_h") + print(f"#define cekf_{typeName}_debug_h") printGpl(args.yaml, document) print("") print(f'#include "{typeName}_helper.h"') for include in includes: - print(f'#include "debug_{include}"') + print(f'#include "{include[0:-2]}_debug.h"') for include in limited_includes: print(f'#include "{include}"') print("") catalog.printPrintDeclarations() print("") + catalog.printCompareDeclarations() + print("") print("#endif") elif args.type == 'debug_c': printGpl(args.yaml, document) print("") print('#include ') print("") - print(f'#include "debug_{typeName}.h"') + print(f'#include "{typeName}_debug.h"') for include in limited_includes: print(f'#include "{include}"') print("") print('static void pad(int depth) { eprintf("%*s", depth * 4, ""); }') print("") catalog.printPrintFunctions() + print("") + print("/***************************************/") + print("") + catalog.printCompareFunctions() From f1e56604b6b17b767a2aec40c97799cb58d3ded7 Mon Sep 17 00:00:00 2001 From: Bill Hails Date: Sat, 30 Dec 2023 17:20:57 +0000 Subject: [PATCH 2/4] type checking mostly working, save while we're ahead --- fn/infer.fn | 179 ++++++++++++++++++++ fn/listOfInt.fn | 1 + src/common.h | 2 +- src/debugging_off.h | 6 + src/debugging_on.h | 25 ++- src/lexer.l | 2 +- src/main.c | 17 +- src/module.c | 6 +- src/module.h | 1 + src/tc.c | 7 +- src/tc.h | 12 +- src/tc.yaml | 10 +- src/tc_analyze.c | 313 ++++++++++++++++++++++++++--------- src/tc_debug.c | 71 ++++---- src/tc_debug.h | 16 +- src/tc_helper.c | 32 +++- tests/src/test_typechecker.c | 180 ++++++++++++++------ tools/makeAST.py | 80 ++++++++- 18 files changed, 764 insertions(+), 196 deletions(-) create mode 100644 fn/infer.fn create mode 100644 fn/listOfInt.fn diff --git a/fn/infer.fn b/fn/infer.fn new file mode 100644 index 0000000..b04baaa --- /dev/null +++ b/fn/infer.fn @@ -0,0 +1,179 @@ +let + typedef typeExp { + varType(typeExp) + | operType(list(char), list(typeExp)) + | nullType + } + typedef copyEnv { + cpEnv(typeExp, typeExp, copyEnv) + | nullCpEnv + } + typedef typeCheckEnv { + tcEnv(list(char), typeExp, typeCheckEnv) + | nullTcEnv + } + typedef ExpClass { + ideClass(list(char)) + | condClass(ExpClass, ExpClass, ExpClass) + | lambClass(list(char), ExpClass) + | appClass(ExpClass, ExpClass) + | blockClass(DeclClass, ExpClass) + } + typedef DeclClass { + defClass(list(char), ExpClass) + | seqClass(DeclClass, DeclClass) + | recClass(Decl) + } + fn newTypeVar() { varType(nullType) } + fn prune { + (x = varType(nullType)) { x } + (varType(x)) { prune(x) } + (x) { x } + } + fn occursInType { + (var, type) { + switch(prune(type)) { + (x = varType(_)) { var == x } + (operType(_, args)) { occursInTypeList(var, args) } + } + } + } + fn occursInTypeList { + (_, []) { false } + (var, h @ t) { occursInType(var, h) or occursInTypeList(var, t) } + } + fn unifyType(exp1, exp2) { + switch (prune(exp1), prune(exp2)) { + (exp1=varType(_), exp2) { + if (occursInType(exp1, exp2)) { + exp1 == exp2 + } else { + varType(exp2); // exp1.instance := exp2 + true + } + } + (exp1, exp2=varType(_)) { + unifyType(exp2, exp1) + } + (operType(ide, args1), operType(ide, args2)) { + unifyArgs(args1, args2) + } + (_, _) { + false + } + } + } + fn unifyArgs { + ([], []) { true } + (h1 @ t1, h2 @ t2) { unifyType(h1, h2) and unifyArgs(t1, t2) } + (_, _) { false } + } + fn isGeneric(var, ng) { not occursInTypeList(var, ng) } + fn freshVar { + (typeVar, nullCpEnv, envt) { + fn (fresh) { + #(fresh, copyEnv(fresh, typevar, envt)) + } (newTypeVar()) + } + (typeVar, cpEnv(old, new, parent), topEnv) { + if (typeVar == old) { + (new, topEnv) + } else { + freshVar(typeVar, parent, topEnv) + } + } + } + fn fresh(typeExp, ng, envt) { + switch(prune(typeExp)) { + (x = varType(_)) { + if (isGeneric(x, ng)) { + freshVar(x, envt, envt) + } else { x } + } + (operType(ide, args)) { + operType(ide, freshList(args, ng, envt)) + } + } + } + fn freshList { + ([], _, _) { [] } + (h @ t, ng, envt) { + fresh(h, ng, envt) @ freshList(t, ng, envt) + } + } + fn freshType(typeExp, ng) { + fresh(typeExp, ng, nullCpEnv) + } + fn retrieve { + (ide, tcEnv(ide, exp, _), ng) { freshType(exp, ng) } + (_, nullTcEnv, _) { error("unbound ide") } + (ide, tcEnv(_, _, tail), ng) { retrieve(ide, tail, ng) } + } + fn funType(dom, cod) { operType("->", [dom, cod]) } + fn analyzeExp { + (expClass(ide), envt, ng) { retrieve(ide, envt, ng) } + (condClass(test, cons, alt), envt, ng) { + unifyType(test, boolType) and + unifyType(analyzeExp(cons, envt, ng), analyzeExp(alt, envt, ng)) + // return type of cons + } + (lambClass(binder, body), envt, ng) { + fn (typeOfBinder) { + funType(typeOfBinder, + analyzeExp(body, tcEnv(binder, typeOfBinder, envt), + typeOfBinder @ ng)) + } (newTypeVar()) + } + (appClass(fun, arg), envt, ng) { + fn (typeOfRes) { + unifyType( + analyzeExp(fun, envt, ng), + funType(analyzeExp(arg, envt, ng), typeOfRes)); + typeOfRes + } (newTypeVar()) + } + (blockClass(decl, scope), envt, ng) { + analyzeExp(scope, analyzeDecl(decl, envt, ng), ng) + } + } + fn analyzeDecl { + (defClass(binder, def), envt, ng) { + tcEnv(binder, analyzeExp(def, envt, ng), envt) + } + (seqClass(first, second), envt, ng) { + analyzeDecl(second, anayzeDecl(first, envt, ng), ng); + } + (recClass(rec), envt, ng) { + let + #(env2, ng2) = analyzeRecDeclBind(rec, envt, ng); + in + analyzeRecDecl(rec, env2, ng2); + env2 + } + } + fn analyzeRecDeclBind { + (defClass(binder, _), envt, ng) { + fn (fresh) { + #(tcEnv(binder, fresh, envt), fresh @ ng) + } (newTypeVar()) + } + (seqClass(first, second), envt, ng) { + let + #(env1, ng1) = analyzeRecDeclBind(first, envt, ng); + in + analyzeRecDeclBind(second, env1, ng1); + } + } + fn analyzeRecDecl { + (defClass(binder, def), envt, ng) { + unifyType(retrieve(binder, envt, ng), + analyzeExp(def, envt, ng)) + } + (seqClass(first, second), envt, ng) { + analyzeRecDecl(first, envt, ng) and + analyzeRecDecl(second, envt, ng) + } + (recClass(rec), envt, ng) { + analyzeRec(rec, envt, ng) + } + } diff --git a/fn/listOfInt.fn b/fn/listOfInt.fn new file mode 100644 index 0000000..7660873 --- /dev/null +++ b/fn/listOfInt.fn @@ -0,0 +1 @@ +[1] diff --git a/src/common.h b/src/common.h index 4a00a0a..6959216 100644 --- a/src/common.h +++ b/src/common.h @@ -45,7 +45,7 @@ // #define DEBUG_BYTECODE // define this to make fatal errors dump core (if ulimit allows) #define DEBUG_DUMP_CORE -#define DEBUG_TC +// #define DEBUG_TC // #define DEBUG_LAMBDA_CONVERT // #define DEBUG_LEAK // #define DEBUG_ANF diff --git a/src/debugging_off.h b/src/debugging_off.h index c806153..1a6907d 100644 --- a/src/debugging_off.h +++ b/src/debugging_off.h @@ -1,3 +1,5 @@ +#ifndef cekf_debugging +#define cekf_debugging /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -23,3 +25,7 @@ #define DEBUG(...) #define IFDEBUG(x) #define NEWLINE() +#define DEBUGGING_ON() +#define DEBUGGING_OFF() + +#endif diff --git a/src/debugging_on.h b/src/debugging_on.h index 72d895b..c669c74 100644 --- a/src/debugging_on.h +++ b/src/debugging_on.h @@ -1,3 +1,5 @@ +#ifndef cekf_debugging +#define cekf_debugging /* * CEKF - VM supporting amb * Copyright (C) 2022-2023 Bill Hails @@ -17,12 +19,21 @@ */ static int _debugInvocationId = 0; +static bool _debuggingOn = true; +static int _debuggingDepth = 0; + #define DEBUG(...) do { \ - eprintf("*** %s:%-5d ", __FILE__, __LINE__); \ - eprintf(__VA_ARGS__); eprintf("\n"); \ + if (_debuggingOn) { \ + eprintf("%s:%-5d %*s", __FILE__, __LINE__, _debuggingDepth, ""); \ + eprintf(__VA_ARGS__); \ + eprintf("\n"); \ + } \ } while(0) -#define ENTER(name) int _debugMyId = _debugInvocationId++; \ - DEBUG("ENTER " #name " #%d", _debugMyId) -#define LEAVE(name) DEBUG("LEAVE " #name " #%d", _debugMyId) -#define NEWLINE() eprintf("\n") -#define IFDEBUG(x) do { x; NEWLINE(); } while(0) +#define ENTER(name) int _debugMyId = _debugInvocationId++; DEBUG("ENTER " #name " #%d", _debugMyId); _debuggingDepth++ +#define LEAVE(name) _debuggingDepth--; DEBUG("LEAVE " #name " #%d", _debugMyId) +#define NEWLINE() do { if (_debuggingOn) eprintf("\n"); } while(0) +#define IFDEBUG(x) do { if (_debuggingOn) { eprintf("%s:%-5d %*s", __FILE__, __LINE__, _debuggingDepth, ""); x; NEWLINE(); } } while(0) +#define DEBUGGING_ON() do { _debuggingOn = true; } while (0) +#define DEBUGGING_OFF() do { _debuggingOn = false; } while (0) + +#endif diff --git a/src/lexer.l b/src/lexer.l index bb9adc2..06f2684 100644 --- a/src/lexer.l +++ b/src/lexer.l @@ -21,7 +21,7 @@ struct PmModule *mod = yyextra; %} [ \t]+ {} -[\n] {} +[\n] { incLineNo(mod); } \/\/.* { } -?[0-9]+ { yylval->s = yytext; return NUMBER; } diff --git a/src/main.c b/src/main.c index b339afc..ed7c123 100644 --- a/src/main.c +++ b/src/main.c @@ -23,8 +23,8 @@ #include "common.h" #include "ast.h" -#include "debug_ast.h" -#include "debug_lambda.h" +#include "ast_debug.h" +#include "lambda_debug.h" #include "lambda_conversion.h" #include "module.h" // #include "parser.h" @@ -40,7 +40,7 @@ #include "anf.h" #include "bigint.h" #include "tc_analyze.h" -#include "debug_tc.h" +#include "tc_debug.h" #ifdef DEBUG_RUN_TESTS #if DEBUG_RUN_TESTS == 1 @@ -147,16 +147,19 @@ int main(int argc, char *argv[]) { eprintf("need filename\n"); exit(1); } + // parse => AST PmModule *mod = newPmToplevelFromFile(argv[optind]); PROTECT(mod); pmParseModule(mod); enableGC(); + // lambda conversion: AST => LamExp LamExp *exp = lamConvertNest(mod->nest, NULL); int save = PROTECT(exp); #ifdef DEBUG_LAMBDA_CONVERT ppLamExp(exp); eprintf("\n"); #endif + // type checking TcEnv *env = tc_init(); PROTECT(env); TcType *res = tc_analyze(exp, env); @@ -164,11 +167,13 @@ int main(int argc, char *argv[]) { return 1; } PROTECT(res); - printTcType(res, 0); + ppTcType(res); eprintf("\n"); + // normalization: LamExp => ANF Exp *anfExp = anfNormalize(exp); PROTECT(anfExp); disableGC(); + // desugaring anfExp = desugarExp(anfExp); PROTECT(anfExp); enableGC(); @@ -176,13 +181,17 @@ int main(int argc, char *argv[]) { printExp(anfExp); eprintf("\n"); #endif + // static analysis: ANF => annotated ANF (de bruijn) analizeExp(anfExp, NULL); + // byte code generation initByteCodeArray(&byteCodes); writeExp(anfExp, &byteCodes); writeEnd(&byteCodes); UNPROTECT(save); + // execution printContainedValue(run(byteCodes), 1); printf("\n"); + // report stats etc. if (report_mem_flag) reportMemory(); if (report_time_flag) { diff --git a/src/module.c b/src/module.c index dd03178..d69c28e 100644 --- a/src/module.c +++ b/src/module.c @@ -159,6 +159,10 @@ int popPmFile(PmModule *mod) { return 1; } +void incLineNo(PmModule *mod) { + if (mod != NULL && mod->bufStack != NULL) mod->bufStack->lineno++; +} + void showModuleState(FILE *fp, PmModule *mod) { if (mod == NULL) { fprintf(fp, "module is null\n"); @@ -168,5 +172,5 @@ void showModuleState(FILE *fp, PmModule *mod) { fprintf(fp, "module->bufStack is null\n"); return; } - fprintf(fp, "current file %s\n", mod->bufStack->filename); + fprintf(fp, "current file %s, line %d\n", mod->bufStack->filename, mod->bufStack->lineno + 1); } diff --git a/src/module.h b/src/module.h index dc3cfdf..a69da84 100644 --- a/src/module.h +++ b/src/module.h @@ -25,6 +25,7 @@ PmModule *newPmToplevelFromString(char *src, char *id); void markPmModule(Header *h); void freePmModule(Header *h); int pmParseModule(PmModule *mod); +void incLineNo(PmModule *mod); int popPmFile(PmModule *mod); void showModuleState(FILE *fp, PmModule *mod); diff --git a/src/tc.c b/src/tc.c index 711b1c6..89a4e9c 100644 --- a/src/tc.c +++ b/src/tc.c @@ -78,10 +78,11 @@ struct TcTypeDefArgs * newTcTypeDefArgs(struct TcType * type, struct TcTypeDefAr return x; } -struct TcVar * newTcVar(HashSymbol * name) { +struct TcVar * newTcVar(HashSymbol * name, int id) { struct TcVar * x = NEW(TcVar, OBJTYPE_TCVAR); DEBUG("new TcVar %pn", x); x->name = name; + x->id = id; x->instance = NULL; return x; } @@ -166,7 +167,9 @@ void markTcType(struct TcType * x) { case TCTYPE_TYPE_VAR: markTcVar(x->val.var); break; - case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_SMALLINTEGER: + break; + case TCTYPE_TYPE_BIGINTEGER: break; case TCTYPE_TYPE_CHARACTER: break; diff --git a/src/tc.h b/src/tc.h index 31d2049..014584d 100644 --- a/src/tc.h +++ b/src/tc.h @@ -30,7 +30,8 @@ typedef enum TcTypeType { TCTYPE_TYPE_FUNCTION, TCTYPE_TYPE_PAIR, TCTYPE_TYPE_VAR, - TCTYPE_TYPE_INTEGER, + TCTYPE_TYPE_SMALLINTEGER, + TCTYPE_TYPE_BIGINTEGER, TCTYPE_TYPE_CHARACTER, TCTYPE_TYPE_TYPEDEF, } TcTypeType; @@ -41,7 +42,8 @@ typedef union TcTypeVal { struct TcFunction * function; struct TcPair * pair; struct TcVar * var; - void * integer; + void * smallinteger; + void * biginteger; void * character; struct TcTypeDef * typeDef; } TcTypeVal; @@ -87,6 +89,7 @@ typedef struct TcTypeDefArgs { typedef struct TcVar { Header header; HashSymbol * name; + int id; struct TcType * instance; } TcVar; @@ -104,7 +107,7 @@ struct TcFunction * newTcFunction(struct TcType * arg, struct TcType * result); struct TcPair * newTcPair(struct TcType * first, struct TcType * second); struct TcTypeDef * newTcTypeDef(HashSymbol * name, struct TcTypeDefArgs * args); struct TcTypeDefArgs * newTcTypeDefArgs(struct TcType * type, struct TcTypeDefArgs * next); -struct TcVar * newTcVar(HashSymbol * name); +struct TcVar * newTcVar(HashSymbol * name, int id); struct TcType * newTcType(enum TcTypeType type, union TcTypeVal val); void markTcEnv(struct TcEnv * x); @@ -129,7 +132,8 @@ void freeTcType(struct TcType * x); #define TCTYPE_VAL_FUNCTION(x) ((union TcTypeVal ){.function = (x)}) #define TCTYPE_VAL_PAIR(x) ((union TcTypeVal ){.pair = (x)}) #define TCTYPE_VAL_VAR(x) ((union TcTypeVal ){.var = (x)}) -#define TCTYPE_VAL_INTEGER() ((union TcTypeVal ){.integer = (NULL)}) +#define TCTYPE_VAL_SMALLINTEGER() ((union TcTypeVal ){.smallinteger = (NULL)}) +#define TCTYPE_VAL_BIGINTEGER() ((union TcTypeVal ){.biginteger = (NULL)}) #define TCTYPE_VAL_CHARACTER() ((union TcTypeVal ){.character = (NULL)}) #define TCTYPE_VAL_TYPEDEF(x) ((union TcTypeVal ){.typeDef = (x)}) diff --git a/src/tc.yaml b/src/tc.yaml index 01da121..41c2c9d 100644 --- a/src/tc.yaml +++ b/src/tc.yaml @@ -21,6 +21,12 @@ config: description: Structures to support type inference +cmp: + extraArgs: + map: HashTable + bespokeImplementation: + - TcVar + structs: TcEnv: table: HashTable @@ -48,6 +54,7 @@ structs: TcVar: name: HashSymbol + id: int instance: TcType=NULL unions: @@ -55,7 +62,8 @@ unions: function: TcFunction pair: TcPair var: TcVar - integer: void_ptr + smallinteger: void_ptr + biginteger: void_ptr character: void_ptr typeDef: TcTypeDef diff --git a/src/tc_analyze.c b/src/tc_analyze.c index b573e87..e9edf23 100644 --- a/src/tc_analyze.c +++ b/src/tc_analyze.c @@ -28,6 +28,7 @@ #ifdef DEBUG_TC #include "debugging_on.h" +#include "lambda_pp.h" #else #include "debugging_off.h" #endif @@ -41,7 +42,8 @@ static void addToNg(TcNg *env, HashSymbol *symbol, TcType *type); static void addFreshVarToEnv(TcEnv *env, HashSymbol *key); static void addCmpToEnv(TcEnv *env, HashSymbol *key); static TcType *makeBoolean(void); -static TcType *makeInteger(void); +static TcType *makeSmallInteger(void); +static TcType *makeBigInteger(void); static TcType *makeCharacter(void); static TcType *makeFreshVar(void); static TcType *makeVar(HashSymbol *t); @@ -56,7 +58,7 @@ static void addThenToEnv(TcEnv *env); static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng); static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng); static TcType *analyzeVar(HashSymbol *var, TcEnv *env, TcNg *ng); -static TcType *analyzeStdInt(int val, TcEnv *env, TcNg *ng); +static TcType *analyzeSmallInteger(); static TcType *analyzeBigInteger(); static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng); static TcType *analyzeUnary(LamUnaryApp *app, TcEnv *env, TcNg *ng); @@ -75,7 +77,7 @@ static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng); static TcType *analyzeAnd(LamAnd *and, TcEnv *env, TcNg *ng); static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng); static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng); -static TcType *analyzeCharacter(char c, TcEnv *env, TcNg *ng); +static TcType *analyzeCharacter(); static TcType *analyzeBack(); static TcType *analyzeError(); static bool unify(TcType *a, TcType *b); @@ -83,11 +85,16 @@ static TcType *prune(TcType *t); static bool occursInType(TcType *a, TcType *b); static bool occursIn(TcType *a, TcType *b); static bool sameType(TcType *a, TcType *b); -static TcType *analyzeIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng); +static TcType *analyzeBigIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng); +static TcType *analyzeSmallIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng); static TcType *analyzeBooleanExp(LamExp *exp, TcEnv *env, TcNg *ng); static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map); static TcType *lookup(TcEnv *env, HashSymbol *symbol, TcNg *ng); static TcType *makeTypeDef(HashSymbol *name, TcTypeDefArgs *args); +static void markType(void *ptr); +static void printType(void *ptr, int depth); + +static int id_counter = 0; TcEnv *tc_init(void) { TcEnv *env = extendEnv(NULL); @@ -119,6 +126,7 @@ TcEnv *tc_init(void) { TcType *tc_analyze(LamExp *exp, TcEnv *env) { TcNg *ng = extendNg(NULL); + IFDEBUG(ppLamExp(exp)); return prune(analyzeExp(exp, env, ng)); } @@ -130,7 +138,7 @@ static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { case LAMEXP_TYPE_VAR: return analyzeVar(exp->val.var, env, ng); case LAMEXP_TYPE_STDINT: - return analyzeStdInt(exp->val.stdint, env, ng); + return analyzeSmallInteger(); case LAMEXP_TYPE_BIGINTEGER: return analyzeBigInteger(); case LAMEXP_TYPE_PRIM: @@ -170,7 +178,7 @@ static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { case LAMEXP_TYPE_AMB: return analyzeAmb(exp->val.amb, env, ng); case LAMEXP_TYPE_CHARACTER: - return analyzeCharacter(exp->val.character, env, ng); + return analyzeCharacter(); case LAMEXP_TYPE_BACK: return analyzeBack(); case LAMEXP_TYPE_ERROR: @@ -184,7 +192,10 @@ static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { static TcType *makeFunctionType(LamVarList *args, TcEnv *env, TcType *returnType) { ENTER(makeFunctionType); - if (args == NULL) return returnType; + if (args == NULL) { + LEAVE(makeFunctionType); + return returnType; + } TcType *next = makeFunctionType(args->next, env, returnType); int save = PROTECT(next); TcType *this = NULL; @@ -199,15 +210,16 @@ static TcType *makeFunctionType(LamVarList *args, TcEnv *env, TcType *returnType static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng) { ENTER(analyzeLam); + IFDEBUG(ppLamLam(lam)); env = extendEnv(env); int save = PROTECT(env); ng = extendNg(ng); PROTECT(ng); for (LamVarList *args = lam->args; args != NULL; args = args->next) { - TcType *fresh = makeFreshVar(); - int save2 = PROTECT(fresh); - addToEnv(env, args->var, fresh); - addToNg(ng, fresh->val.var->name, fresh); + TcType *freshVar = makeFreshVar(); + int save2 = PROTECT(freshVar); + addToEnv(env, args->var, freshVar); + addToNg(ng, freshVar->val.var->name, freshVar); UNPROTECT(save2); } TcType *returnType = analyzeExp(lam->exp, env, ng); @@ -215,34 +227,45 @@ static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng) { TcType *functionType = makeFunctionType(lam->args, env, returnType); UNPROTECT(save); LEAVE(analyzeLam); + IFDEBUG(ppTcType(functionType)); return functionType; } static TcType *analyzeVar(HashSymbol *var, TcEnv *env, TcNg *ng) { ENTER(analyzeVar); + DEBUG("var: %s", var->name); TcType *res = lookup(env, var, ng); if (res == NULL) { can_happen("undefined variable %s in analyzeVar", var->name); } LEAVE(analyzeVar); + DEBUG("var: %s", var->name); + IFDEBUG(ppTcType(res)); return res; } -static TcType *analyzeStdInt(int val __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { - cant_happen("analyzeStdInt not implemented yet"); +static TcType *analyzeSmallInteger() { + return makeSmallInteger(); } static TcType *analyzeBigInteger() { ENTER(analyzeBigInteger); - TcType *res = makeInteger(); + TcType *res = makeBigInteger(); LEAVE(analyzeBigInteger); return res; } static TcType *analyzeBinaryArith(LamExp *exp1, LamExp *exp2, TcEnv *env, TcNg *ng) { ENTER(analyzeBinaryArith); - (void) analyzeIntegerExp(exp1, env, ng); - TcType *res = analyzeIntegerExp(exp2, env, ng); + (void) analyzeBigIntegerExp(exp1, env, ng); + TcType *res = analyzeBigIntegerExp(exp2, env, ng); + LEAVE(analyzeBinaryArith); + return res; +} + +static TcType *analyzeUnaryArith(LamExp *exp, TcEnv *env, TcNg *ng) { + ENTER(analyzeBinaryArith); + TcType *res = analyzeBigIntegerExp(exp, env, ng); LEAVE(analyzeBinaryArith); return res; } @@ -268,6 +291,13 @@ static TcType *analyzeBinaryBool(LamExp *exp1, LamExp *exp2, TcEnv *env, TcNg *n return res; } +static TcType *analyzeUnaryBool(LamExp *exp, TcEnv *env, TcNg *ng) { + ENTER(analyzeBinaryBool); + TcType *res = analyzeBooleanExp(exp, env, ng); + LEAVE(analyzeBinaryBool); + return res; +} + static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng) { ENTER(analyzePrim); switch (app->type) { @@ -296,7 +326,7 @@ static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng) { // is the (vec 0 ) to retrieve the tag of a constructor. All other calls to vec // should be hidden behind (deconstruct ...) syntax // plus in this case the exp2 arg will always be a simple var so analysis would not prove anything - TcType *res = makeInteger(); + TcType *res = makeSmallInteger(); LEAVE(analyzePrim); return res; } @@ -310,12 +340,29 @@ static TcType *analyzePrim(LamPrimApp *app, TcEnv *env, TcNg *ng) { } } -static TcType *analyzeUnary(LamUnaryApp *app __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { - cant_happen("analyzeUnary not implemented yet"); +static TcType *analyzeUnary(LamUnaryApp *app, TcEnv *env, TcNg *ng) { + ENTER(analyzeUnary); + TcType *res = NULL; + switch (app->type) { + case LAMUNARYOP_TYPE_NEG: + res = analyzeUnaryArith(app->exp, env, ng); + break; + case LAMUNARYOP_TYPE_NOT: + res = analyzeUnaryBool(app->exp, env, ng); + break; + case LAMUNARYOP_TYPE_PRINT: + res = analyzeExp(app->exp, env, ng); + break; + default: + cant_happen("unrecognized type %d in analyzeUnary", app->type); + } + LEAVE(analyzeUnary); + return res; } static TcType *analyzeSequence(LamSequence *sequence, TcEnv *env, TcNg *ng) { ENTER(analyzeSequence); + IFDEBUG(ppLamSequence(sequence)); if (sequence == NULL) { cant_happen("NULL sequence in analyzeSequence"); } @@ -385,6 +432,7 @@ static TcType *findResultType(TcType *fn) { static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, TcNg *ng) { ENTER(analyzeDeconstruct); + IFDEBUG(ppLamDeconstruct(deconstruct)); TcType *constructor = NULL; if (!getFromEnv(env, deconstruct->name, &constructor)) { can_happen("undefined type deconstructor %s", deconstruct->name->name); @@ -399,6 +447,8 @@ static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, TcNg unify(expType, resultType); UNPROTECT(save); LEAVE(analyzeDeconstruct); + IFDEBUG(ppTcType(fieldType)); + IFDEBUG(ppTcType(resultType)); return fieldType; } @@ -439,15 +489,22 @@ static LamApply *curryLamApply(LamApply *apply) { static TcType *analyzeApply(LamApply *apply, TcEnv *env, TcNg *ng) { ENTER(analyzeApply); + IFDEBUG(ppLamApply(apply)); switch (apply->nargs) { case 0: { + DEBUG("analyzeApply, nargs: 0"); TcType *res = analyzeExp(apply->function, env, ng); LEAVE(analyzeApply); + IFDEBUG(ppLamApply(apply)); + IFDEBUG(ppTcType(res)); return res; } case 1: { + DEBUG("analyzeApply, nargs: 1"); TcType *fn = analyzeExp(apply->function, env, ng); int save = PROTECT(fn); + DEBUG("analyzeApply function is"); + IFDEBUG(ppTcType(fn)); TcType *arg = analyzeExp(apply->args->exp, env, ng); PROTECT(arg); TcType *res = makeFreshVar(); @@ -457,14 +514,19 @@ static TcType *analyzeApply(LamApply *apply, TcEnv *env, TcNg *ng) { unify(fn, functionType); UNPROTECT(save); LEAVE(analyzeApply); + IFDEBUG(ppLamApply(apply)); + IFDEBUG(ppTcType(res)); return res; } default:{ + DEBUG("analyzeApply, nargs: %d", apply->nargs); LamApply *curried = curryLamApply(apply); int save = PROTECT(curried); TcType *res = analyzeApply(curried, env, ng); UNPROTECT(save); + IFDEBUG(ppLamApply(apply)); LEAVE(analyzeApply); + IFDEBUG(ppTcType(res)); return res; } } @@ -483,8 +545,21 @@ static TcType *analyzeIff(LamIff *iff, TcEnv *env, TcNg *ng) { return consequent; } -static TcType *analyzeCallCC(LamExp *called __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { - cant_happen("analyzeCallCC not implemented yet"); +static TcType *analyzeCallCC(LamExp *called, TcEnv *env, TcNg *ng) { + // 'call/cc' is ((a -> b) -> a) -> a + TcType *a = makeFreshVar(); + int save = PROTECT(a); + TcType *b = makeFreshVar(); + PROTECT(b); + TcType *ab = makeFn(a, b); + PROTECT(ab); + TcType *aba = makeFn(ab, a); + PROTECT(aba); + TcType *calledType = analyzeExp(called, env, ng); + PROTECT(calledType); + unify(calledType, aba); + UNPROTECT(save); + return a; } static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { @@ -494,40 +569,47 @@ static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { ng = extendNg(ng); PROTECT(ng); for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { - TcType *fresh = makeFreshVar(); - int save2 = PROTECT(fresh); - addToEnv(env, bindings->var, fresh); - addToNg(ng, fresh->val.var->name, fresh); + TcType *freshVar = makeFreshVar(); + int save2 = PROTECT(freshVar); + addToEnv(env, bindings->var, freshVar); + // addToNg(ng, freshVar->val.var->name, freshVar); UNPROTECT(save2); } for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { - TcType *fresh = lookup(env, bindings->var, ng); - if (fresh == NULL) { + DEBUG("analyzeLetRec considering %s", bindings->var->name); + TcType *freshVar = NULL; + if (!getFromEnv(env, bindings->var, &freshVar)) { cant_happen("failed to retrieve fresh var from env in analyzeLetRec"); } + int save2 = PROTECT(freshVar); TcType *type = analyzeExp(bindings->val, env, ng); - int save2 = PROTECT(type); - unify(fresh, type); + PROTECT(type); + unify(freshVar, type); + DEBUG("analyzeLetRec binding %s, result:", bindings->var->name); + IFDEBUG(ppTcType(freshVar)); + // addToEnv(env, bindings->var, freshVar); UNPROTECT(save2); } - printTcEnv(env, 0); - eprintf("\n"); - printTcNg(ng, 0); - eprintf("\n"); + IFDEBUG(printTcEnv(env, 0)); TcType *res = analyzeExp(letRec->body, env, ng); UNPROTECT(save); LEAVE(analyzeLetRec); return res; } -static TcTypeDefArgs *makeTcTypeDefArgs(LamTypeArgs *lamTypeArgs) { +static TcTypeDefArgs *makeTcTypeDefArgs(LamTypeArgs *lamTypeArgs, HashTable *map) { if (lamTypeArgs == NULL) { return NULL; } - TcTypeDefArgs *next = makeTcTypeDefArgs(lamTypeArgs->next); + TcTypeDefArgs *next = makeTcTypeDefArgs(lamTypeArgs->next, map); int save = PROTECT(next); - TcType *name = makeVar(lamTypeArgs->name); - PROTECT(name); + TcType *name = NULL; + if (!hashGet(map, lamTypeArgs->name, &name)) { + name = makeVar(lamTypeArgs->name); + int save2 = PROTECT(name); + hashSet(map, lamTypeArgs->name, &name); + UNPROTECT(save2); + } TcTypeDefArgs *this = newTcTypeDefArgs(name, next); UNPROTECT(save); return this; @@ -539,36 +621,37 @@ static TcType *makeTypeDef(HashSymbol *name, TcTypeDefArgs *args) { TcType *res = newTcType(TCTYPE_TYPE_TYPEDEF, TCTYPE_VAL_TYPEDEF(tcTypeDef)); UNPROTECT(save); DEBUG("makeTypeDef: %s %p", name->name, res); + IFDEBUG(ppTcTypeDef(tcTypeDef)); return res; } -static TcType *makeTcTypeDefType(LamType *lamType) { - TcTypeDefArgs *args = makeTcTypeDefArgs(lamType->args); +static TcType *makeTcTypeDefType(LamType *lamType, HashTable *map) { + TcTypeDefArgs *args = makeTcTypeDefArgs(lamType->args, map); int save = PROTECT(args); TcType *res = makeTypeDef(lamType->name, args); UNPROTECT(save); return res; } -static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg); +static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg, HashTable *map); -static TcTypeDefArgs *makeTypeDefArgs(LamTypeConstructorArgs *args) { +static TcTypeDefArgs *makeTypeDefArgs(LamTypeConstructorArgs *args, HashTable *map) { if (args == NULL) { return NULL; } - TcTypeDefArgs *next = makeTypeDefArgs(args->next); + TcTypeDefArgs *next = makeTypeDefArgs(args->next, map); int save = PROTECT(next); - TcType *arg = makeTypeConstructorArg(args->arg); + TcType *arg = makeTypeConstructorArg(args->arg, map); PROTECT(arg); TcTypeDefArgs *this = newTcTypeDefArgs(arg, next); UNPROTECT(save); return this; } -static TcType *makeTypeConstructorApplication(LamTypeFunction *func) { +static TcType *makeTypeConstructorApplication(LamTypeFunction *func, HashTable *map) { // this code is building the inner application of a type, i.e. // list(t) in the context of t -> list(t) -> list(t) - TcTypeDefArgs *args = makeTypeDefArgs(func->args); + TcTypeDefArgs *args = makeTypeDefArgs(func->args, map); if (args == NULL) { cant_happen("null args to LamTypeDefFunction '%s' in makeTypeConstructorApplication", func->name->name); } @@ -578,20 +661,26 @@ static TcType *makeTypeConstructorApplication(LamTypeFunction *func) { return res; } -static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg) { +static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg, HashTable *map) { TcType *res = NULL; switch (arg->type) { case LAMTYPECONSTRUCTORTYPE_TYPE_INTEGER: - res = makeInteger(); + res = makeBigInteger(); break; case LAMTYPECONSTRUCTORTYPE_TYPE_CHARACTER: res = makeCharacter(); break; - case LAMTYPECONSTRUCTORTYPE_TYPE_VAR: - res = makeVar(arg->val.var); - break; + case LAMTYPECONSTRUCTORTYPE_TYPE_VAR: { + if (!hashGet(map, arg->val.var, &res)) { + res = makeVar(arg->val.var); + int save = PROTECT(res); + hashSet(map, arg->val.var, &res); + UNPROTECT(save); + } + } + break; case LAMTYPECONSTRUCTORTYPE_TYPE_FUNCTION: - res = makeTypeConstructorApplication(arg->val.function); + res = makeTypeConstructorApplication(arg->val.function, map); break; default: cant_happen("unrecognised type %d in collectTypeConstructorArg", arg->type); @@ -599,34 +688,36 @@ static TcType *makeTypeConstructorArg(LamTypeConstructorType *arg) { return res; } -static TcType *makeTypeDefConstructor(LamTypeConstructorArgs *args, TcType *result) { +static TcType *makeTypeDefConstructor(LamTypeConstructorArgs *args, TcType *result, HashTable *map) { // this code is building the top-level type of a type constructor, i.e. // pair => t -> list(t) -> list(t) if (args == NULL) { return result; } - TcType *next = makeTypeDefConstructor(args->next, result); + TcType *next = makeTypeDefConstructor(args->next, result, map); int save = PROTECT(next); - TcType *this = makeTypeConstructorArg(args->arg); + TcType *this = makeTypeConstructorArg(args->arg, map); PROTECT(this); TcType *res = makeFn(this, next); UNPROTECT(save); return res; } -static void collectTypeDefConstructor(LamTypeConstructor *constructor, TcType *type, TcEnv *env) { - TcType *res = makeTypeDefConstructor(constructor->args, type); +static void collectTypeDefConstructor(LamTypeConstructor *constructor, TcType *type, TcEnv *env, HashTable *map) { + TcType *res = makeTypeDefConstructor(constructor->args, type, map); int save = PROTECT(res); addToEnv(env, constructor->name, res); UNPROTECT(save); } static void collectTypeDef(LamTypeDef *lamTypeDef, TcEnv *env) { + HashTable *map = newHashTable(sizeof(TcType *), markType, printType); + int save = PROTECT(map); LamType *lamType = lamTypeDef->type; - TcType *tcType = makeTcTypeDefType(lamType); - int save = PROTECT(tcType); + TcType *tcType = makeTcTypeDefType(lamType, map); + PROTECT(tcType); for (LamTypeConstructorList *list = lamTypeDef->constructors; list != NULL; list = list->next) { - collectTypeDefConstructor(list->constructor, tcType, env); + collectTypeDefConstructor(list->constructor, tcType, env, map); } UNPROTECT(save); } @@ -675,10 +766,20 @@ static TcType *unifyMatchCases(LamMatchList *cases, TcEnv *env, TcNg *ng) { return this; } -static TcType *analyzeIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { +static TcType *analyzeBigIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { TcType *type = analyzeExp(exp, env, ng); int save = PROTECT(type); - TcType *integer = makeInteger(); + TcType *integer = makeBigInteger(); + PROTECT(integer); + unify(type, integer); + UNPROTECT(save); + return integer; +} + +static TcType *analyzeSmallIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { + TcType *type = analyzeExp(exp, env, ng); + int save = PROTECT(type); + TcType *integer = makeSmallInteger(); PROTECT(integer); unify(type, integer); UNPROTECT(save); @@ -696,7 +797,7 @@ static TcType *analyzeBooleanExp(LamExp *exp, TcEnv *env, TcNg *ng) { } static TcType *analyzeMatch(LamMatch *match, TcEnv *env, TcNg *ng) { - (void) analyzeIntegerExp(match->index, env, ng); + (void) analyzeSmallIntegerExp(match->index, env, ng); TcType *res = unifyMatchCases(match->cases, env, ng); return res; } @@ -730,7 +831,7 @@ static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng) { PROTECT(value); switch (cond->cases->type) { case LAMCONDCASES_TYPE_INTEGERS: { - TcType *integer = makeInteger(); + TcType *integer = makeBigInteger(); PROTECT(integer); unify(value, integer); result = unifyIntCondCases(cond->cases->val.integers, env, ng); @@ -750,8 +851,9 @@ static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng) { return result; } -static TcType *analyzeAnd(LamAnd *and __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { - cant_happen("analyzeAnd not implemented yet"); +static TcType *analyzeAnd(LamAnd *and, TcEnv *env, TcNg *ng) { + TcType *res = analyzeBinaryBool(and->left, and->right, env, ng); + return res; } static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng) { @@ -769,8 +871,11 @@ static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng) { return left; } -static TcType *analyzeCharacter(char c __attribute__((unused)), TcEnv *env __attribute__((unused)), TcNg *ng __attribute__((unused))) { - cant_happen("analyzeCharacter not implemented yet"); +static TcType *analyzeCharacter() { + ENTER(analyzeCharacter); + TcType *res = makeCharacter(); + LEAVE(analyzeCharacter); + return res; } static TcType *analyzeBack() { @@ -794,6 +899,8 @@ static void printType(void *ptr, int depth) { } static void addToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { + DEBUG("addToEnv %s =>", symbol->name); + IFDEBUG(ppTcType(type)); hashSet(env->table, symbol, &type); } @@ -854,10 +961,14 @@ static TcTypeDefArgs *freshTypeDefArgs(TcTypeDefArgs *args, TcNg *ng, HashTable } static TcType *freshTypeDef(TcTypeDef *typeDef, TcNg *ng, HashTable *map) { + ENTER(freshTypeDef); TcTypeDefArgs *args = freshTypeDefArgs(typeDef->args, ng, map); int save = PROTECT(args); TcType *res = makeTypeDef(typeDef->name, args); UNPROTECT(save); + LEAVE(freshTypeDef); + IFDEBUG(ppTcTypeDef(typeDef)); + IFDEBUG(ppTcType(res)); return res; } @@ -883,7 +994,7 @@ static TcType *typeGetOrPut(HashTable *map, TcType *typeVar, TcType *defaultValu if (hashGet(map, name, &res)) { return res; } - hashSet(map, name, defaultValue); + hashSet(map, name, &defaultValue); return defaultValue; } @@ -906,7 +1017,8 @@ static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map) { return res; } return type; - case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_SMALLINTEGER: + case TCTYPE_TYPE_BIGINTEGER: case TCTYPE_TYPE_CHARACTER: return type; case TCTYPE_TYPE_TYPEDEF: { @@ -919,23 +1031,35 @@ static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map) { } static TcType *fresh(TcType *type, TcNg *ng) { + ENTER(fresh); + IFDEBUG(ppTcType(type)); HashTable *map = makeTypeMap(); int save = PROTECT(map); TcType *res = freshRec(type, ng, map); UNPROTECT(save); + LEAVE(fresh); + IFDEBUG(ppTcType(res)); return res; } static TcType *lookup(TcEnv *env, HashSymbol *symbol, TcNg *ng) { + ENTER(lookup); + DEBUG("lookup: %s", symbol->name); TcType *type = NULL; if (getFromEnv(env, symbol, &type)) { TcType *res = fresh(type, ng); + LEAVE(lookup); + IFDEBUG(ppTcType(res)); return res; } + LEAVE(lookup); + DEBUG("NULL"); return NULL; } static void addToNg(TcNg *env, HashSymbol *symbol, TcType *type) { + DEBUG("addToNg %s =>", symbol->name); + IFDEBUG(ppTcType(type)); hashSet(env->table, symbol, &type); } @@ -973,11 +1097,11 @@ static TcNg *extendNg(TcNg *parent) { } static TcType *makeVar(HashSymbol *t) { - TcVar *var = newTcVar(t); + TcVar *var = newTcVar(t, id_counter++); int save = PROTECT(var); TcType *res = newTcType(TCTYPE_TYPE_VAR, TCTYPE_VAL_VAR(var)); UNPROTECT(save); - DEBUG("makeVar %p", res); + DEBUG("makeVar %s %p", t->name, res); return res; } @@ -985,9 +1109,15 @@ static TcType *makeFreshVar() { return makeVar(genSym("t$")); } -static TcType *makeInteger() { - TcType *res = newTcType(TCTYPE_TYPE_INTEGER, TCTYPE_VAL_INTEGER()); - DEBUG("makeInteger %p", res); +static TcType *makeSmallInteger() { + TcType *res = newTcType(TCTYPE_TYPE_SMALLINTEGER, TCTYPE_VAL_SMALLINTEGER()); + DEBUG("makeSmallInteger %p", res); + return res; +} + +static TcType *makeBigInteger() { + TcType *res = newTcType(TCTYPE_TYPE_BIGINTEGER, TCTYPE_VAL_BIGINTEGER()); + DEBUG("makeBigInteger %p", res); return res; } @@ -1005,7 +1135,7 @@ static void addUnOpToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { } static void addNegToEnv(TcEnv *env) { - TcType *integer = makeInteger(); + TcType *integer = makeBigInteger(); int save = PROTECT(integer); addUnOpToEnv(env, negSymbol(), integer); UNPROTECT(save); @@ -1084,7 +1214,7 @@ static void addBinOpToEnv(TcEnv *env, HashSymbol *symbol, TcType *type) { static void addIntBinOpToEnv(TcEnv *env, HashSymbol *symbol) { // int -> int -> int - TcType *integer = makeInteger(); + TcType *integer = makeBigInteger(); int save = PROTECT(integer); addBinOpToEnv(env, symbol, integer); UNPROTECT(save); @@ -1148,17 +1278,23 @@ static bool unifyTypeDefs(TcTypeDef *a, TcTypeDef *b) { static bool unify(TcType *a, TcType *b) { a = prune(a); b = prune(b); + DEBUG("UNIFY"); + IFDEBUG(ppTcType(a); eprintf(" WITH "); ppTcType(b)); if (a->type == TCTYPE_TYPE_VAR) { if (b->type != TCTYPE_TYPE_VAR) { if (occursInType(a, b)) { can_happen("occurs-in check failed"); return false; } + DEBUG("unify combining"); a->val.var->instance = b; + IFDEBUG(ppTcType(a)); return true; } if (a->val.var->name != b->val.var->name) { + DEBUG("unify combining"); a->val.var->instance = b; + IFDEBUG(ppTcType(a)); } return true; } else if (b->type == TCTYPE_TYPE_VAR) { @@ -1179,7 +1315,8 @@ static bool unify(TcType *a, TcType *b) { return unifyPairs(a->val.pair, b->val.pair); case TCTYPE_TYPE_VAR: cant_happen("encountered var in unify"); - case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_SMALLINTEGER: + case TCTYPE_TYPE_BIGINTEGER: case TCTYPE_TYPE_CHARACTER: return true; case TCTYPE_TYPE_TYPEDEF: @@ -1191,6 +1328,13 @@ static bool unify(TcType *a, TcType *b) { cant_happen("reached end of unify"); } +static void pruneTypeDefArgs(TcTypeDefArgs *args) { + while (args != NULL) { + args->type = prune(args->type); + args = args->next; + } +} + static TcType *prune(TcType *t) { if (t == NULL) return NULL; if (t->type == TCTYPE_TYPE_VAR) { @@ -1198,6 +1342,11 @@ static TcType *prune(TcType *t) { t->val.var->instance = prune(t->val.var->instance); return t->val.var->instance; } + } else if (t->type == TCTYPE_TYPE_TYPEDEF) { + pruneTypeDefArgs(t->val.typeDef->args); + } else if (t->type == TCTYPE_TYPE_FUNCTION) { + t->val.function->arg = prune(t->val.function->arg); + t->val.function->result = prune(t->val.function->result); } return t; } @@ -1228,6 +1377,8 @@ static bool sameTypeDefType(TcTypeDef *a, TcTypeDef *b) { } static bool sameType(TcType *a, TcType *b) { + a = prune(a); + b = prune(b); if (a == NULL || b == NULL) { cant_happen("NULL in sameType"); } @@ -1240,8 +1391,9 @@ static bool sameType(TcType *a, TcType *b) { case TCTYPE_TYPE_PAIR: return samePairType(a->val.pair, b->val.pair); case TCTYPE_TYPE_VAR: - return a->val.var->name == b->val.var->name; - case TCTYPE_TYPE_INTEGER: + return a->val.var->id == b->val.var->id; + case TCTYPE_TYPE_BIGINTEGER: + case TCTYPE_TYPE_SMALLINTEGER: case TCTYPE_TYPE_CHARACTER: return true; case TCTYPE_TYPE_TYPEDEF: @@ -1284,7 +1436,8 @@ static bool occursIn(TcType *a, TcType *b) { return occursInPair(a, b->val.pair); case TCTYPE_TYPE_VAR: cant_happen("occursIn 2nd arg should not be a var"); - case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_SMALLINTEGER: + case TCTYPE_TYPE_BIGINTEGER: case TCTYPE_TYPE_CHARACTER: return false; case TCTYPE_TYPE_TYPEDEF: diff --git a/src/tc_debug.c b/src/tc_debug.c index 04e59ac..583744b 100644 --- a/src/tc_debug.c +++ b/src/tc_debug.c @@ -103,6 +103,9 @@ void printTcVar(struct TcVar * x, int depth) { if (x == NULL) { eprintf("TcVar (NULL)"); return; } eprintf("TcVar[\n"); printAstSymbol(x->name, depth + 1); + eprintf("\n"); + pad(depth + 1); +eprintf("int %d", x->id); eprintf("\n"); printTcType(x->instance, depth + 1); eprintf("\n"); @@ -130,11 +133,17 @@ void printTcType(struct TcType * x, int depth) { eprintf("TCTYPE_TYPE_VAR\n"); printTcVar(x->val.var, depth + 1); break; - case TCTYPE_TYPE_INTEGER: + case TCTYPE_TYPE_SMALLINTEGER: pad(depth + 1); - eprintf("TCTYPE_TYPE_INTEGER\n"); + eprintf("TCTYPE_TYPE_SMALLINTEGER\n"); pad(depth + 1); -eprintf("void * %p", x->val.integer); +eprintf("void * %p", x->val.smallinteger); + break; + case TCTYPE_TYPE_BIGINTEGER: + pad(depth + 1); + eprintf("TCTYPE_TYPE_BIGINTEGER\n"); + pad(depth + 1); +eprintf("void * %p", x->val.biginteger); break; case TCTYPE_TYPE_CHARACTER: pad(depth + 1); @@ -158,84 +167,82 @@ eprintf("void * %p", x->val.character); /***************************************/ -bool eqTcEnv(struct TcEnv * a, struct TcEnv * b) { +bool eqTcEnv(struct TcEnv * a, struct TcEnv * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; if (a->table != b->table) return false; - if (!eqTcEnv(a->next, b->next)) return false; + if (!eqTcEnv(a->next, b->next, map)) return false; return true; } -bool eqTcNg(struct TcNg * a, struct TcNg * b) { +bool eqTcNg(struct TcNg * a, struct TcNg * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; if (a->table != b->table) return false; - if (!eqTcNg(a->next, b->next)) return false; + if (!eqTcNg(a->next, b->next, map)) return false; return true; } -bool eqTcFunction(struct TcFunction * a, struct TcFunction * b) { +bool eqTcFunction(struct TcFunction * a, struct TcFunction * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; - if (!eqTcType(a->arg, b->arg)) return false; - if (!eqTcType(a->result, b->result)) return false; + if (!eqTcType(a->arg, b->arg, map)) return false; + if (!eqTcType(a->result, b->result, map)) return false; return true; } -bool eqTcPair(struct TcPair * a, struct TcPair * b) { +bool eqTcPair(struct TcPair * a, struct TcPair * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; - if (!eqTcType(a->first, b->first)) return false; - if (!eqTcType(a->second, b->second)) return false; + if (!eqTcType(a->first, b->first, map)) return false; + if (!eqTcType(a->second, b->second, map)) return false; return true; } -bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b) { +bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; if (a->name != b->name) return false; - if (!eqTcTypeDefArgs(a->args, b->args)) return false; + if (!eqTcTypeDefArgs(a->args, b->args, map)) return false; return true; } -bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b) { +bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; - if (!eqTcType(a->type, b->type)) return false; - if (!eqTcTypeDefArgs(a->next, b->next)) return false; + if (!eqTcType(a->type, b->type, map)) return false; + if (!eqTcTypeDefArgs(a->next, b->next, map)) return false; return true; } -bool eqTcVar(struct TcVar * a, struct TcVar * b) { - if (a == b) return true; - if (a == NULL || b == NULL) return false; - if (a->name != b->name) return false; - if (!eqTcType(a->instance, b->instance)) return false; - return true; -} +// Bespoke implementation required for +// bool eqTcVar(struct TcVar * a, struct TcVar * b, HashTable *map) -bool eqTcType(struct TcType * a, struct TcType * b) { +bool eqTcType(struct TcType * a, struct TcType * b, HashTable *map) { if (a == b) return true; if (a == NULL || b == NULL) return false; if (a->type != b->type) return false; switch(a->type) { case TCTYPE_TYPE_FUNCTION: - if (!eqTcFunction(a->val.function, b->val.function)) return false; + if (!eqTcFunction(a->val.function, b->val.function, map)) return false; break; case TCTYPE_TYPE_PAIR: - if (!eqTcPair(a->val.pair, b->val.pair)) return false; + if (!eqTcPair(a->val.pair, b->val.pair, map)) return false; break; case TCTYPE_TYPE_VAR: - if (!eqTcVar(a->val.var, b->val.var)) return false; + if (!eqTcVar(a->val.var, b->val.var, map)) return false; + break; + case TCTYPE_TYPE_SMALLINTEGER: + if (a->val.smallinteger != b->val.smallinteger) return false; break; - case TCTYPE_TYPE_INTEGER: - if (a->val.integer != b->val.integer) return false; + case TCTYPE_TYPE_BIGINTEGER: + if (a->val.biginteger != b->val.biginteger) return false; break; case TCTYPE_TYPE_CHARACTER: if (a->val.character != b->val.character) return false; break; case TCTYPE_TYPE_TYPEDEF: - if (!eqTcTypeDef(a->val.typeDef, b->val.typeDef)) return false; + if (!eqTcTypeDef(a->val.typeDef, b->val.typeDef, map)) return false; break; default: cant_happen("unrecognised type %d in eqTcType", a->type); diff --git a/src/tc_debug.h b/src/tc_debug.h index 8b2f496..c26e491 100644 --- a/src/tc_debug.h +++ b/src/tc_debug.h @@ -33,13 +33,13 @@ void printTcTypeDefArgs(struct TcTypeDefArgs * x, int depth); void printTcVar(struct TcVar * x, int depth); void printTcType(struct TcType * x, int depth); -bool eqTcEnv(struct TcEnv * a, struct TcEnv * b); -bool eqTcNg(struct TcNg * a, struct TcNg * b); -bool eqTcFunction(struct TcFunction * a, struct TcFunction * b); -bool eqTcPair(struct TcPair * a, struct TcPair * b); -bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b); -bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b); -bool eqTcVar(struct TcVar * a, struct TcVar * b); -bool eqTcType(struct TcType * a, struct TcType * b); +bool eqTcEnv(struct TcEnv * a, struct TcEnv * b, HashTable *map); +bool eqTcNg(struct TcNg * a, struct TcNg * b, HashTable *map); +bool eqTcFunction(struct TcFunction * a, struct TcFunction * b, HashTable *map); +bool eqTcPair(struct TcPair * a, struct TcPair * b, HashTable *map); +bool eqTcTypeDef(struct TcTypeDef * a, struct TcTypeDef * b, HashTable *map); +bool eqTcTypeDefArgs(struct TcTypeDefArgs * a, struct TcTypeDefArgs * b, HashTable *map); +bool eqTcVar(struct TcVar * a, struct TcVar * b, HashTable *map); +bool eqTcType(struct TcType * a, struct TcType * b, HashTable *map); #endif diff --git a/src/tc_helper.c b/src/tc_helper.c index 73fe837..e4bfe3e 100644 --- a/src/tc_helper.c +++ b/src/tc_helper.c @@ -17,6 +17,7 @@ */ #include "tc_helper.h" +#include "symbol.h" void ppTcType(TcType *type) { if (type == NULL) { @@ -33,8 +34,11 @@ void ppTcType(TcType *type) { case TCTYPE_TYPE_VAR: ppTcVar(type->val.var); break; - case TCTYPE_TYPE_INTEGER: - eprintf("int"); + case TCTYPE_TYPE_BIGINTEGER: + eprintf("bigint"); + break; + case TCTYPE_TYPE_SMALLINTEGER: + eprintf("smallint"); break; case TCTYPE_TYPE_CHARACTER: eprintf("char"); @@ -63,7 +67,7 @@ void ppTcPair(TcPair *pair) { } void ppTcVar(TcVar *var) { - eprintf("<%s>", var->name->name); + eprintf("<%s>%d", var->name->name, var->id); if (var->instance != NULL) { eprintf(" ["); ppTcType(var->instance); @@ -84,3 +88,25 @@ void ppTcTypeDef(TcTypeDef *typeDef) { ppTypeDefArgs(typeDef->args); eprintf(")"); } + +bool eqTcVar(struct TcVar * a, struct TcVar * b, HashTable *map) { + if (a == b) return true; + if (a->name == b->name) return true; + HashSymbol *common = NULL; + if (hashGet(map, a->name, &common)) { + HashSymbol *other = NULL; + if (hashGet(map, b->name, &other)) { + return common == other; + } else { + return false; + } + } else if (hashGet(map, b->name, &common)) { + return false; + } else { + // symmetric + common = genSym("tt$"); + hashSet(map, a->name, &common); + hashSet(map, b->name, &common); + } + return true; +} diff --git a/tests/src/test_typechecker.c b/tests/src/test_typechecker.c index cb4c812..0021cf5 100644 --- a/tests/src/test_typechecker.c +++ b/tests/src/test_typechecker.c @@ -17,6 +17,15 @@ */ #include "test.h" +#include "symbol.h" + +static bool compareTcTypes(TcType *a, TcType *b) { + HashTable *map = newHashTable(sizeof(HashSymbol *), NULL, NULL); + int save = PROTECT(map); + bool res = eqTcType(a, b, map); + UNPROTECT(save); + return res; +} static AstNest *parseWrapped(char *string) { disableGC(); @@ -30,7 +39,6 @@ static AstNest *parseWrapped(char *string) { return mod->nest; } -/* static AstNest *parseSolo(char *string) { disableGC(); PmModule *mod = newPmModuleFromString(string, string); @@ -42,90 +50,166 @@ static AstNest *parseSolo(char *string) { assert(mod->nest != NULL); return mod->nest; } -*/ -/* -static void test_car() { - printf("test_car\n"); - AstNest *result = parseWrapped("car"); - int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); - UNPROTECT(save); +static TcType *charToVar(char *name) { + HashSymbol *t = newSymbol(name); + TcVar *v = newTcVar(t, 0); + int save = PROTECT(v); + TcType *var = newTcType(TCTYPE_TYPE_VAR, TCTYPE_VAL_VAR(v)); + UNPROTECT(save); + return var; +} + +static TcType *listOf(TcType *type) { + TcTypeDefArgs *args = newTcTypeDefArgs(type, NULL); + int save = PROTECT(args); + HashSymbol *list = newSymbol("list"); + TcTypeDef *typeDef = newTcTypeDef(list, args); + PROTECT(typeDef); + TcType *td = newTcType(TCTYPE_TYPE_TYPEDEF, TCTYPE_VAL_TYPEDEF(typeDef)); + UNPROTECT(save); + return td; +} + +static TcType *makeFunction2(TcType *arg, TcType *result) { + TcFunction *fun = newTcFunction(arg, result); + int save = PROTECT(fun); + TcType *f = newTcType(TCTYPE_TYPE_FUNCTION, TCTYPE_VAL_FUNCTION(fun)); + UNPROTECT(save); + return f; +} + +static TcType *makeFunction3(TcType *arg1, TcType *arg2, TcType *result) { + TcType *f1 = makeFunction2(arg2, result); + int save = PROTECT(f1); + TcType *f2 = makeFunction2(arg1, f1); + UNPROTECT(save); + return f2; +} + +static TcType *makeBigInteger() { + return newTcType(TCTYPE_TYPE_BIGINTEGER, TCTYPE_VAL_BIGINTEGER()); +} + +static TcType *analyze(AstNest *nest) { + LamExp *exp = lamConvertNest(nest, NULL); + int save = PROTECT(exp); + TcEnv *env = tc_init(); + PROTECT(env); + TcType *res = tc_analyze(exp, env); + PROTECT(res); + ppTcType(res); + eprintf("\n"); assert(!hadErrors()); + UNPROTECT(save); + return res; } static void test_cdr() { printf("test_cdr\n"); AstNest *result = parseWrapped("cdr"); int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); + TcType *res = analyze(result); + PROTECT(res); + TcType *var = charToVar("#t"); + PROTECT(var); + TcType *td = listOf(var); + PROTECT(td); + TcType *f = makeFunction2(td, td); + PROTECT(f); + assert(compareTcTypes(f, res)); UNPROTECT(save); - assert(!hadErrors()); } -*/ + +static void test_car() { + printf("test_car\n"); + AstNest *result = parseWrapped("car"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *var = charToVar("#t"); + PROTECT(var); + TcType *td = listOf(var); + PROTECT(td); + TcType *f = makeFunction2(td, var); + PROTECT(f); + assert(compareTcTypes(f, res)); + UNPROTECT(save); +} + static void test_car_of() { printf("test_car_of\n"); AstNest *result = parseWrapped("<[1]"); int save = PROTECT(result); - LamExp *exp = lamConvertNest(result, NULL); - PROTECT(exp); - ppLamExp(exp); - printf("\n"); - TcEnv *env = tc_init(); - PROTECT(env); - TcType *res = tc_analyze(exp, env); - assert(!hadErrors()); + TcType *res = analyze(result); PROTECT(res); - printTcType(res, 0); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); UNPROTECT(save); } -/* + static void test_adder() { printf("test_adder\n"); AstNest *result = parseSolo("fn(a,b){a+b}"); int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); + TcType *res = analyze(result); + PROTECT(res); + TcType *bigInt = makeBigInteger(); + PROTECT(bigInt); + TcType *expected = makeFunction3(bigInt, bigInt, bigInt); + PROTECT(expected); + assert(compareTcTypes(res, expected)); UNPROTECT(save); - assert(!hadErrors()); } static void test_fact() { printf("test_fact\n"); AstNest *result = parseSolo("let fn fact {(0) {1} (n) {n * fact(n - 1)} } in fact"); int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); + TcType *res = analyze(result); + PROTECT(res); + TcType *bigInt = makeBigInteger(); + PROTECT(bigInt); + TcType *expected = makeFunction2(bigInt, bigInt); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_add1() { + printf("test_add1\n"); + AstNest *result = parseSolo("let fn add1(x) { 1 + x } in add1(2)"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); UNPROTECT(save); - assert(!hadErrors()); } -static void test_small_car() { - printf("test_small_car\n"); - AstNest *result = parseSolo("let typedef mylist(#t) { mynull | mypair(#t, mylist(#t)) }" - " fn myhead { (mynull) { mynull } (mypair(a,b)) { a } }" - " in myhead(mypair(2, mynull))"); +static void test_caddr() { + printf("test_caddr\n"); + AstNest *result = parseWrapped("let x = [1, 2, 3, 4]; in <>>x"); int save = PROTECT(result); - WResult *wr = WTop(result); - showTinMonoType(wr->monoType); - printf("\n"); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); UNPROTECT(save); - assert(!hadErrors()); } -*/ + int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) { initProtection(); - // test_car(); - // test_cdr(); + test_car(); + test_cdr(); test_car_of(); - // test_adder(); - // test_small_car(); + test_adder(); + test_fact(); + test_add1(); + test_caddr(); } diff --git a/tools/makeAST.py b/tools/makeAST.py index 75a53b6..059269f 100644 --- a/tools/makeAST.py +++ b/tools/makeAST.py @@ -38,6 +38,16 @@ def tag(self, t): if t in self.contents: self.contents[t].tag() + def noteExtraCmpArgs(self, args): + for key in self.contents: + self.contents[key].noteExtraCmpArgs(args) + + def noteBespokeCmpImplementation(self, name): + if name in self.contents: + self.contents[name].noteBespokeCmpImplementation() + else: + raise Exception("bespoke cmp implementation declared for nonexistant entry " + name) + def get(self, key): if key in self.contents: return self.contents[key] @@ -176,6 +186,11 @@ class Base: def __init__(self, name): self.name = name self.tagged = False + self.bespokeCmpImplementation = False + self.extraCmpArgs = {} + + def noteExtraCmpArgs(self, args): + self.extraCmpArgs = args def objTypeArray(self): return [] @@ -252,9 +267,12 @@ def isUnion(self): def isStruct(self): return False - def isArray(selfself): + def isArray(self): return False + def noteBespokeCmpImplementation(self): + self.bespokeCmpImplementation = True + def printMarkField(self, field, depth, prefix=''): pass @@ -549,11 +567,30 @@ def getPrintSignature(self, catalog): myType = self.getTypeDeclaration() return "void print{myName}({myType} x, int depth)".format(myName=self.getName(), myType=myType) + def getCtype(self, astType, catalog): + return f"{astType} *" + + def getExtraCmpFargs(self, catalog): + extra = [] + for name in self.extraCmpArgs: + ctype = self.getCtype(self.extraCmpArgs[name], catalog) + extra += [f"{ctype}{name}"] + if len(extra) > 0: + return ", " + ", ".join(extra) + return "" + def getCompareSignature(self, catalog): myType = self.getTypeDeclaration() - return "bool eq{myName}({myType} a, {myType} b)".format(myName=self.getName(), myType=myType) + myName = self.getName() + extraCmpArgs = self.getExtraCmpFargs(catalog) + return f"bool eq{myName}({myType} a, {myType} b{extraCmpArgs})" def printCompareFunction(self, catalog): + if self.bespokeCmpImplementation: + print("// Bespoke implementation required for"); + print("// {decl}".format(decl=self.getCompareSignature(catalog))) + print("") + return myName = self.getName() print("{decl} {{".format(decl=self.getCompareSignature(catalog))) print(" if (a == b) return true;") @@ -683,9 +720,31 @@ def getPrintSignature(self, catalog): myType = self.getTypeDeclaration() return "void print{myName}({myType} x, int depth)".format(myName=self.getName(), myType=myType) + def getCtype(self, astType, catalog): + return f"{astType} *" + + def getExtraCmpFargs(self, catalog): + extra = [] + for name in self.extraCmpArgs: + ctype = self.getCtype(self.extraCmpArgs[name], catalog) + extra += [f"{ctype}{name}"] + if len(extra) > 0: + return ", " + ", ".join(extra) + return "" + + def getExtraCmpAargs(self, catalog): + extra = [] + for name in self.extraCmpArgs: + extra += [name] + if len(extra) > 0: + return ", " + ", ".join(extra) + return "" + def getCompareSignature(self, catalog): myType = self.getTypeDeclaration() - return "bool eq{myName}({myType} a, {myType} b)".format(myName=self.getName(), myType=myType) + myName = self.getName() + extraCmpArgs = self.getExtraCmpFargs(catalog) + return f"bool eq{myName}({myType} a, {myType} b{extraCmpArgs})" def getNewArgs(self): return [x for x in self.fields if x.default is None] @@ -750,8 +809,9 @@ def printMarkField(self, field, depth, prefix=''): def printCompareField(self, field, depth, prefix=''): myName=self.getName() + extraArgs = self.getExtraCmpAargs({}) pad(depth) - print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field})) return false;") + print(f"if (!eq{myName}(a->{prefix}{field}, b->{prefix}{field}{extraArgs})) return false;") def printPrintField(self, field, depth, prefix=''): myName=self.getName() @@ -794,6 +854,11 @@ def printTypeObjCase(self, catalog): print('return "{name}";'.format(name=self.getName())) def printCompareFunction(self, catalog): + if self.bespokeCmpImplementation: + print("// Bespoke implementation required for"); + print("// {decl}".format(decl=self.getCompareSignature(catalog))) + print("") + return myName = self.getName() print("{decl} {{".format(decl=self.getCompareSignature(catalog))) print(" if (a == b) return true;") @@ -1198,6 +1263,13 @@ def printGpl(file, document): for tag in document["tags"]: catalog.tag(tag); +if "cmp" in document: + if "extraArgs" in document["cmp"]: + catalog.noteExtraCmpArgs(document["cmp"]["extraArgs"]) + if "bespokeImplementation" in document["cmp"]: + for bespoke in document["cmp"]["bespokeImplementation"]: + catalog.noteBespokeCmpImplementation(bespoke) + catalog.build() if args.type == "h": From bb8a895bebdd717daabed29083fc4140872a4907 Mon Sep 17 00:00:00 2001 From: Bill Hails Date: Sun, 31 Dec 2023 12:28:37 +0000 Subject: [PATCH 3/4] type checker fixes, checkpoint --- Makefile | 2 +- fn/map3.fn | 13 +++ src/common.h | 2 +- src/errors.c | 3 + src/lambda_conversion.c | 10 ++ src/tc_analyze.c | 55 ++++++---- tests/src/test_typechecker.c | 192 ++++++++++++++++++++++++++++++++++- 7 files changed, 251 insertions(+), 26 deletions(-) create mode 100644 fn/map3.fn diff --git a/Makefile b/Makefile index 8cfad18..ccb7b6f 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ $(EXTRA_OBJ): obj/%.o: tmp/%.c | obj $(LAXCC) -I src/ -I tmp/ -c $< -o $@ $(TEST_OBJ): obj/%.o: tests/src/%.c | obj - $(CC) -I src/ -I tmp/ -c $< -o $@ + $(LAXCC) -I src/ -I tmp/ -c $< -o $@ $(MAIN_DEP) $(DEP): dep/%.d: src/%.c | dep $(CC) -I tmp/ -I src/ -MM -MT $(patsubst dep/%,obj/%,$(patsubst %.d,%.o,$@)) -o $@ $< diff --git a/fn/map3.fn b/fn/map3.fn new file mode 100644 index 0000000..df36b69 --- /dev/null +++ b/fn/map3.fn @@ -0,0 +1,13 @@ +let + typedef colours { red | green | blue } + fn map { + (f, nil) { [] } + (f, h @ t) { f(h) @ map(f, t) } + } + fn toInt { + (red) { 0 } + (green) { 1 } + (blue) { 2 } + } +in + map(toInt, [red, green, blue]) diff --git a/src/common.h b/src/common.h index 6959216..4a00a0a 100644 --- a/src/common.h +++ b/src/common.h @@ -45,7 +45,7 @@ // #define DEBUG_BYTECODE // define this to make fatal errors dump core (if ulimit allows) #define DEBUG_DUMP_CORE -// #define DEBUG_TC +#define DEBUG_TC // #define DEBUG_LAMBDA_CONVERT // #define DEBUG_LEAK // #define DEBUG_ANF diff --git a/src/errors.c b/src/errors.c index 2844048..5adb061 100644 --- a/src/errors.c +++ b/src/errors.c @@ -48,6 +48,9 @@ void can_happen(const char *message, ...) { va_end(args); eprintf("\n"); errors = true; +#ifdef DEBUG_TC + abort(); +#endif } bool hadErrors() { diff --git a/src/lambda_conversion.c b/src/lambda_conversion.c index 6bc3ffe..722eba2 100644 --- a/src/lambda_conversion.c +++ b/src/lambda_conversion.c @@ -199,6 +199,13 @@ static LamMakeVec *performMakeVecSubstitutions(LamMakeVec *makeVec, HashTable *s return makeVec; } +static LamDeconstruct *performDeconstructSubstitutions(LamDeconstruct *deconstruct, HashTable *substitutions) { + ENTER(performDeconstructSubstitutions); + deconstruct->exp = lamPerformSubstitutions(deconstruct->exp, substitutions); + LEAVE(performDeconstructSubstitutions); + return deconstruct; +} + static LamConstruct *performConstructSubstitutions(LamConstruct *construct, HashTable *substitutions) { ENTER(performConstructSubstitutions); construct->args = performListSubstitutions(construct->args, substitutions); @@ -378,6 +385,9 @@ LamExp *lamPerformSubstitutions(LamExp *exp, HashTable *substitutions) { case LAMEXP_TYPE_MAKEVEC: exp->val.makeVec = performMakeVecSubstitutions(exp->val.makeVec, substitutions); break; + case LAMEXP_TYPE_DECONSTRUCT: + exp->val.deconstruct = performDeconstructSubstitutions(exp->val.deconstruct, substitutions); + break; case LAMEXP_TYPE_CONSTRUCT: exp->val.construct = performConstructSubstitutions(exp->val.construct, substitutions); break; diff --git a/src/tc_analyze.c b/src/tc_analyze.c index e9edf23..a573dea 100644 --- a/src/tc_analyze.c +++ b/src/tc_analyze.c @@ -126,8 +126,11 @@ TcEnv *tc_init(void) { TcType *tc_analyze(LamExp *exp, TcEnv *env) { TcNg *ng = extendNg(NULL); + int save = PROTECT(ng); IFDEBUG(ppLamExp(exp)); - return prune(analyzeExp(exp, env, ng)); + TcType *res = prune(analyzeExp(exp, env, ng)); + UNPROTECT(save); + return res; } static TcType *analyzeExp(LamExp *exp, TcEnv *env, TcNg *ng) { @@ -572,7 +575,6 @@ static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { TcType *freshVar = makeFreshVar(); int save2 = PROTECT(freshVar); addToEnv(env, bindings->var, freshVar); - // addToNg(ng, freshVar->val.var->name, freshVar); UNPROTECT(save2); } for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { @@ -587,10 +589,8 @@ static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { unify(freshVar, type); DEBUG("analyzeLetRec binding %s, result:", bindings->var->name); IFDEBUG(ppTcType(freshVar)); - // addToEnv(env, bindings->var, freshVar); UNPROTECT(save2); } - IFDEBUG(printTcEnv(env, 0)); TcType *res = analyzeExp(letRec->body, env, ng); UNPROTECT(save); LEAVE(analyzeLetRec); @@ -652,9 +652,6 @@ static TcType *makeTypeConstructorApplication(LamTypeFunction *func, HashTable * // this code is building the inner application of a type, i.e. // list(t) in the context of t -> list(t) -> list(t) TcTypeDefArgs *args = makeTypeDefArgs(func->args, map); - if (args == NULL) { - cant_happen("null args to LamTypeDefFunction '%s' in makeTypeConstructorApplication", func->name->name); - } int save = PROTECT(args); TcType *res = makeTypeDef(func->name, args); UNPROTECT(save); @@ -726,6 +723,7 @@ static TcType *analyzeTypeDefs(LamTypeDefs *typeDefs, TcEnv *env, TcNg *ng) { ENTER(analyzeTypeDefs); env = extendEnv(env); int save = PROTECT(env); + DEBUG("after extendEnv:"); for (LamTypeDefList *list = typeDefs->typeDefs; list != NULL; list = list->next) { collectTypeDef(list->typeDef, env); } @@ -973,19 +971,24 @@ static TcType *freshTypeDef(TcTypeDef *typeDef, TcNg *ng, HashTable *map) { } static bool isGeneric(TcType *typeVar, TcNg *ng) { - if (ng == NULL) { - return true; - } - int i = 0; - TcType *entry = NULL; - HashSymbol *s = NULL; - while ((s = iterateHashTable(ng->table, &i, &entry)) != NULL) { - if (occursInType(typeVar, entry)) { - return false; + ENTER(isGeneric); + IFDEBUG(ppTcType(typeVar)); + while (ng != NULL) { + int i = 0; + TcType *entry = NULL; + HashSymbol *s = NULL; + while ((s = iterateHashTable(ng->table, &i, &entry)) != NULL) { + if (occursInType(typeVar, entry)) { + LEAVE(isGeneric); + DEBUG("false"); + return false; + } } + ng = ng->next; } - bool res = isGeneric(typeVar, ng->next); - return res; + LEAVE(isGeneric); + DEBUG("true"); + return true; } static TcType *typeGetOrPut(HashTable *map, TcType *typeVar, TcType *defaultValue) { @@ -1057,10 +1060,10 @@ static TcType *lookup(TcEnv *env, HashSymbol *symbol, TcNg *ng) { return NULL; } -static void addToNg(TcNg *env, HashSymbol *symbol, TcType *type) { +static void addToNg(TcNg *ng, HashSymbol *symbol, TcType *type) { DEBUG("addToNg %s =>", symbol->name); IFDEBUG(ppTcType(type)); - hashSet(env->table, symbol, &type); + hashSet(ng->table, symbol, &type); } static TcType *makeBoolean() { @@ -1079,20 +1082,24 @@ static TcType *makeFn(TcType *arg, TcType *result) { } static TcEnv *extendEnv(TcEnv *parent) { + ENTER(extendEnv); HashTable *table = newHashTable(sizeof(TcType *), markType, printType); int save = PROTECT(table); table->shortEntries = true; TcEnv *env = newTcEnv(table, parent); UNPROTECT(save); + LEAVE(extendEnv); return env; } static TcNg *extendNg(TcNg *parent) { + ENTER(extendNg); HashTable *table = newHashTable(sizeof(TcType *), markType, printType); int save = PROTECT(table); table->shortEntries = true; TcNg *ng = newTcNg(table, parent); UNPROTECT(save); + LEAVE(extendNg); return ng; } @@ -1301,6 +1308,14 @@ static bool unify(TcType *a, TcType *b) { return unify(b, a); } else { if (a->type != b->type) { + if ((a->type == TCTYPE_TYPE_SMALLINTEGER && + b->type == TCTYPE_TYPE_TYPEDEF) || + (a->type == TCTYPE_TYPE_TYPEDEF && + b->type == TCTYPE_TYPE_SMALLINTEGER)) { + // small integers are *only* used as type tags + // so can unify with typeDefs + return true; + } can_happen("unification failed"); ppTcType(a); eprintf(" vs "); diff --git a/tests/src/test_typechecker.c b/tests/src/test_typechecker.c index 0021cf5..c8986a2 100644 --- a/tests/src/test_typechecker.c +++ b/tests/src/test_typechecker.c @@ -60,13 +60,19 @@ static TcType *charToVar(char *name) { return var; } +static TcType *makeTypeDef(char *name, TcTypeDefArgs *args) { + HashSymbol *sym = newSymbol(name); + TcTypeDef *typeDef = newTcTypeDef(sym, args); + int save = PROTECT(typeDef); + TcType *td = newTcType(TCTYPE_TYPE_TYPEDEF, TCTYPE_VAL_TYPEDEF(typeDef)); + UNPROTECT(save); + return td; +} + static TcType *listOf(TcType *type) { TcTypeDefArgs *args = newTcTypeDefArgs(type, NULL); int save = PROTECT(args); - HashSymbol *list = newSymbol("list"); - TcTypeDef *typeDef = newTcTypeDef(list, args); - PROTECT(typeDef); - TcType *td = newTcType(TCTYPE_TYPE_TYPEDEF, TCTYPE_VAL_TYPEDEF(typeDef)); + TcType *td = makeTypeDef("list", args); UNPROTECT(save); return td; } @@ -91,6 +97,10 @@ static TcType *makeBigInteger() { return newTcType(TCTYPE_TYPE_BIGINTEGER, TCTYPE_VAL_BIGINTEGER()); } +static TcType *makeCharacter() { + return newTcType(TCTYPE_TYPE_CHARACTER, TCTYPE_VAL_CHARACTER()); +} + static TcType *analyze(AstNest *nest) { LamExp *exp = lamConvertNest(nest, NULL); int save = PROTECT(exp); @@ -201,9 +211,174 @@ static void test_caddr() { UNPROTECT(save); } +static void test_curry() { + printf("test_curry\n"); + AstNest *result = parseWrapped("let fn add3(a, b, c) { a + b + c } in add3(1)(2)(3)"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_here() { + printf("test_here\n"); + AstNest *result = parseWrapped( +"let" +" fn funky(k) { k(1) }" +"in" +" 4 + here fn (k) {" +" if (funky(k)) {" +" 2" +" } else {" +" 3" +" }" +" }" + ); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_if() { + printf("test_if\n"); + AstNest *result = parseWrapped("if (true and true) { 10 } else { 20 }"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeBigInteger(); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_id() { + printf("test_id\n"); + AstNest *result = parseWrapped( +"let" +" fn id (x) { x }" +"" +" fn length {" +" ([]) { 0 }" +" (_ @ t) { 1 + length(t) }" +" }" +"" +" fn even(n) { n % 2 == 0 }" +"" +" fn checkId(x) {" +" id(even(id(length(id(x)))))" +" }" +"" +"in" +" checkId(\"hello\")" + ); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *expected = makeTypeDef("bool", NULL); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_either_1() { + printf("test_either_1\n"); + AstNest *result = parseWrapped("let typedef either(#a, #b) { a(#a) | b(#b) } in a(1)"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *big = makeBigInteger(); + PROTECT(big); + TcType *var = charToVar("#t"); + PROTECT(var); + TcTypeDefArgs *args = newTcTypeDefArgs(var, NULL); + PROTECT(args); + args = newTcTypeDefArgs(big, args); + PROTECT(args); + TcType *expected = makeTypeDef("either", args); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_tostr() { + printf("test_tostr\n"); + AstNest *result = parseWrapped( +"let" +" typedef colour { red | green | blue }" +" fn tostr {" +" (red) { \"red\" }" +" (green) { \"green\" }" +" (blue) { \"blue\" }" +" }" +"in" +" tostr(red)" + ); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *character = makeCharacter(); + PROTECT(character); + TcType *expected = listOf(character); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_lol() { + printf("test_lol\n"); + AstNest *result = parseWrapped("[[1]]"); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *big = makeBigInteger(); + PROTECT(big); + TcType *loi = listOf(big); + PROTECT(loi); + TcType *expected = listOf(loi); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + +static void test_map() { + printf("test_lol\n"); + AstNest *result = parseWrapped( +"let" +" typedef colours { red | green | blue }" +" fn map {" +" (f, nil) { [] }" +" (f, h @ t) { f(h) @ map(f, t) }" +" }" +" fn toInt {" +" (red) { 0 }" +" (green) { 1 }" +" (blue) { 2 }" +" }" +"in" +" map(toInt, [red, green, blue])" + ); + int save = PROTECT(result); + TcType *res = analyze(result); + PROTECT(res); + TcType *big = makeBigInteger(); + PROTECT(big); + TcType *expected = listOf(big); + PROTECT(expected); + assert(compareTcTypes(res, expected)); + UNPROTECT(save); +} + int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) { initProtection(); + /* test_car(); test_cdr(); test_car_of(); @@ -211,5 +386,14 @@ int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) test_fact(); test_add1(); test_caddr(); + test_either_1(); + test_tostr(); + test_curry(); + test_here(); + test_id(); + test_if(); + test_lol(); + */ + test_map(); } From d3b2a6077c6a12c75b5279aefaa8c3a646872ce3 Mon Sep 17 00:00:00 2001 From: Bill Hails Date: Sun, 31 Dec 2023 18:04:16 +0000 Subject: [PATCH 4/4] type checker passes all tests --- src/common.h | 4 +- src/hash.c | 3 +- src/lambda_pp.c | 3 + src/tc_analyze.c | 116 +++++++++++++++++++++++------------ tests/src/test_typechecker.c | 36 +++++------ 5 files changed, 100 insertions(+), 62 deletions(-) diff --git a/src/common.h b/src/common.h index 4a00a0a..e0ee591 100644 --- a/src/common.h +++ b/src/common.h @@ -29,7 +29,7 @@ // #define DEBUG_STEP // if DEBUG_STEP is defined, this sleeps for 1 second between each machine step // #define DEBUG_SLOW_STEP -// define this to cause a GC at every possible step (catched memory leaks early) +// define this to cause a GC at every possible step (catches memory leaks early) #define DEBUG_STRESS_GC // #define DEBUG_LOG_GC // #define DEBUG_GC @@ -45,7 +45,7 @@ // #define DEBUG_BYTECODE // define this to make fatal errors dump core (if ulimit allows) #define DEBUG_DUMP_CORE -#define DEBUG_TC +// #define DEBUG_TC // #define DEBUG_LAMBDA_CONVERT // #define DEBUG_LEAK // #define DEBUG_ANF diff --git a/src/hash.c b/src/hash.c index f6e0d23..316e156 100644 --- a/src/hash.c +++ b/src/hash.c @@ -275,7 +275,7 @@ void printHashTable(HashTable *table, int depth) { eprintf("HashTable: (NULL)"); return; } - eprintf("{[id:%d]", table->id); + eprintf("HashTable %d: {", table->id); bool first = true; for (int i = 0; i < table->capacity; ++i) { if (table->keys[i] != NULL) { @@ -291,7 +291,6 @@ void printHashTable(HashTable *table, int depth) { eprintf(" "); else eprintf("\n"); -DEBUG("printHashTable, index %d, valuePtr %p", i, valuePtr(table, i)); table->printfunction(valuePtr(table, i), table->shortEntries ? 0 : (depth + 2)); eprintf("\n"); } else { diff --git a/src/lambda_pp.c b/src/lambda_pp.c index 500c388..7112ccf 100644 --- a/src/lambda_pp.c +++ b/src/lambda_pp.c @@ -245,6 +245,9 @@ void ppLamPrimOp(LamPrimOp type) { case LAMPRIMOP_TYPE_MOD: eprintf("mod"); break; + case LAMPRIMOP_TYPE_POW: + eprintf("pow"); + break; default: cant_happen("unrecognised type %d in ppLamPrimOp", type); } diff --git a/src/tc_analyze.c b/src/tc_analyze.c index a573dea..1e9f2df 100644 --- a/src/tc_analyze.c +++ b/src/tc_analyze.c @@ -45,7 +45,7 @@ static TcType *makeBoolean(void); static TcType *makeSmallInteger(void); static TcType *makeBigInteger(void); static TcType *makeCharacter(void); -static TcType *makeFreshVar(void); +static TcType *makeFreshVar(char *); static TcType *makeVar(HashSymbol *t); static TcType *makeFn(TcType *arg, TcType *result); static void addBoolBinOpToEnv(TcEnv *env, HashSymbol *symbol); @@ -219,7 +219,7 @@ static TcType *analyzeLam(LamLam *lam, TcEnv *env, TcNg *ng) { ng = extendNg(ng); PROTECT(ng); for (LamVarList *args = lam->args; args != NULL; args = args->next) { - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar(args->var->name); int save2 = PROTECT(freshVar); addToEnv(env, args->var, freshVar); addToNg(ng, freshVar->val.var->name, freshVar); @@ -436,17 +436,22 @@ static TcType *findResultType(TcType *fn) { static TcType *analyzeDeconstruct(LamDeconstruct *deconstruct, TcEnv *env, TcNg *ng) { ENTER(analyzeDeconstruct); IFDEBUG(ppLamDeconstruct(deconstruct)); + /* TcType *constructor = NULL; if (!getFromEnv(env, deconstruct->name, &constructor)) { + */ + TcType *constructor = lookup(env, deconstruct->name, ng); + int save = PROTECT(constructor); + if (constructor == NULL) { can_happen("undefined type deconstructor %s", deconstruct->name->name); - TcType *res = makeFreshVar(); + TcType *res = makeFreshVar(deconstruct->name->name); LEAVE(analyzeDeconstruct); return res; } TcType *fieldType = findNthArg(deconstruct->vec - 1, constructor); TcType *resultType = findResultType(constructor); TcType *expType = analyzeExp(deconstruct->exp, env, ng); - int save = PROTECT(expType); + PROTECT(expType); unify(expType, resultType); UNPROTECT(save); LEAVE(analyzeDeconstruct); @@ -460,7 +465,7 @@ static TcType *analyzeConstant(LamConstant *constant, TcEnv *env, TcNg *ng) { TcType *constType = lookup(env, constant->name, ng); if (constType == NULL) { can_happen("undefined constant %s", constant->name->name); - TcType *res = makeFreshVar(); + TcType *res = makeFreshVar("err"); LEAVE(analyzeConstant); return res; } @@ -510,7 +515,7 @@ static TcType *analyzeApply(LamApply *apply, TcEnv *env, TcNg *ng) { IFDEBUG(ppTcType(fn)); TcType *arg = analyzeExp(apply->args->exp, env, ng); PROTECT(arg); - TcType *res = makeFreshVar(); + TcType *res = makeFreshVar("apply"); PROTECT(res); TcType *functionType = makeFn(arg, res); PROTECT(functionType); @@ -550,9 +555,9 @@ static TcType *analyzeIff(LamIff *iff, TcEnv *env, TcNg *ng) { static TcType *analyzeCallCC(LamExp *called, TcEnv *env, TcNg *ng) { // 'call/cc' is ((a -> b) -> a) -> a - TcType *a = makeFreshVar(); + TcType *a = makeFreshVar("callccA"); int save = PROTECT(a); - TcType *b = makeFreshVar(); + TcType *b = makeFreshVar("callccB"); PROTECT(b); TcType *ab = makeFn(a, b); PROTECT(ab); @@ -566,13 +571,14 @@ static TcType *analyzeCallCC(LamExp *called, TcEnv *env, TcNg *ng) { } static TcType *analyzeLetRec(LamLetRec *letRec, TcEnv *env, TcNg *ng) { + DEBUG("***************************************"); ENTER(analyzeLetRec); env = extendEnv(env); int save = PROTECT(env); ng = extendNg(ng); PROTECT(ng); for (LamLetRecBindings *bindings = letRec->bindings; bindings != NULL; bindings = bindings->next) { - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar(bindings->var->name); int save2 = PROTECT(freshVar); addToEnv(env, bindings->var, freshVar); UNPROTECT(save2); @@ -720,6 +726,7 @@ static void collectTypeDef(LamTypeDef *lamTypeDef, TcEnv *env) { } static TcType *analyzeTypeDefs(LamTypeDefs *typeDefs, TcEnv *env, TcNg *ng) { + DEBUG("***************************************"); ENTER(analyzeTypeDefs); env = extendEnv(env); int save = PROTECT(env); @@ -747,82 +754,101 @@ static TcType *analyzeLet(LamLet *let, TcEnv *env, TcNg *ng) { return res; } -static TcType *unifyMatchCases(LamMatchList *cases, TcEnv *env, TcNg *ng) { - ENTER(unifyMatchCases); +static TcType *analyzeMatchCases(LamMatchList *cases, TcEnv *env, TcNg *ng) { + ENTER(analyzeMatchCases); if (cases == NULL) { - TcType *res = makeFreshVar(); - LEAVE(unifyMatchCases); + TcType *res = makeFreshVar("matchCases"); + LEAVE(analyzeMatchCases); return res; } - TcType *rest = unifyMatchCases(cases->next, env, ng); + TcType *rest = analyzeMatchCases(cases->next, env, ng); int save = PROTECT(rest); TcType *this = analyzeExp(cases->body, env, ng); PROTECT(this); unify(this, rest); UNPROTECT(save); - LEAVE(unifyMatchCases); + LEAVE(analyzeMatchCases); return this; } static TcType *analyzeBigIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { + ENTER(analyzeBigIntegerExp); TcType *type = analyzeExp(exp, env, ng); int save = PROTECT(type); TcType *integer = makeBigInteger(); PROTECT(integer); unify(type, integer); UNPROTECT(save); + LEAVE(analyzeBigIntegerExp); return integer; } static TcType *analyzeSmallIntegerExp(LamExp *exp, TcEnv *env, TcNg *ng) { + ENTER(analyzeSmallIntegerExp); TcType *type = analyzeExp(exp, env, ng); int save = PROTECT(type); TcType *integer = makeSmallInteger(); PROTECT(integer); unify(type, integer); UNPROTECT(save); + LEAVE(analyzeSmallIntegerExp); return integer; } static TcType *analyzeBooleanExp(LamExp *exp, TcEnv *env, TcNg *ng) { + ENTER(analyzeBooleanExp); TcType *type = analyzeExp(exp, env, ng); int save = PROTECT(type); TcType *boolean = makeBoolean(); PROTECT(boolean); unify(type, boolean); UNPROTECT(save); + LEAVE(analyzeBooleanExp); return boolean; } static TcType *analyzeMatch(LamMatch *match, TcEnv *env, TcNg *ng) { + ENTER(analyzeMatch); (void) analyzeSmallIntegerExp(match->index, env, ng); - TcType *res = unifyMatchCases(match->cases, env, ng); + TcType *res = analyzeMatchCases(match->cases, env, ng); + LEAVE(analyzeMatch); return res; } -static TcType *unifyIntCondCases(LamIntCondCases *cases, TcEnv *env, TcNg *ng) { - if (cases == NULL) return makeFreshVar(); - TcType *rest = unifyIntCondCases(cases->next, env, ng); +static TcType *analyzeIntCondCases(LamIntCondCases *cases, TcEnv *env, TcNg *ng) { + ENTER(analyzeIntCondCases); + if (cases == NULL) { + LEAVE(analyzeIntCondCases); + return makeFreshVar("intCondCases"); + } + TcType *rest = analyzeIntCondCases(cases->next, env, ng); int save = PROTECT(rest); TcType *this = analyzeExp(cases->body, env, ng); PROTECT(this); unify(this, rest); UNPROTECT(save); + LEAVE(analyzeIntCondCases); return this; } -static TcType *unifyCharCondCases(LamCharCondCases *cases, TcEnv *env, TcNg *ng) { - if (cases == NULL) return makeFreshVar(); - TcType *rest = unifyCharCondCases(cases->next, env, ng); +static TcType *analyzeCharCondCases(LamCharCondCases *cases, TcEnv *env, TcNg *ng) { + ENTER(analyzeCharCondCases); + if (cases == NULL) { + LEAVE(analyzeCharCondCases); + return makeFreshVar("charCondCases"); + } + TcType *rest = analyzeCharCondCases(cases->next, env, ng); int save = PROTECT(rest); TcType *this = analyzeExp(cases->body, env, ng); PROTECT(this); unify(this, rest); UNPROTECT(save); + LEAVE(analyzeCharCondCases); return this; } static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng) { + ENTER(analyzeCond); TcType *result = NULL; int save = PROTECT(result); TcType *value = analyzeExp(cond->value, env, ng); @@ -832,40 +858,47 @@ static TcType *analyzeCond(LamCond *cond, TcEnv *env, TcNg *ng) { TcType *integer = makeBigInteger(); PROTECT(integer); unify(value, integer); - result = unifyIntCondCases(cond->cases->val.integers, env, ng); + result = analyzeIntCondCases(cond->cases->val.integers, env, ng); } break; case LAMCONDCASES_TYPE_CHARACTERS: { TcType *character = makeCharacter(); PROTECT(character); unify(value, character); - result = unifyCharCondCases(cond->cases->val.characters, env, ng); + result = analyzeCharCondCases(cond->cases->val.characters, env, ng); } break; default: cant_happen("unrecognized type %d in analyzeCond", cond->cases->type); } UNPROTECT(save); + LEAVE(analyzeCond); return result; } static TcType *analyzeAnd(LamAnd *and, TcEnv *env, TcNg *ng) { + ENTER(analyzeAnd); TcType *res = analyzeBinaryBool(and->left, and->right, env, ng); + LEAVE(analyzeAnd); return res; } static TcType *analyzeOr(LamOr *or, TcEnv *env, TcNg *ng) { + ENTER(analyzeOr); TcType *res = analyzeBinaryBool(or->left, or->right, env, ng); + LEAVE(analyzeOr); return res; } static TcType *analyzeAmb(LamAmb *amb, TcEnv *env, TcNg *ng) { + ENTER(analyzeAmb); TcType *left = analyzeExp(amb->left, env, ng); int save = PROTECT(left); TcType *right = analyzeExp(amb->right, env, ng); PROTECT(right); unify(left, right); UNPROTECT(save); + LEAVE(analyzeAmb); return left; } @@ -877,16 +910,19 @@ static TcType *analyzeCharacter() { } static TcType *analyzeBack() { - TcType *res = makeFreshVar(); + ENTER(analyzeBack); + TcType *res = makeFreshVar("back"); + LEAVE(analyzeBack); return res; } static TcType *analyzeError() { - TcType *res = makeFreshVar(); + ENTER(analyzeError); + TcType *res = makeFreshVar("error"); + LEAVE(analyzeError); return res; } - static void markType(void *ptr) { markTcType(*((TcType **) ptr)); } @@ -973,6 +1009,7 @@ static TcType *freshTypeDef(TcTypeDef *typeDef, TcNg *ng, HashTable *map) { static bool isGeneric(TcType *typeVar, TcNg *ng) { ENTER(isGeneric); IFDEBUG(ppTcType(typeVar)); + IFDEBUG(printTcNg(ng, 0)); while (ng != NULL) { int i = 0; TcType *entry = NULL; @@ -1013,7 +1050,7 @@ static TcType *freshRec(TcType *type, TcNg *ng, HashTable *map) { } case TCTYPE_TYPE_VAR: if (isGeneric(type, ng)) { - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar(type->val.var->name->name); int save = PROTECT(freshVar); TcType *res = typeGetOrPut(map, type, freshVar); UNPROTECT(save); @@ -1082,24 +1119,20 @@ static TcType *makeFn(TcType *arg, TcType *result) { } static TcEnv *extendEnv(TcEnv *parent) { - ENTER(extendEnv); HashTable *table = newHashTable(sizeof(TcType *), markType, printType); int save = PROTECT(table); table->shortEntries = true; TcEnv *env = newTcEnv(table, parent); UNPROTECT(save); - LEAVE(extendEnv); return env; } static TcNg *extendNg(TcNg *parent) { - ENTER(extendNg); HashTable *table = newHashTable(sizeof(TcType *), markType, printType); int save = PROTECT(table); table->shortEntries = true; TcNg *ng = newTcNg(table, parent); UNPROTECT(save); - LEAVE(extendNg); return ng; } @@ -1112,8 +1145,10 @@ static TcType *makeVar(HashSymbol *t) { return res; } -static TcType *makeFreshVar() { - return makeVar(genSym("t$")); +static TcType *makeFreshVar(char *name) { + static char buff[256]; + snprintf(buff, 256, "%s/", name); + return makeVar(genSym(buff)); } static TcType *makeSmallInteger() { @@ -1159,7 +1194,7 @@ static void addIfToEnv(TcEnv *env) { // 'if' is bool -> a -> a -> a TcType *boolean = makeBoolean(); int save = PROTECT(boolean); - TcType *a = makeFreshVar(); + TcType *a = makeFreshVar("if"); (void) PROTECT(a); TcType *aa = makeFn(a, a); (void) PROTECT(aa); @@ -1173,9 +1208,9 @@ static void addIfToEnv(TcEnv *env) { static void addHereToEnv(TcEnv *env) { // 'call/cc' is ((a -> b) -> a) -> a - TcType *a = makeFreshVar(); + TcType *a = makeFreshVar("hereA"); int save = PROTECT(a); - TcType *b = makeFreshVar(); + TcType *b = makeFreshVar("hereB"); (void) PROTECT(b); TcType *ab = makeFn(a, b); (void) PROTECT(ab); @@ -1189,7 +1224,7 @@ static void addHereToEnv(TcEnv *env) { static void addCmpToEnv(TcEnv *env, HashSymbol *symbol) { // all binary comparisons are a -> a -> bool - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar(symbol->name); int save = PROTECT(freshVar); TcType *boolean = makeBoolean(); (void) PROTECT(boolean); @@ -1203,7 +1238,7 @@ static void addCmpToEnv(TcEnv *env, HashSymbol *symbol) { static void addFreshVarToEnv(TcEnv *env, HashSymbol *symbol) { // 'error' and 'back' both have unconstrained types - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar(symbol->name); int save = PROTECT(freshVar); addToEnv(env, symbol, freshVar); UNPROTECT(save); @@ -1237,7 +1272,7 @@ static void addBoolBinOpToEnv(TcEnv *env, HashSymbol *symbol) { static void addThenToEnv(TcEnv *env) { // a -> a -> a - TcType *freshVar = makeFreshVar(); + TcType *freshVar = makeFreshVar("then"); int save = PROTECT(freshVar); addBinOpToEnv(env, thenSymbol(), freshVar); UNPROTECT(save); @@ -1287,6 +1322,7 @@ static bool unify(TcType *a, TcType *b) { b = prune(b); DEBUG("UNIFY"); IFDEBUG(ppTcType(a); eprintf(" WITH "); ppTcType(b)); + if (a == b) return true; if (a->type == TCTYPE_TYPE_VAR) { if (b->type != TCTYPE_TYPE_VAR) { if (occursInType(a, b)) { diff --git a/tests/src/test_typechecker.c b/tests/src/test_typechecker.c index c8986a2..0cf19bc 100644 --- a/tests/src/test_typechecker.c +++ b/tests/src/test_typechecker.c @@ -51,7 +51,7 @@ static AstNest *parseSolo(char *string) { return mod->nest; } -static TcType *charToVar(char *name) { +static TcType *makeVar(char *name) { HashSymbol *t = newSymbol(name); TcVar *v = newTcVar(t, 0); int save = PROTECT(v); @@ -121,7 +121,7 @@ static void test_cdr() { int save = PROTECT(result); TcType *res = analyze(result); PROTECT(res); - TcType *var = charToVar("#t"); + TcType *var = makeVar("#t"); PROTECT(var); TcType *td = listOf(var); PROTECT(td); @@ -137,7 +137,7 @@ static void test_car() { int save = PROTECT(result); TcType *res = analyze(result); PROTECT(res); - TcType *var = charToVar("#t"); + TcType *var = makeVar("#t"); PROTECT(var); TcType *td = listOf(var); PROTECT(td); @@ -295,7 +295,7 @@ static void test_either_1() { PROTECT(res); TcType *big = makeBigInteger(); PROTECT(big); - TcType *var = charToVar("#t"); + TcType *var = makeVar("#t"); PROTECT(var); TcTypeDefArgs *args = newTcTypeDefArgs(var, NULL); PROTECT(args); @@ -348,28 +348,30 @@ static void test_lol() { } static void test_map() { - printf("test_lol\n"); + printf("test_map\n"); AstNest *result = parseWrapped( "let" -" typedef colours { red | green | blue }" " fn map {" -" (f, nil) { [] }" +" (_, []) { [] }" " (f, h @ t) { f(h) @ map(f, t) }" " }" -" fn toInt {" -" (red) { 0 }" -" (green) { 1 }" -" (blue) { 2 }" -" }" "in" -" map(toInt, [red, green, blue])" +" map" ); int save = PROTECT(result); TcType *res = analyze(result); PROTECT(res); - TcType *big = makeBigInteger(); - PROTECT(big); - TcType *expected = listOf(big); + TcType *t1 = makeVar("#t1"); + PROTECT(t1); + TcType *t2 = makeVar("#t2"); + PROTECT(t2); + TcType *f = makeFunction2(t1, t2); + PROTECT(f); + TcType *listT1 = listOf(t1); + PROTECT(listT1); + TcType *listT2 = listOf(t2); + PROTECT(listT2); + TcType *expected = makeFunction3(f, listT1, listT2); PROTECT(expected); assert(compareTcTypes(res, expected)); UNPROTECT(save); @@ -378,7 +380,6 @@ static void test_map() { int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) { initProtection(); - /* test_car(); test_cdr(); test_car_of(); @@ -393,7 +394,6 @@ int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) test_id(); test_if(); test_lol(); - */ test_map(); }