diff --git a/src/compiler/CodeGen.cpp b/src/compiler/CodeGen.cpp index 1523179..e8e0d4c 100644 --- a/src/compiler/CodeGen.cpp +++ b/src/compiler/CodeGen.cpp @@ -1,6 +1,7 @@ #include #include #include "binaryen-c.h" +#include "compiler/Compiler.hpp" #include "lexer/Lexemes.hpp" #include "StandardLibrary.hpp" #include "parser/ast/ASTNodeList.hpp" @@ -26,6 +27,8 @@ namespace Theta { } BinaryenExpressionRef CodeGen::generate(shared_ptr node, BinaryenModuleRef &module) { + if (node->hasOwnScope()) scope.enterScope(); + if (node->getNodeType() == ASTNode::SOURCE) { generateSource(dynamic_pointer_cast(node), module); } else if (node->getNodeType() == ASTNode::CAPSULE) { @@ -34,6 +37,10 @@ namespace Theta { return generateBlock(dynamic_pointer_cast(node), module); } else if (node->getNodeType() == ASTNode::RETURN) { return generateReturn(dynamic_pointer_cast(node), module); + } else if (node->getNodeType() == ASTNode::FUNCTION_INVOCATION) { + return generateFunctionInvocation(dynamic_pointer_cast(node), module); + } else if (node->getNodeType() == ASTNode::IDENTIFIER) { + return generateIdentifier(dynamic_pointer_cast(node), module); } else if (node->getNodeType() == ASTNode::BINARY_OPERATION) { return generateBinaryOperation(dynamic_pointer_cast(node), module); } else if (node->getNodeType() == ASTNode::UNARY_OPERATION) { @@ -46,41 +53,82 @@ namespace Theta { return generateBooleanLiteral(dynamic_pointer_cast(node), module); } + if (node->hasOwnScope()) scope.exitScope(); + return nullptr; } BinaryenExpressionRef CodeGen::generateCapsule(shared_ptr capsuleNode, BinaryenModuleRef &module) { vector> capsuleElements = dynamic_pointer_cast(capsuleNode->getValue())->getElements(); + hoistCapsuleElements(capsuleElements); + for (auto elem : capsuleElements) { string elemType = dynamic_pointer_cast(elem->getResolvedType())->getType(); if (elem->getNodeType() == ASTNode::ASSIGNMENT) { shared_ptr identNode = dynamic_pointer_cast(elem->getLeft()); if (elemType == DataTypes::FUNCTION) { - shared_ptr fnDeclNode = dynamic_pointer_cast(elem->getRight()); + generateFunctionDeclaration( + identNode->getIdentifier(), + dynamic_pointer_cast(elem->getRight()), + module, + true + ); + } + } + } + } + BinaryenExpressionRef CodeGen::generateFunctionDeclaration( + string identifier, + shared_ptr fnDeclNode, + BinaryenModuleRef &module, + bool addToExports + ) { + scope.enterScope(); + + BinaryenType parameterType = BinaryenTypeNone(); + int totalParams = fnDeclNode->getParameters()->getElements().size(); - string functionName = capsuleNode->getName() + "." + identNode->getIdentifier(); - - cout << "it is" << functionName.c_str() << " " << functionName.c_str() << endl; + if (totalParams > 0) { + BinaryenType* types = new BinaryenType[totalParams]; - BinaryenExpressionRef body = generate(fnDeclNode->getDefinition(), module); + for (int i = 0; i < totalParams; i++) { + shared_ptr identNode = dynamic_pointer_cast(fnDeclNode->getParameters()->getElements().at(i)); - BinaryenFunctionRef fn = BinaryenAddFunction( - module, - functionName.c_str(), - BinaryenTypeNone(), - getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast(fnDeclNode->getResolvedType()->getValue())), - NULL, - 0, - body - ); + identNode->setMappedBinaryenIndex(i); - BinaryenAddFunctionExport(module, functionName.c_str(), functionName.c_str()); - } + scope.insert(identNode->getIdentifier(), identNode); + types[i] = getBinaryenTypeFromTypeDeclaration( + + dynamic_pointer_cast(fnDeclNode->getParameters()->getElements().at(i)->getValue()) + ); } + + parameterType = BinaryenTypeCreate(types, totalParams); } + + string functionName = Compiler::getQualifiedFunctionIdentifier( + identifier, + dynamic_pointer_cast(fnDeclNode) + ); + + BinaryenFunctionRef fn = BinaryenAddFunction( + module, + functionName.c_str(), + parameterType, + getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast(fnDeclNode->getResolvedType()->getValue())), + NULL, + 0, + generate(fnDeclNode->getDefinition(), module) + ); + + if (addToExports) { + BinaryenAddFunctionExport(module, functionName.c_str(), functionName.c_str()); + } + + scope.exitScope(); } BinaryenExpressionRef CodeGen::generateBlock(shared_ptr blockNode, BinaryenModuleRef &module) { @@ -103,6 +151,37 @@ namespace Theta { return BinaryenReturn(module, generate(returnNode->getValue(), module)); } + BinaryenExpressionRef CodeGen::generateFunctionInvocation(shared_ptr funcInvNode, BinaryenModuleRef &module) { + BinaryenExpressionRef* arguments = new BinaryenExpressionRef[funcInvNode->getParameters()->getElements().size()]; + + string funcName = Compiler::getQualifiedFunctionIdentifier( + dynamic_pointer_cast(funcInvNode->getIdentifier())->getIdentifier(), + funcInvNode + ); + + for (int i = 0; i < funcInvNode->getParameters()->getElements().size(); i++) { + arguments[i] = generate(funcInvNode->getParameters()->getElements().at(i), module); + } + + return BinaryenCall( + module, + funcName.c_str(), + arguments, + funcInvNode->getParameters()->getElements().size(), + getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast(funcInvNode->getResolvedType())) + ); + } + + BinaryenExpressionRef CodeGen::generateIdentifier(shared_ptr identNode, BinaryenModuleRef &module) { + shared_ptr identInScope = scope.lookup(identNode->getIdentifier()); + + return BinaryenLocalGet( + module, + identInScope->getMappedBinaryenIndex(), + getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast(identNode->getResolvedType())) + ); + } + BinaryenExpressionRef CodeGen::generateBinaryOperation(shared_ptr binOpNode, BinaryenModuleRef &module) { if (binOpNode->getOperator() == Lexemes::EXPONENT) { return generateExponentOperation(binOpNode, module); @@ -113,12 +192,13 @@ namespace Theta { BinaryenExpressionRef binaryenLeft = generate(binOpNode->getLeft(), module); BinaryenExpressionRef binaryenRight = generate(binOpNode->getRight(), module); + cout << binOpNode->toJSON() << endl; + if (!binaryenLeft || !binaryenRight) { throw runtime_error("Invalid operand types for binary operation"); } - // TODO: This wont work if we have nested operations on either side - if (binOpNode->getLeft()->getNodeType() == ASTNode::STRING_LITERAL) { + if (dynamic_pointer_cast(binOpNode->getResolvedType())->getType() == DataTypes::STRING) { return BinaryenStringConcat( module, binaryenLeft, @@ -243,4 +323,20 @@ namespace Theta { if (typeDeclaration->getType() == DataTypes::STRING) return BinaryenTypeStringref(); if (typeDeclaration->getType() == DataTypes::BOOLEAN) return BinaryenTypeInt32(); } + + void CodeGen::hoistCapsuleElements(vector> elements) { + scope.enterScope(); + + for (auto ast : elements) bindIdentifierToScope(ast); + } + + void CodeGen::bindIdentifierToScope(shared_ptr ast) { + string identifier = dynamic_pointer_cast(ast->getLeft())->getIdentifier(); + + if (ast->getRight()->getNodeType() == ASTNode::FUNCTION_DECLARATION) { + identifier = Compiler::getQualifiedFunctionIdentifier(identifier, ast->getRight()); + } + + scope.insert(identifier, ast->getRight()); + } } diff --git a/src/compiler/CodeGen.hpp b/src/compiler/CodeGen.hpp index 4ec7a81..f201cfd 100644 --- a/src/compiler/CodeGen.hpp +++ b/src/compiler/CodeGen.hpp @@ -6,10 +6,14 @@ #include "../parser/ast/UnaryOperationNode.hpp" #include "../parser/ast/LiteralNode.hpp" #include "../parser/ast/SourceNode.hpp" +#include "compiler/SymbolTableStack.hpp" #include "parser/ast/ASTNodeList.hpp" #include "parser/ast/CapsuleNode.hpp" +#include "parser/ast/FunctionDeclarationNode.hpp" +#include "parser/ast/IdentifierNode.hpp" #include "parser/ast/ReturnNode.hpp" #include "parser/ast/TypeDeclarationNode.hpp" +#include "parser/ast/FunctionInvocationNode.hpp" #include using namespace std; @@ -19,21 +23,29 @@ namespace Theta { public: // using GenerateResult = std::variant; - static BinaryenModuleRef generateWasmFromAST(shared_ptr ast); - static BinaryenExpressionRef generate(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateCapsule(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateBlock(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateReturn(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateBinaryOperation(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateUnaryOperation(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateNumberLiteral(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateStringLiteral(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateBooleanLiteral(shared_ptr node, BinaryenModuleRef &module); - static BinaryenExpressionRef generateExponentOperation(shared_ptr node, BinaryenModuleRef &module); - static void generateSource(shared_ptr node, BinaryenModuleRef &module); + BinaryenModuleRef generateWasmFromAST(shared_ptr ast); + BinaryenExpressionRef generate(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateCapsule(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateBlock(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateReturn(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateFunctionDeclaration(string identifier, shared_ptr node, BinaryenModuleRef &module, bool addToExports = false); + BinaryenExpressionRef generateFunctionInvocation(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateIdentifier(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateBinaryOperation(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateUnaryOperation(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateNumberLiteral(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateStringLiteral(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateBooleanLiteral(shared_ptr node, BinaryenModuleRef &module); + BinaryenExpressionRef generateExponentOperation(shared_ptr node, BinaryenModuleRef &module); + void generateSource(shared_ptr node, BinaryenModuleRef &module); private: + SymbolTableStack scope; + static BinaryenOp getBinaryenOpFromBinOpNode(shared_ptr node); static BinaryenType getBinaryenTypeFromTypeDeclaration(shared_ptr node); + + void hoistCapsuleElements(vector> elements); + void bindIdentifierToScope(shared_ptr ast); }; } diff --git a/src/compiler/Compiler.cpp b/src/compiler/Compiler.cpp index ee42701..25198a4 100644 --- a/src/compiler/Compiler.cpp +++ b/src/compiler/Compiler.cpp @@ -31,7 +31,8 @@ namespace Theta { if (!isTypeValid) return; - BinaryenModuleRef module = CodeGen::generateWasmFromAST(programAST); + CodeGen codeGen; + BinaryenModuleRef module = codeGen.generateWasmFromAST(programAST); if (isEmitWAT) { cout << "Generated WAT for \"" + entrypoint + "\":" << endl; @@ -57,7 +58,8 @@ namespace Theta { if (!isTypeValid) return ast; - BinaryenModuleRef module = CodeGen::generateWasmFromAST(ast); + CodeGen codeGen; + BinaryenModuleRef module = codeGen.generateWasmFromAST(ast); cout << "-> " + ast->toJSON() << endl; cout << "-> "; @@ -214,4 +216,30 @@ namespace Theta { cout << "Could not parse AST for file " + fileName << endl; } } + + string Compiler::getQualifiedFunctionIdentifier(string variableName, shared_ptr node) { + vector> params; + + if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) { + shared_ptr declarationNode = dynamic_pointer_cast(node); + params = declarationNode->getParameters()->getElements(); + } else { + shared_ptr invocationNode = dynamic_pointer_cast(node); + params = invocationNode->getParameters()->getElements(); + } + + string functionIdentifier = variableName + to_string(params.size()); + + for (int i = 0; i < params.size(); i++) { + if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) { + shared_ptr paramType = dynamic_pointer_cast(params.at(i)->getValue()); + functionIdentifier += paramType->getType(); + } else { + shared_ptr paramType = dynamic_pointer_cast(params.at(i)->getResolvedType()); + functionIdentifier += paramType->getType(); + } + } + + return functionIdentifier; + } } diff --git a/src/compiler/Compiler.hpp b/src/compiler/Compiler.hpp index dca9e27..926bd3b 100644 --- a/src/compiler/Compiler.hpp +++ b/src/compiler/Compiler.hpp @@ -101,6 +101,16 @@ namespace Theta { */ bool optimizeAST(shared_ptr &ast, bool silenceErrors = false); + + /** + * @brief Generates a unique function identifier based on the function's name and its parameters to handle overloading. + * + * @param variableName The base name of the function. + * @param declarationNode The function declaration node containing the parameters. + * @return string The unique identifier for the function. + */ + static string getQualifiedFunctionIdentifier(string variableName, shared_ptr node); + shared_ptr> filesByCapsuleName; private: /** diff --git a/src/compiler/optimization/LiteralInlinerPass.cpp b/src/compiler/optimization/LiteralInlinerPass.cpp index 9034ca7..c275103 100644 --- a/src/compiler/optimization/LiteralInlinerPass.cpp +++ b/src/compiler/optimization/LiteralInlinerPass.cpp @@ -67,7 +67,7 @@ void LiteralInlinerPass::bindIdentifierToScope(shared_ptr &ast, SymbolT string identifier = dynamic_pointer_cast(ast->getLeft())->getIdentifier(); if (ast->getRight()->getNodeType() == ASTNode::FUNCTION_DECLARATION) { - string uniqueFuncIdentifier = getDeterministicFunctionIdentifier(identifier, ast->getRight()); + string uniqueFuncIdentifier = Compiler::getQualifiedFunctionIdentifier(identifier, ast->getRight()); shared_ptr existingFuncIdentifierInScope = scope.lookup(uniqueFuncIdentifier); diff --git a/src/compiler/optimization/OptimizationPass.cpp b/src/compiler/optimization/OptimizationPass.cpp index 49f0ccb..2fd9f87 100644 --- a/src/compiler/optimization/OptimizationPass.cpp +++ b/src/compiler/optimization/OptimizationPass.cpp @@ -88,30 +88,3 @@ shared_ptr OptimizationPass::lookupInScope(string identifierName) { return foindHoisted; } - - -string OptimizationPass::getDeterministicFunctionIdentifier(string variableName, shared_ptr node) { - vector> params; - - if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) { - shared_ptr declarationNode = dynamic_pointer_cast(node); - params = declarationNode->getParameters()->getElements(); - } else { - shared_ptr invocationNode = dynamic_pointer_cast(node); - params = invocationNode->getParameters()->getElements(); - } - - string functionIdentifier = variableName + to_string(params.size()); - - for (int i = 0; i < params.size(); i++) { - if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) { - shared_ptr paramType = dynamic_pointer_cast(params.at(i)->getValue()); - functionIdentifier += paramType->getType(); - } else { - shared_ptr paramType = dynamic_pointer_cast(params.at(i)->getResolvedType()); - functionIdentifier += paramType->getType(); - } - } - - return functionIdentifier; -} diff --git a/src/compiler/optimization/OptimizationPass.hpp b/src/compiler/optimization/OptimizationPass.hpp index 787d0d3..129c2a8 100644 --- a/src/compiler/optimization/OptimizationPass.hpp +++ b/src/compiler/optimization/OptimizationPass.hpp @@ -46,15 +46,6 @@ namespace Theta { */ shared_ptr lookupInScope(string identifier); - /** - * @brief Generates a unique function identifier based on the function's name and its parameters to handle overloading. - * - * @param variableName The base name of the function. - * @param declarationNode The function declaration node containing the parameters. - * @return string The unique identifier for the function. - */ - string getDeterministicFunctionIdentifier(string variableName, shared_ptr node); - private: /** * @brief Pure virtual function to be implemented by derived classes for performing specific optimizations on the AST. diff --git a/src/parser/ast/ASTNode.hpp b/src/parser/ast/ASTNode.hpp index 89ac196..c1df4b3 100644 --- a/src/parser/ast/ASTNode.hpp +++ b/src/parser/ast/ASTNode.hpp @@ -53,31 +53,31 @@ namespace Theta { shared_ptr left; shared_ptr right; shared_ptr resolvedType; - + int mappedBinaryenIndex; + ASTNode(ASTNode::Types type) : nodeType(type), value(nullptr) {}; virtual void setValue(shared_ptr childNode) { value = childNode; } - virtual shared_ptr& getValue() { return value; } virtual void setLeft(shared_ptr childNode) { left = childNode; } - virtual shared_ptr& getLeft() { return left; } virtual void setRight(shared_ptr childNode) { right = childNode; } - virtual shared_ptr& getRight() { return right; } + virtual int getMappedBinaryenIndex() { return mappedBinaryenIndex; } + virtual void setMappedBinaryenIndex(int idx) { mappedBinaryenIndex = idx; } + + void setResolvedType(shared_ptr typeNode) { resolvedType = typeNode; } + shared_ptr getResolvedType() { return resolvedType; } + virtual bool hasMany() { return false; } virtual bool hasOwnScope() { return false; } virtual ~ASTNode() = default; - void setResolvedType(shared_ptr typeNode) { resolvedType = typeNode; } - - shared_ptr getResolvedType() { return resolvedType; } - static string nodeTypeToString(ASTNode::Types nodeType) { static map typesMap = { { ASTNode::ASSIGNMENT, "Assignment" }, @@ -116,4 +116,4 @@ namespace Theta { }; } -#endif \ No newline at end of file +#endif