From 64a48d117e1c4e9a143e38cfbecadd5184016832 Mon Sep 17 00:00:00 2001 From: chh Date: Mon, 8 Jul 2024 21:41:49 +0800 Subject: [PATCH 01/17] [FrontendGen] Add Frontend Generator. --- CMakeLists.txt | 2 +- examples/FrontendGen/.gitignore | 6 +- examples/FrontendGen/example.fegen | 247 ++- examples/FrontendGen/function.fegen | 38 + examples/FrontendGen/makefile | 12 +- examples/FrontendGen/opDefine.fegen | 9 + examples/FrontendGen/rule.fegen | 14 + examples/FrontendGen/typeDefine.fegen | 17 + frontend/CMakeLists.txt | 4 +- frontend/FrontendGen/.gitignore | 1 + frontend/FrontendGen/CMakeLists.txt | 21 +- frontend/FrontendGen/README.md | 14 + frontend/FrontendGen/frontendgen.cpp | 121 +- frontend/FrontendGen/include/AST.h | 211 --- frontend/FrontendGen/include/CGModule.h | 73 - frontend/FrontendGen/include/Diagnostics.def | 10 - frontend/FrontendGen/include/Diagnostics.h | 52 - frontend/FrontendGen/include/FegenManager.h | 482 ++++++ frontend/FrontendGen/include/FegenVisitor.h | 775 +++++++++ frontend/FrontendGen/include/Lexer.h | 59 - frontend/FrontendGen/include/Parser.h | 60 - frontend/FrontendGen/include/Scope.h | 81 + frontend/FrontendGen/include/Sema.h | 35 - frontend/FrontendGen/include/Terminator.def | 22 - frontend/FrontendGen/include/Terminator.h | 75 - frontend/FrontendGen/include/Token.def | 37 - frontend/FrontendGen/include/Token.h | 56 - frontend/FrontendGen/include/TypeMap.def | 32 - frontend/FrontendGen/lib/CGModule.cpp | 422 ----- frontend/FrontendGen/lib/CMakeLists.txt | 62 +- frontend/FrontendGen/lib/Diagnostics.cpp | 43 - frontend/FrontendGen/lib/FegenLexer.g4 | 221 +++ frontend/FrontendGen/lib/FegenManager.cpp | 1480 ++++++++++++++++++ frontend/FrontendGen/lib/FegenParser.g4 | 443 ++++++ frontend/FrontendGen/lib/FegenVisitor.cpp | 11 + frontend/FrontendGen/lib/Lexer.cpp | 198 --- frontend/FrontendGen/lib/Parser.cpp | 403 ----- frontend/FrontendGen/lib/Scope.cpp | 117 ++ frontend/FrontendGen/lib/Sema.cpp | 54 - 39 files changed, 3919 insertions(+), 2101 deletions(-) create mode 100644 examples/FrontendGen/function.fegen create mode 100644 examples/FrontendGen/opDefine.fegen create mode 100644 examples/FrontendGen/rule.fegen create mode 100644 examples/FrontendGen/typeDefine.fegen create mode 100644 frontend/FrontendGen/.gitignore create mode 100644 frontend/FrontendGen/README.md delete mode 100644 frontend/FrontendGen/include/AST.h delete mode 100644 frontend/FrontendGen/include/CGModule.h delete mode 100644 frontend/FrontendGen/include/Diagnostics.def delete mode 100644 frontend/FrontendGen/include/Diagnostics.h create mode 100644 frontend/FrontendGen/include/FegenManager.h create mode 100644 frontend/FrontendGen/include/FegenVisitor.h delete mode 100644 frontend/FrontendGen/include/Lexer.h delete mode 100644 frontend/FrontendGen/include/Parser.h create mode 100644 frontend/FrontendGen/include/Scope.h delete mode 100644 frontend/FrontendGen/include/Sema.h delete mode 100644 frontend/FrontendGen/include/Terminator.def delete mode 100644 frontend/FrontendGen/include/Terminator.h delete mode 100644 frontend/FrontendGen/include/Token.def delete mode 100644 frontend/FrontendGen/include/Token.h delete mode 100644 frontend/FrontendGen/include/TypeMap.def delete mode 100644 frontend/FrontendGen/lib/CGModule.cpp delete mode 100644 frontend/FrontendGen/lib/Diagnostics.cpp create mode 100644 frontend/FrontendGen/lib/FegenLexer.g4 create mode 100644 frontend/FrontendGen/lib/FegenManager.cpp create mode 100644 frontend/FrontendGen/lib/FegenParser.g4 create mode 100644 frontend/FrontendGen/lib/FegenVisitor.cpp delete mode 100644 frontend/FrontendGen/lib/Lexer.cpp delete mode 100644 frontend/FrontendGen/lib/Parser.cpp create mode 100644 frontend/FrontendGen/lib/Scope.cpp delete mode 100644 frontend/FrontendGen/lib/Sema.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 83b7981421..2c38064a54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,7 +160,7 @@ check_toolchain() # NB: currently, ANTLR is used in dsl examples only, # however, there is a plan to use in the frontend, # so it is kept in the top-level cmake -if(BUDDY_DSL_EXAMPLES) +if(BUDDY_DSL_EXAMPLES OR FeGen) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Antlr) # required if linking to static library diff --git a/examples/FrontendGen/.gitignore b/examples/FrontendGen/.gitignore index 34ec8116e7..72daf17d09 100644 --- a/examples/FrontendGen/.gitignore +++ b/examples/FrontendGen/.gitignore @@ -1,3 +1,3 @@ -Toy.g4 -MLIRToyVisitor.h - +test/ +*.g4 +*.td \ No newline at end of file diff --git a/examples/FrontendGen/example.fegen b/examples/FrontendGen/example.fegen index cc453795c1..55caa409d0 100644 --- a/examples/FrontendGen/example.fegen +++ b/examples/FrontendGen/example.fegen @@ -1,141 +1,106 @@ -dialect Toy_Dialect - : name = "toy" - : cppNamespace = "mlir::toy" - ; - -op ConstantOp - : arguments = (ins F64ElementsAttr : $value) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "DenseElementsAttr" : $value), - [{ build($_builder, $_state, value.getType(), value); }]>, - OpBuilder<(ins "double":$value)>] - ; - -op AddOp - : arguments = (ins F64Tensor : $lhs, F64Tensor: $rhs) - : results = (outs F64Tensor) - : builders = [OpBuilder<(ins "Value" : $lhs, "Value" : $rhs)>] - ; - -op CastOp - : arguments = (ins F64Tensor:$input) - : results = (outs F64Tensor:$output) - ; - -op FuncOp - : arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type - ) - : builders = [ OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs)> - ] - ; - -op MulOp - : arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ] - ; - -op PrintOp - : arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input) - ; - -op ReshapeOp - : arguments = (ins F64Tensor : $input) - : results = (outs StaticShapeTensorOf<[F64]>) - ; - -op ReturnOp - : arguments = (ins Variadic:$input) - : builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> - ] - ; - -op GenericCallOp - : arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> - ] - ; - -op TransposeOp - : arguments = (ins F64Tensor:$input) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "Value":$input)> - ] - ; - - -rule module - : funDefine - ; - -rule expression - : Number - : tensorLiteral - : identifierExpr - : expression Add expression - ; - -rule returnExpr - : Return expression? - ; - -rule identifierExpr - : Identifier - : Identifier ParentheseOpen (expression (Comma expression) *)? ParentheseClose { - builder = GenericCallOp_1, PrintOp_0 - } - ; - -rule tensorLiteral - : SbracketOpen ( tensorLiteral ( Comma tensorLiteral ) *) ? SbracketClose - : Number - ; - -rule varDecl - : Var Identifier (type) ? (Equal expression) ? { - builder = ReshapeOp_0 - } - ; - -rule type - : AngleBracketOpen Number(Comma Number) * AngleBracketClose - ; - -rule funDefine - : prototype block { - builder = ReturnOp_1 - } - ; - -rule prototype - : Def Identifier ParentheseOpen declList ? ParentheseClose { - builder = FuncOp_0 - } - ; - -rule declList - : Identifier - : Identifier Comma declList - ; - -rule block - : BracketOpen(blockExpr Semi) * BracketClose - ; - -rule blockExpr - : varDecl - : returnExpr - : expression - ; - +fegen toy + +typedef struct { + parameters [list elementTypes] // ArrayParameter<'Type'> +} + +Type Toy_Type = any<[Tensor, struct]>; + +opdef constant { + arguments [operand list> numberAttr] // Variadic + results [operand Tensor res] + body { + list shape = shapeOf(res); + // full是一个内置函数,创建memref,并将每个元素都填充numberAttr + res = full(shape, numberAttr); + } +} + +opdef add { + arguments [operand Tensor lhs, operand Tensor rhs] + results [operand Tensor res] + body { + // 这个'+'也是一个内置的函数 + res = lhs + rhs; // res = builder.create(lsh, rhs); + } +} + +opdef mul { + arguments [operand Tensor lhs, operand Tensor rhs] + results [operand Tensor res] + body { + // 这个'*'也是一个内置的函数 + res = lhs * rhs; + } +} + +opdef reshape { + arguments [operand F64Tensor input] + results [operand F64Tensor output] + body { + list shape = shapeOf(output); + output = reshape(input, shape); + } +} + +double stod(string numStr){ + double res = 0; + int index; + int i; + for(i = 0; i <= len(numStr)-1; i=i+1){ + char c = numStr[0]; + int charNum; + if(c == '0'){ + charNum = 0; + }else if (c == '1'){ + charNum = 1; + }else if (c == '2'){ + charNum = 2; + }else if (c == '3'){ + charNum = 3; + }else if (c == '4'){ + charNum = 4; + }else if (c == '5'){ + charNum = 5; + }else if (c == '6'){ + charNum = 6; + }else if (c == '7'){ + charNum = 7; + }else if (c == '8'){ + charNum = 8; + }else if (c == '9'){ + charNum = 9; + }else if (c == '.'){ + index = i; + } + res = res * 10; + res = res + charNum; + } + res = res * 0.1**(len(numStr) - 1 - index); + return res; +} + + +module + : structDefine* funDefine+ + ; + +structDefine + : Struct Identifier BracketOpen (varDecl Semicolon)* BracketClose + ; + +// cpp value --get--> mlir::attribute || --constant Operation--> mlir::Value +// ======== || + +expression + : Number + { + returns [operand F64Tensor ret, operand F64Tensor ret] + actions { + // Type mlir::Value ret of operator | Attribute | Cpp Value + double numberAttr = stod($Number().getText()); + Type retType = Tensor<[], double>; + ret = constant(numberAttr, retType); + } + } + ; \ No newline at end of file diff --git a/examples/FrontendGen/function.fegen b/examples/FrontendGen/function.fegen new file mode 100644 index 0000000000..e2fc6e7bfa --- /dev/null +++ b/examples/FrontendGen/function.fegen @@ -0,0 +1,38 @@ +fegen toy + +double stod(string numStr){ + double res = 0; + int index; + int i; + for(i = 0; i <= len(numStr)-1; i=i+1){ + char c = numStr[0]; + int charNum; + if(c == '0'){ + charNum = 0; + }else if (c == '1'){ + charNum = 1; + }else if (c == '2'){ + charNum = 2; + }else if (c == '3'){ + charNum = 3; + }else if (c == '4'){ + charNum = 4; + }else if (c == '5'){ + charNum = 5; + }else if (c == '6'){ + charNum = 6; + }else if (c == '7'){ + charNum = 7; + }else if (c == '8'){ + charNum = 8; + }else if (c == '9'){ + charNum = 9; + }else if (c == '.'){ + index = i; + } + res = res * 10; + res = res + charNum; + } + res = res * 0.1**(len(numStr) - 1 - index); + return res; +} \ No newline at end of file diff --git a/examples/FrontendGen/makefile b/examples/FrontendGen/makefile index 26b6049727..6b65351f0b 100644 --- a/examples/FrontendGen/makefile +++ b/examples/FrontendGen/makefile @@ -1,12 +1,12 @@ #!/bin/bash BUDDY_FRONTEND_GEN := ../../build/bin/buddy-frontendgen -frontendgen-emit-ast: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=ast +opDefine: + @${BUDDY_FRONTEND_GEN} -f ./opDefine.fegen -frontendgen-emit-antlr: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=antlr -g Toy +typeDefine: + @${BUDDY_FRONTEND_GEN} -f ./typeDefine.fegen -frontendgen-emit-visitor: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=visitor -g Toy +rule: + @${BUDDY_FRONTEND_GEN} -f ./rule.fegen diff --git a/examples/FrontendGen/opDefine.fegen b/examples/FrontendGen/opDefine.fegen new file mode 100644 index 0000000000..60fa22db92 --- /dev/null +++ b/examples/FrontendGen/opDefine.fegen @@ -0,0 +1,9 @@ +fegen toy + +opdef add { + arguments [operand Integer lhs, operand Integer rhs] + results [operand Integer res] + body { + res = lhs + rhs; + } +} \ No newline at end of file diff --git a/examples/FrontendGen/rule.fegen b/examples/FrontendGen/rule.fegen new file mode 100644 index 0000000000..912d495c8a --- /dev/null +++ b/examples/FrontendGen/rule.fegen @@ -0,0 +1,14 @@ +fegen toy + +module + : structDefine* funDefine+ + ; + +structDefine + : Struct Identifier BracketOpen (varDecl Semicolon)* BracketClose + ; + +expression + : Number + | Identifier + ; \ No newline at end of file diff --git a/examples/FrontendGen/typeDefine.fegen b/examples/FrontendGen/typeDefine.fegen new file mode 100644 index 0000000000..d4be6997d7 --- /dev/null +++ b/examples/FrontendGen/typeDefine.fegen @@ -0,0 +1,17 @@ +fegen toy + +typedef struct { + parameters [list elementTypes] +} + +typedef test1 { + parameters [Type e] +} + +typedef test2 { + parameters [list e] +} + +typedef test3 { + parameters [int e] +} \ No newline at end of file diff --git a/frontend/CMakeLists.txt b/frontend/CMakeLists.txt index 39e683c04b..7a2d235471 100644 --- a/frontend/CMakeLists.txt +++ b/frontend/CMakeLists.txt @@ -1,4 +1,6 @@ -add_subdirectory(FrontendGen) +if(FeGen) + add_subdirectory(FrontendGen) +endif() add_subdirectory(Interfaces) if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) add_subdirectory(Python) diff --git a/frontend/FrontendGen/.gitignore b/frontend/FrontendGen/.gitignore new file mode 100644 index 0000000000..91827d60b3 --- /dev/null +++ b/frontend/FrontendGen/.gitignore @@ -0,0 +1 @@ +.antlr/ \ No newline at end of file diff --git a/frontend/FrontendGen/CMakeLists.txt b/frontend/FrontendGen/CMakeLists.txt index 0f67051931..ca9204b3a3 100644 --- a/frontend/FrontendGen/CMakeLists.txt +++ b/frontend/FrontendGen/CMakeLists.txt @@ -1,10 +1,19 @@ -include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include") -link_directories("${CMAKE_CURRENT_BINARY_DIR}/lib") add_subdirectory(lib) set (LLVM_LINK_COMPONENTS -support -frontendgenlib -) + support +) + +include_directories(${ANTLR_FegenLexer_OUTPUT_DIR}) +include_directories(${ANTLR_FegenParser_OUTPUT_DIR}) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include") + add_llvm_tool(buddy-frontendgen -frontendgen.cpp + frontendgen.cpp ) + +target_link_libraries(buddy-frontendgen + PRIVATE + fegen_antlr_generated + fegenVisitor + antlr4_static +) \ No newline at end of file diff --git a/frontend/FrontendGen/README.md b/frontend/FrontendGen/README.md new file mode 100644 index 0000000000..13b5b2e698 --- /dev/null +++ b/frontend/FrontendGen/README.md @@ -0,0 +1,14 @@ +# How to build + +FrontendGen is designed for generate mlir project quickly by writing fegen files. + +The `FeGen` option needs to be enabled when building. + +``` bash +$ cmake -G Ninja .. \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DFeGen=ON \ + -DCMAKE_BUILD_TYPE=RELEASE +``` \ No newline at end of file diff --git a/frontend/FrontendGen/frontendgen.cpp b/frontend/FrontendGen/frontendgen.cpp index bfb6ef2002..51b29d46f6 100644 --- a/frontend/FrontendGen/frontendgen.cpp +++ b/frontend/FrontendGen/frontendgen.cpp @@ -18,111 +18,54 @@ // //===----------------------------------------------------------------------===// -#include "CGModule.h" -#include "Diagnostics.h" -#include "Lexer.h" -#include "Parser.h" +#include + #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "FegenLexer.h" +#include "FegenParser.h" +#include "FegenVisitor.h" +#include "antlr4-common.h" + llvm::cl::opt inputFileName("f", llvm::cl::desc("")); -llvm::cl::opt grammarName("g", llvm::cl::desc("")); namespace { enum Action { none, dumpAst, dumpAntlr, dumpAll, dumpVisitor }; } -llvm::cl::opt emitAction( - "emit", llvm::cl::desc("Select the kind of output desired"), - llvm::cl::values(clEnumValN(dumpAst, "ast", "Out put the ast")), - llvm::cl::values(clEnumValN(dumpAntlr, "antlr", "Out put the antlr file")), - llvm::cl::values(clEnumValN(dumpVisitor, "visitor", - "Out put the visitor file")), - llvm::cl::values(clEnumValN(dumpAll, "all", "put out all file"))); - -/// Control generation of ast, tablegen files and antlr files. -void emit(frontendgen::Module *module, frontendgen::Terminators &terminators) { - bool emitAst = emitAction == Action::dumpAst; - bool emitAntlr = - emitAction == Action::dumpAntlr || emitAction == Action::dumpAll; - bool emitVisitor = - emitAction == Action::dumpVisitor || emitAction == Action::dumpAll; - // Emit antlr file. - if (emitAntlr) { - if (grammarName.empty()) { - llvm::errs() << "if you want to emit g4 file you have to point out the " - "name of grammar.\n"; - return; - } - std::error_code EC; - llvm::sys::fs::OpenFlags openFlags = llvm::sys::fs::OpenFlags::OF_None; - std::string outputFileName = grammarName.c_str(); - outputFileName += ".g4"; - auto Out = llvm::ToolOutputFile(outputFileName, EC, openFlags); - frontendgen::CGModule CGmodule(module, Out.os(), terminators); - CGmodule.emitAntlr(grammarName); - Out.keep(); - } - // Emit antlr's AST. - if (emitAst && !module->getRules().empty()) { - llvm::raw_fd_ostream os(-1, true); - frontendgen::CGModule CGmodule(module, os, terminators); - CGmodule.emitAST(); - } - // Emit visitor file. - if (emitVisitor && !module->getRules().empty()) { - std::error_code EC; - llvm::sys::fs::OpenFlags openFlags = llvm::sys::fs::OpenFlags::OF_None; - std::string outputFileName("MLIR"); - outputFileName = outputFileName + grammarName + "Visitor.h"; - auto Out = llvm::ToolOutputFile(outputFileName, EC, openFlags); - frontendgen::CGModule CGmodule(module, Out.os(), terminators); - CGmodule.emitMLIRVisitor(grammarName); - Out.keep(); - } - // Free memory. - for (auto rule : module->getRules()) { - for (auto generatorsAndOthers : rule->getGeneratorsAndOthers()) { - for (auto element : generatorsAndOthers->getGenerator()) { - delete element; - } - delete generatorsAndOthers; - } - delete rule; - } +// llvm::cl::opt emitAction( +// "emit", llvm::cl::desc("Select the kind of output desired"), +// llvm::cl::values(clEnumValN(dumpAst, "ast", "Out put the ast")), +// llvm::cl::values(clEnumValN(dumpAntlr, "g4", "Out put the g4 file")), +// llvm::cl::values(clEnumValN(dumpVisitor, "visitor", +// "Out put the visitor file")), +// llvm::cl::values(clEnumValN(dumpAll, "all", "put out all file"))); - delete module->getDialect(); - for (auto op : module->getOps()) { - delete op->getArguments(); - delete op->getResults(); - for (auto builder : op->getBuilders()) { - delete builder->getDag(); - delete builder; - } - delete op; - } - delete module; +int dumpAST(fegen::FegenParser::FegenSpecContext *moduleAST) { + llvm::errs() << moduleAST->toStringTree(1 /* prety format*/) << "\n"; + return 0; } int main(int argc, char *argv[]) { llvm::cl::ParseCommandLineOptions(argc, argv); - llvm::ErrorOr> file = - llvm::MemoryBuffer::getFile(inputFileName.c_str()); - if (std::error_code bufferError = file.getError()) { - llvm::errs() << "error read: " << bufferError.message() << '\n'; - exit(1); - } - llvm::SourceMgr srcMgr; - srcMgr.AddNewSourceBuffer(std::move(*file), llvm::SMLoc()); - frontendgen::DiagnosticEngine diagnostic(srcMgr); - frontendgen::Lexer lexer(srcMgr, diagnostic); - frontendgen::Sema action; - frontendgen::Terminators terminators; - frontendgen::Parser parser(lexer, action, terminators); - frontendgen::Module *module = parser.parser(); - emit(module, terminators); + + // Parse the input file with ANTLR. + std::fstream in(inputFileName); + antlr4::ANTLRInputStream input(in); + fegen::FegenLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + fegen::FegenParser parser(&tokens); + auto moduleAST = parser.fegenSpec(); + + fegen::FegenVisitor visitor; + visitor.visit(moduleAST); + visitor.emitG4(); + visitor.emitTypeDefination(); + visitor.emitDialectDefination(); + visitor.emitOpDefination(); return 0; } diff --git a/frontend/FrontendGen/include/AST.h b/frontend/FrontendGen/include/AST.h deleted file mode 100644 index 549f68f2a8..0000000000 --- a/frontend/FrontendGen/include/AST.h +++ /dev/null @@ -1,211 +0,0 @@ -//====- AST.h -------------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_AST_H -#define INCLUDE_AST_H -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/SMLoc.h" -#include -namespace frontendgen { - -/// Base class for all generator nodes. -class AntlrBase { -public: - enum baseKind { rule, terminator, pbexpression }; - -private: - baseKind kind; - -protected: - llvm::StringRef name; - llvm::SMLoc loc; - -public: - virtual ~AntlrBase(){}; - AntlrBase(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : kind(kind), name(name), loc(loc) {} - llvm::StringRef getName() { return name; } - llvm::SMLoc getLoc() { return loc; } - baseKind getKind() const { return kind; } -}; - -class GeneratorAndOthers { - std::vector generator; - llvm::SmallVector builderNames; - llvm::SmallVector builderIdxs; - -public: - void setbuilderNames(llvm::SmallVector &builderNames) { - this->builderNames = builderNames; - } - void setbuilderIdxs(llvm::SmallVector &builderIdxs) { - this->builderIdxs = builderIdxs; - } - std::vector &getGenerator() { return generator; } - llvm::SmallVector getBuilderNames() { - return this->builderNames; - } - llvm::SmallVector getBuilderIndices() { return this->builderIdxs; } -}; - -/// This class is used to mark the node in the generator as a rule, and can also -/// store the generators of a rule. -class Rule : public AntlrBase { - std::vector generatorsAndOthers; - -public: - Rule(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::rule; - } - void setGenerators(std::vector &generatorsAndOthers) { - this->generatorsAndOthers = generatorsAndOthers; - } - std::vector getGeneratorsAndOthers() { - return generatorsAndOthers; - } -}; -/// The class is used to mark the node in the generator as a terminator. -class Terminator : public AntlrBase { -public: - Terminator(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::terminator; - } -}; -/// The class is used to mark the node in the generator as regular expressions. -class PBExpression : public AntlrBase { -public: - PBExpression(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::terminator; - } -}; - -/// The class is used to store the information about Dialect class in the -/// TableGen. -class Dialect { - llvm::StringRef defName; - llvm::StringRef name; - llvm::StringRef cppNamespace; - -public: - Dialect() {} - llvm::StringRef getName() { return name; } - llvm::StringRef getCppNamespace() { return cppNamespace; } - llvm::StringRef getDefName() { return defName; } - void setName(llvm::StringRef name) { this->name = name; } - void setDefName(llvm::StringRef defName) { this->defName = defName; } - void setCppNamespace(llvm::StringRef cppNamespace) { - this->cppNamespace = cppNamespace; - } -}; - -class DAG { - llvm::StringRef dagOperator; - llvm::SmallVector operands; - llvm::SmallVector operandNames; - llvm::StringMap values; - -public: - DAG(){}; - DAG(const DAG &dag) { - this->dagOperator = dag.dagOperator; - this->operands = dag.operands; - this->operandNames = dag.operandNames; - this->values = dag.values; - } - - void addOperand(llvm::StringRef operand, llvm::StringRef operandName) { - operands.push_back(operand); - operandNames.push_back(operandName); - } - void setValue(llvm::StringRef operand, llvm::StringRef value) { - values[operand] = value; - } - llvm::StringRef findValue(llvm::StringRef operand) { - if (values.find(operand) == values.end()) - return llvm::StringRef(); - return values[operand]; - } - llvm::StringRef getDagOperater() { return dagOperator; } - void setDagOperatpr(llvm::StringRef dagOperator) { - this->dagOperator = dagOperator; - } - llvm::SmallVector getOperands() { return operands; } - llvm::SmallVector getOperandNames() { - return operandNames; - } -}; -/// The class is used to store builder in Op class. -class Builder { - DAG *dag = nullptr; - llvm::StringRef code; - -public: - Builder(DAG *dag, llvm::StringRef code) { - this->dag = dag; - this->code = code; - } - DAG *getDag() { return dag; } - llvm::StringRef getCode() { return code; } -}; - -/// The class is used to store information about Op class in the TableGen. -class Op { - llvm::StringRef opName; - DAG *arguments; - DAG *results; - std::vector builders; - -public: - llvm::StringRef getOpName() { return opName; } - DAG *getArguments() { return arguments; } - DAG *getResults() { return results; } - std::vector getBuilders() { return builders; } - - void setOpName(llvm::StringRef opName) { this->opName = opName; } - - void setArguments(DAG *arguments) { this->arguments = arguments; } - void setResults(DAG *results) { this->results = results; } - void setBuilders(std::vector &builders) { - this->builders = builders; - } -}; - -/// This class will become the root of a tree which contains all information we -/// need to generate code. -class Module { - std::vector rules; - Dialect *dialect; - std::vector ops; - -public: - std::vector &getRules() { return rules; } - Dialect *getDialect() { return dialect; } - std::vector &getOps() { return ops; } - void setRules(std::vector &rules) { this->rules = rules; } - void seDialect(Dialect *&dialect) { this->dialect = dialect; } - void setOps(std::vector &ops) { this->ops = ops; } -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/CGModule.h b/frontend/FrontendGen/include/CGModule.h deleted file mode 100644 index 7fc769d94b..0000000000 --- a/frontend/FrontendGen/include/CGModule.h +++ /dev/null @@ -1,73 +0,0 @@ -//====- CGModule.h -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_CGMODULE_H -#define INCLUDE_CGMODULE_H -#include "AST.h" -#include "Terminator.h" -#include "llvm/Support/raw_ostream.h" -namespace frontendgen { - -/// TypeMap is used to store type maps.The cppMap is used to map c++ types, -/// argumentsMap and resultsMap are used to map TableGen types. -class TypeMap { - llvm::StringMap cppMap; - llvm::StringMap argumentsMap; - llvm::StringMap resultsMap; - -public: - TypeMap() { -#define CPPMAP(key, value) cppMap.insert(std::pair(key, value)); -#define RESULTSMAP(key, value) resultsMap.insert(std::pair(key, value)); -#define ARGUMENTSMAP(key, value) argumentsMap.insert(std::pair(key, value)); -#include "TypeMap.def" - } - llvm::StringRef findCppMap(llvm::StringRef value); - llvm::StringRef findArgumentMap(llvm::StringRef value); - llvm::StringRef findResultsMap(llvm::StringRef value); -}; - -/// The class for code generation. -class CGModule { - Terminators &terminators; - Module *module; - llvm::raw_fd_ostream &os; - TypeMap typeMap; - -public: - CGModule(Module *module, llvm::raw_fd_ostream &os, Terminators &terminators) - : terminators(terminators), module(module), os(os) {} - void emitAST(); - void emitAntlr(llvm::StringRef grammarName); - void emit(const std::vector &rules); - void emit(const std::vector &generators); - void emit(const std::vector &generator); - void emitGrammar(llvm::StringRef grammarName); - void emitTerminators(); - void emitCustomTerminators(); - void emitWSAndComment(); - void emitIncludes(llvm::StringRef grammarName); - void emitMLIRVisitor(llvm::StringRef grammarName); - void emitClass(llvm::StringRef grammarName); - void emitRuleVisitor(llvm::StringRef grammarName, Rule *rule); - void emitBuilders(Rule *rule); - void emitBuilder(llvm::StringRef builderOp, int index); - Op *findOp(llvm::StringRef opName); - void emitOp(Op *op, int index); -}; -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Diagnostics.def b/frontend/FrontendGen/include/Diagnostics.def deleted file mode 100644 index c2afe0f63e..0000000000 --- a/frontend/FrontendGen/include/Diagnostics.def +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef DIAG -#define DIAG(ID, Level, Msg) -#endif -DIAG(err_expected, Error, "expected {0} but found {1}") -DIAG(err_no_mnemonic, Warning, "you should indicate mnemonic.") -DIAG(err_not_supported_element, Error, "the {0} is not supported." ) -DIAG(err_no_name, Error, "opinterface should indicate the interface name.") -DIAG(err_only_supported_builder, Error, "we are only support builder") -DIAG(err_builder_fail,Error, "builder indicate failed.") -#undef DIAG diff --git a/frontend/FrontendGen/include/Diagnostics.h b/frontend/FrontendGen/include/Diagnostics.h deleted file mode 100644 index 54e475d77b..0000000000 --- a/frontend/FrontendGen/include/Diagnostics.h +++ /dev/null @@ -1,52 +0,0 @@ -//====- Diagnostics.h -----------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_DIAGNOSTIC_H -#define INCLUDE_DIAGNOSTIC_H -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/SourceMgr.h" - -/// When there is an error in the user's code, we can diagnose the error through -/// the class. -namespace frontendgen { -class DiagnosticEngine { - llvm::SourceMgr &SrcMgr; - static const char *getDiagnosticText(unsigned diagID); - llvm::SourceMgr::DiagKind getDiagnosticKind(unsigned diagID); - bool hasReport = false; - -public: - enum diagKind { -#define DIAG(ID, Level, Msg) ID, -#include "Diagnostics.def" - }; - DiagnosticEngine(llvm::SourceMgr &SrcMgr) : SrcMgr(SrcMgr) {} - - template - void report(llvm::SMLoc loc, unsigned diagID, Args &&...arguments) { - if (!hasReport) { - std::string Msg = llvm::formatv(getDiagnosticText(diagID), - std::forward(arguments)...) - .str(); - SrcMgr.PrintMessage(loc, getDiagnosticKind(diagID), Msg); - hasReport = true; - } - } -}; - -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h new file mode 100644 index 0000000000..c5623e76a7 --- /dev/null +++ b/frontend/FrontendGen/include/FegenManager.h @@ -0,0 +1,482 @@ +#ifndef FEGEN_MANAGER_H +#define FEGEN_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +#include "FegenParser.h" +#include "ParserRuleContext.h" + +#define FEGEN_PLACEHOLDER "Placeholder" +#define FEGEN_TYPE "Type" +#define FEGEN_TYPETEMPLATE "TypeTemplate" +#define FEGEN_INTEGER "Integer" +#define FEGEN_FLOATPOINT "FloatPoint" +#define FEGEN_CHAR "Char" +#define FEGEN_STRING "String" +#define FEGEN_VECTOR "Vector" +#define FEGEN_TENSOR "Tensor" +#define FEGEN_LIST "List" +#define FEGEN_OPTINAL "Optional" +#define FEGEN_ANY "Any" + +namespace fegen { + +class FegenType; +class FegenManager; +class FegenValue; + +// binary operation + +enum class FegenOperator { + OR, + AND, + EQUAL, + NOT_EQUAL, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + ADD, + SUB, + MUL, + DIV, + MOD, + POWER, + NEG, + NOT +}; + +// user defined function +class FegenFunction { +private: + // cpp function name + std::string name; + // input object + std::vector inputTypeList; + // return type + FegenType *returnType; + explicit FegenFunction(std::string name, + std::vector &&inputTypeList, + FegenType *returnType); + +public: + static FegenFunction *get(std::string name, + std::vector inputTypeList, + FegenType *returnType = nullptr); + ~FegenFunction() = default; + std::string getName(); + std::vector &getInputTypeList(); + FegenValue *getInputTypeList(size_t i); + FegenType *getReturnType(); +}; + +class FegenValue; + +// user defined operation +class FegenOperation { +private: + std::string dialectName; + std::string operationName; + // arguments of operation + std::vector arguments; + // results of operation + std::vector results; + // operation body context + FegenParser::BodySpecContext *ctx; + explicit FegenOperation(std::string dialectName, std::string operationName, + std::vector &&arguments, + std::vector &&results, + FegenParser::BodySpecContext *ctx); + +public: + void setOpName(std::string); + std::string getOpName(); + std::vector &getArguments(); + FegenValue *getArguments(size_t i); + std::vector &getResults(); + FegenValue *getResults(size_t i); + static FegenOperation *get(std::string operationName, + std::vector arguments, + std::vector results, + FegenParser::BodySpecContext *ctx); + ~FegenOperation() = default; +}; + +class FegenTypeDefination; + +class FegenType { + friend class FegenValue; + +public: + enum class TypeKind { ATTRIBUTE, OPERAND, CPP }; + +private: + TypeKind kind; + std::string typeName; + std::vector parameters; + FegenTypeDefination *typeDefine; + int typeLevel; + +public: + FegenType(TypeKind kind, std::string name, + std::vector parameters, FegenTypeDefination *tyDef, + int typeLevel); + FegenType(TypeKind kind, std::vector parameters, + FegenTypeDefination *tyDef, int typeLevel); + FegenType(const FegenType &); + FegenType(FegenType &&); + TypeKind getTypeKind(); + void setTypeKind(TypeKind kind); + std::vector &getParameters(); + FegenValue *getParameters(size_t i); + void setParameters(std::vector ¶ms); + FegenTypeDefination *getTypeDefination(); + void setTypeDefination(FegenTypeDefination *tyDef); + std::string getTypeName(); + int getTypeLevel(); + // for generating typedef td file. + std::string toStringForTypedef(); + // for generating op def td file. + std::string toStringForOpdef(); + static bool isSameType(FegenType *type1, FegenType *type2); + ~FegenType(); + // placeholder + static FegenType getPlaceHolder(); + // Type + static FegenType getMetaType(); + + // TypeTemplate + static FegenType getMetaTemplateType(); + + // int + static FegenType getInt32Type(); + + // float + static FegenType getFloatType(); + + // float + static FegenType getDoubleType(); + + // bool + static FegenType getBoolType(); + + // Integer + static FegenType getIntegerType(FegenValue *size); + + // FloatPoint + static FegenType getFloatPointType(FegenValue *size); + + // char + static FegenType getCharType(); + + // string + static FegenType getStringType(); + + // Vector + static FegenType getVectorType(FegenValue *size, FegenType elementType); + + // Tensor + static FegenType getTensorType(FegenValue *shape, FegenType elementType); + + // List + static FegenType getListType(FegenType elementType); + + // Optional + static FegenType getOptionalType(FegenType elementType); + + // Any + static FegenType getAnyType(std::vector elementTypes); + + static FegenType getIntegerTemplate(); + static FegenType getFloatPointTemplate(); + + static FegenType getInstanceType(FegenTypeDefination *typeDefination, + std::vector parameters); + + static FegenType getTemplateType(FegenTypeDefination *typeDefination); +}; + +class FegenTypeDefination { + friend class FegenManager; + +private: + std::string dialectName; + std::string name; + std::vector parameters; + FegenParser::TypeDefinationDeclContext *ctx; + bool ifCustome; + std::string mnemonic; + +public: + FegenTypeDefination(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome); + static FegenTypeDefination *get(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome = true); + std::string getDialectName(); + void setDialectName(std::string); + std::string getName(); + std::string getMnemonic(); + void setName(std::string); + const std::vector &getParameters(); + FegenParser::TypeDefinationDeclContext *getCtx(); + void setCtx(FegenParser::TypeDefinationDeclContext *); + bool isCustome(); +}; + +/// @brief Represent right value, and pass by value. +class FegenRightValue { + friend class FegenType; + friend class FegenValue; + +public: + enum class LiteralKind { + MONOSTATE, + INT, + FLOAT, + STRING, + TYPE, + VECTOR, + EXPRESSION, + LEFT_VAR + }; + + struct Expression { + bool ifTerminal; + LiteralKind kind; + FegenType exprType; + bool isLiteral; + bool ifConstexpr; + Expression(bool, LiteralKind, FegenType &, bool); + virtual ~Expression() = default; + virtual bool isTerminal(); + virtual std::string toString() = 0; + virtual std::string toStringForTypedef() = 0; + virtual std::string toStringForOpdef() = 0; + LiteralKind getKind(); + virtual std::any getContent() = 0; + virtual bool isConstexpr(); + }; + + struct ExpressionNode : public Expression { + using opType = + std::variant; + opType op; + std::vector params; + ExpressionNode(std::vector, opType, FegenType &, bool); + ExpressionNode(ExpressionNode &) = default; + ~ExpressionNode(); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::any getContent() override; + + /// @brief operate lhs and rhs using binary operator. + static ExpressionNode *binaryOperation(Expression *lhs, Expression *rhs, + FegenOperator op); + /// @brief operate expr using unary operator + static ExpressionNode *unaryOperation(Expression *, FegenOperator); + + // TODO: callFunction + static ExpressionNode *callFunction(std::vector, + FegenFunction *); + + // TODO: callOperation + static ExpressionNode *callOperation(std::vector, + FegenOperation *); + }; + + struct ExpressionTerminal : public Expression { + // monostate, int literal, float literal, string literal, type literal, list + // literal, reference of variable + using primLiteralType = + std::variant, FegenValue *>; + primLiteralType content; + ExpressionTerminal(primLiteralType, LiteralKind, FegenType, bool); + ExpressionTerminal(ExpressionTerminal &) = default; + ~ExpressionTerminal(); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::any getContent() override; + static ExpressionTerminal *get(std::monostate); + static ExpressionTerminal *get(int); + static ExpressionTerminal *get(float); + static ExpressionTerminal *get(std::string); + static ExpressionTerminal *get(FegenType &); + static ExpressionTerminal *get(std::vector &); + static ExpressionTerminal *get(fegen::FegenValue *); + }; + +public: + FegenRightValue(Expression *content); + FegenRightValue(const FegenRightValue &); + FegenRightValue(FegenRightValue &&); + FegenRightValue::LiteralKind getKind(); + std::string toString(); + std::string toStringForTypedef(); + std::string toStringForOpdef(); + std::any getContent(); + Expression *getExpr(); + + static FegenRightValue get(); + static FegenRightValue get(int content); + static FegenRightValue get(float content); + static FegenRightValue get(std::string content); + static FegenRightValue get(FegenType &content); + // list + static FegenRightValue get(std::vector &content); + static FegenRightValue get(fegen::FegenValue *content); + static FegenRightValue get(Expression *expr); + ~FegenRightValue(); + +private: + Expression *content; +}; + +class FegenValue { + friend class FegenType; + +private: + FegenType type; + std::string name; + FegenRightValue content; + +public: + FegenValue(FegenType type, std::string name, FegenRightValue content); + FegenValue(const FegenValue &rhs); + FegenValue(FegenValue &&rhs); + + static FegenValue *get(FegenType type, std::string name, + FegenRightValue constant); + + std::string getName(); + FegenType &getType(); + /// @brief return content of right value, get ExprssionNode* if kind is + /// EXPRESSION. + template T getContent() { + return std::any_cast(this->content.getContent()); + } + FegenRightValue::LiteralKind getContentKind(); + std::string getContentString(); + std::string getContentStringForTypedef(); + std::string getContentStringForOpdef(); + FegenRightValue::Expression *getExpr(); + ~FegenValue() = default; +}; + +class FegenNode; + +class FegenRule { + friend class FegenManager; + +private: + std::string content; + // from which node + FegenNode *src; + std::map inputs; + std::map returns; + // context in parser tree + antlr4::ParserRuleContext *ctx; + explicit FegenRule(std::string content, FegenNode *src, + antlr4::ParserRuleContext *ctx); + +public: + static FegenRule *get(std::string content, FegenNode *src, + antlr4::ParserRuleContext *ctx); + llvm::StringRef getContent(); + // check and add input value + bool addInput(FegenValue input); + // check and add return value + bool addReturn(FegenValue output); + // set source node + void setSrc(FegenNode *src); +}; + +class FegenNode { + friend class FegenManager; + +public: + enum class NodeType { PARSER_RULE, LEXER_RULE }; + +private: + std::vector rules; + antlr4::ParserRuleContext *ctx; + NodeType ntype; + explicit FegenNode(std::vector &&rules, + antlr4::ParserRuleContext *ctx, NodeType ntype); + +public: + static FegenNode *get(std::vector rules, + antlr4::ParserRuleContext *ctx, NodeType ntype); + static FegenNode *get(antlr4::ParserRuleContext *ctx, NodeType ntype); + void addFegenRule(FegenRule *rule); + // release rules first + ~FegenNode(); +}; + +class FegenVisitor; + +class FegenManager { + friend class FegenVisitor; + +private: + // ScopeStack &sstack; + FegenManager(); + FegenManager(const FegenManager &) = delete; + const FegenManager &operator=(const FegenManager &) = delete; + // release nodes, type, operation, function + ~FegenManager(); + void initbuiltinTypes(); + +public: + std::string moduleName; + std::vector headFiles; + std::map nodeMap; + llvm::StringMap typeMap; + std::map typeDefMap; + std::map operationMap; + std::map functionMap; + // stmt contents + std::unordered_map stmtContentMap; + void addStmtContent(antlr4::ParserRuleContext *ctx, std::any content); + template T getStmtContent(antlr4::ParserRuleContext *ctx) { + assert(this->stmtContentMap.count(ctx)); + return std::any_cast(this->stmtContentMap[ctx]); + } + + static FegenManager &getManager(); + void setModuleName(std::string name); + + FegenTypeDefination *getTypeDefination(std::string name); + bool addTypeDefination(FegenTypeDefination *tyDef); + + FegenOperation *getOperationDefination(std::string name); + bool addOperationDefination(FegenOperation *opDef); + void emitG4(); + void emitTypeDefination(); + void emitOpDefination(); + void emitDialectDefination(); + void emitTdFiles(); + void emitBuiltinFunction(); +}; + +FegenType inferenceType(std::vector, + FegenOperator); + +} // namespace fegen + +#endif \ No newline at end of file diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h new file mode 100644 index 0000000000..82e384dc1a --- /dev/null +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -0,0 +1,775 @@ +#ifndef FEGEN_FEGENVISITOR_H +#define FEGEN_FEGENVISITOR_H + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +#include "FegenManager.h" +#include "FegenParser.h" +#include "FegenParserBaseVisitor.h" +#include "Scope.h" + +using namespace antlr4; + +namespace fegen { + +/// @brief check if params are right. +/// @param expected expected params. +/// @param actual actual params. +/// @return true if correct. +bool checkParams(std::vector &expected, + std::vector &actual); + +/// @brief check if the type of elements in list are correct. +bool checkListLiteral(std::vector listLiteral); + +class FegenVisitor : public FegenParserBaseVisitor { +private: + FegenManager &manager; + ScopeStack &sstack; + +public: + void emitG4() { this->manager.emitG4(); } + void emitTypeDefination() { this->manager.emitTypeDefination(); } + void emitDialectDefination() { this->manager.emitDialectDefination(); } + void emitOpDefination() { this->manager.emitOpDefination(); } + + FegenVisitor() + : manager(FegenManager::getManager()), + sstack(ScopeStack::getScopeStack()) { + this->manager.initbuiltinTypes(); + } + + std::any visitTypeDefinationDecl( + FegenParser::TypeDefinationDeclContext *ctx) override { + auto typeName = ctx->typeDefinationName()->getText(); + auto tyDef = std::any_cast( + this->visit(ctx->typeDefinationBlock())); + // set name and ctx for type defination + tyDef->setName(typeName); + tyDef->setCtx(ctx); + // add defination to manager map + this->manager.typeDefMap.insert({typeName, tyDef}); + return nullptr; + } + + // return FegenTypeDefination* + std::any visitTypeDefinationBlock( + FegenParser::TypeDefinationBlockContext *ctx) override { + auto params = std::any_cast>( + this->visit(ctx->parametersSpec())); + auto tyDef = + FegenTypeDefination::get(this->manager.moduleName, "", params, nullptr); + return tyDef; + } + + std::any visitFegenDecl(FegenParser::FegenDeclContext *ctx) override { + this->manager.setModuleName(ctx->identifier()->getText()); + return nullptr; + } + + std::any + visitParserRuleSpec(FegenParser::ParserRuleSpecContext *ctx) override { + auto ruleList = + std::any_cast>(this->visit(ctx->ruleBlock())); + auto ruleNode = + FegenNode::get(ruleList, ctx, FegenNode::NodeType::PARSER_RULE); + // set source node for rules + for (auto rule : ruleList) { + rule->setSrc(ruleNode); + } + this->manager.nodeMap.insert({ctx->ParserRuleName()->getText(), ruleNode}); + return nullptr; + } + + std::any visitRuleAltList(FegenParser::RuleAltListContext *ctx) override { + std::vector ruleList; + for (auto alt : ctx->actionAlt()) { + auto fegenRule = std::any_cast(this->visit(alt)); + ruleList.push_back(fegenRule); + } + return ruleList; + } + + std::any visitActionAlt(FegenParser::ActionAltContext *ctx) override { + auto rawRule = this->visit(ctx->alternative()); + if (ctx->actionBlock()) { + auto blockValues = std::any_cast< + std::tuple, std::vector>>( + this->visit(ctx->actionBlock())); + auto inputs = std::get<0>(blockValues); + auto returns = std::get<1>(blockValues); + auto rule = std::any_cast(rawRule); + for (auto in : inputs) { + auto flag = rule->addInput(*in); + if (!flag) { // TODO: error report + std::cerr << "input of " << rule->getContent().str() << " \"" + << in->getName() << "\" existed." << std::endl; + } + } + for (auto out : returns) { + auto flag = rule->addReturn(*out); + if (!flag) { // TODO: error report + std::cerr << "return of " << rule->getContent().str() << " \"" + << out->getName() << "\" existed." << std::endl; + } + } + } + return rawRule; + } + + // return tuple, vector> + std::any visitActionBlock(FegenParser::ActionBlockContext *ctx) override { + std::vector inputs; + std::vector returns; + if (ctx->inputsSpec()) { + inputs = std::any_cast>( + this->visit(ctx->inputsSpec())); + } + + if (ctx->returnsSpec()) { + returns = std::any_cast>( + this->visit(ctx->returnsSpec())); + } + + if (ctx->actionSpec()) { + this->visit(ctx->actionSpec()); + } + return std::tuple(inputs, returns); + } + + // return FegenRule Object + // TODO: do more check + std::any visitAlternative(FegenParser::AlternativeContext *ctx) override { + auto content = ctx->getText(); + auto rule = FegenRule::get(content, nullptr, ctx); + return rule; + } + + std::any visitLexerRuleSpec(FegenParser::LexerRuleSpecContext *ctx) override { + // create node, get rules from child, and insert to node map + auto ruleList = std::any_cast>( + this->visit(ctx->lexerRuleBlock())); + auto ruleNode = + FegenNode::get(ruleList, ctx, FegenNode::NodeType::LEXER_RULE); + // set source node for rules + for (auto rule : ruleList) { + rule->setSrc(ruleNode); + } + this->manager.nodeMap.insert({ctx->LexerRuleName()->getText(), ruleNode}); + return nullptr; + } + + std::any visitLexerAltList(FegenParser::LexerAltListContext *ctx) override { + std::vector ruleList; + for (auto alt : ctx->lexerAlt()) { + auto rule = fegen::FegenRule::get(alt->getText(), nullptr, alt); + ruleList.push_back(rule); + } + return ruleList; + } + + // return vector + std::any visitVarDecls(FegenParser::VarDeclsContext *ctx) override { + size_t varCount = ctx->typeSpec().size(); + std::vector valueList; + for (size_t i = 0; i <= varCount - 1; i++) { + auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); + auto varName = ctx->identifier(i)->getText(); + auto var = + fegen::FegenValue::get(ty, varName, fegen::FegenRightValue::get()); + valueList.push_back(var); + } + + return valueList; + } + + // return fegen::FegenType + std::any + visitTypeInstanceSpec(FegenParser::TypeInstanceSpecContext *ctx) override { + auto valueKind = ctx->valueKind() + ? std::any_cast( + this->visit(ctx->valueKind())) + : fegen::FegenType::TypeKind::CPP; + auto typeInst = + std::any_cast(this->visit(ctx->typeInstance())); + typeInst.setTypeKind(valueKind); + return typeInst; + } + + // return fegen::FegenType::TypeKind + std::any visitValueKind(FegenParser::ValueKindContext *ctx) override { + auto kind = fegen::FegenType::TypeKind::ATTRIBUTE; + if (ctx->CPP()) { + kind = fegen::FegenType::TypeKind::CPP; + } else if (ctx->OPERAND()) { + kind = fegen::FegenType::TypeKind::OPERAND; + } + // otherwise: ATTRIBUTE + return kind; + } + + // return fegen::FegenType + std::any visitTypeInstance(FegenParser::TypeInstanceContext *ctx) override { + if (ctx->typeTemplate()) { // typeTemplate (Less typeTemplateParam (Comma + // typeTemplateParam)* Greater)? + auto typeTeplt = + std::any_cast(this->visit(ctx->typeTemplate())); + // get parameters + std::vector paramList; + for (auto paramCtx : ctx->typeTemplateParam()) { + auto tepltParams = + std::any_cast(this->visit(paramCtx)); + paramList.push_back(tepltParams); + } + + // check parameters + auto expectedParams = typeTeplt.getTypeDefination()->getParameters(); + if (!checkParams(expectedParams, paramList)) { + std::cerr << "parameters error in context: " << ctx->getText() + << std::endl; + exit(0); + } + // get FegenType of instance + auto typeInst = + FegenType::getInstanceType(typeTeplt.getTypeDefination(), paramList); + return typeInst; + } else if (ctx->identifier()) { // identifier + auto varName = ctx->identifier()->getText(); + auto var = this->sstack.attemptFindVar(varName); + if (var) { + if (var->getContentKind() == + fegen::FegenRightValue::LiteralKind::TYPE) { + return var->getContent(); + } else { + std::cerr << "variable " << varName + << " is not a Type or TypeTemplate." << std::endl; + exit(0); + return nullptr; + } + } else { // variable does not exist. + std::cerr << "undefined variable: " << varName << std::endl; + exit(0); + return nullptr; + } + } else { // builtinTypeInstances + return visitChildren(ctx); + } + } + + // return FegenValue* + std::any + visitTypeTemplateParam(FegenParser::TypeTemplateParamContext *ctx) override { + if (ctx->builtinTypeInstances()) { + auto ty = std::any_cast( + this->visit(ctx->builtinTypeInstances())); + return fegen::FegenValue::get(ty, "param", fegen::FegenRightValue::get()); + } else { + auto expr = std::any_cast( + this->visit(ctx->expression())); + return fegen::FegenValue::get(expr->exprType, "expression_tmp", + fegen::FegenRightValue(expr)); + } + } + + // return fegen::FegenType + std::any visitBuiltinTypeInstances( + FegenParser::BuiltinTypeInstancesContext *ctx) override { + if (ctx->BOOL()) { + return FegenType::getBoolType(); + } else if (ctx->INT()) { + return FegenType::getInt32Type(); + } else if (ctx->FLOAT()) { + return FegenType::getFloatType(); + } else if (ctx->DOUBLE()) { + return FegenType::getDoubleType(); + } else if (ctx->CHAR()) { + return FegenType::getCharType(); + } else if (ctx->STRING()) { + return FegenType::getStringType(); + } else { + std::cerr << "error builtin type." << std::endl; + return nullptr; + } + } + + // return FegenType + std::any visitTypeTemplate(FegenParser::TypeTemplateContext *ctx) override { + if (ctx->prefixedName()) { // prefixedName + if (ctx->prefixedName()->identifier().size() == 2) { // dialect.type + // TODO: return type from other dialect + return nullptr; + } else { // type + auto tyDef = this->sstack.attemptFindTypeDef( + ctx->prefixedName()->identifier(0)->getText()); + return fegen::FegenType::getTemplateType(tyDef); + } + } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate + return this->visit(ctx->builtinTypeTemplate()); + } else { // TYPE + return fegen::FegenType::getMetaType(); + } + } + + // return FegenType + std::any visitBuiltinTypeTemplate( + FegenParser::BuiltinTypeTemplateContext *ctx) override { + if (ctx->INTEGER()) { + return fegen::FegenType::getIntegerTemplate(); + } else if (ctx->FLOATPOINT()) { + return fegen::FegenType::getFloatPointTemplate(); + } else if (ctx->TENSOR()) { + // return fegen::FegenType::getTensorTemplate(); + return fegen::FegenType::getPlaceHolder(); + } else if (ctx->VECTOR()) { + // return fegen::FegenType::getVectorTemplate(); + return fegen::FegenType::getPlaceHolder(); + } else { + return nullptr; + } + } + + // return FegenType + std::any + visitCollectTypeSpec(FegenParser::CollectTypeSpecContext *ctx) override { + auto kind = fegen::FegenType::TypeKind::CPP; + if (ctx->valueKind()) { + kind = std::any_cast( + this->visit(ctx->valueKind())); + } + auto ty = std::any_cast(this->visit(ctx->collectType())); + ty.setTypeKind(kind); + return ty; + } + + // return FegenType + std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->expression())); + if (ctx->collectProtoType()->ANY()) { + std::vector tys; + // TODO: reprot error + assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::VECTOR); + auto exprs = + std::any_cast>( + expr->getContent()); + for (auto expr : exprs) { + auto ty = std::any_cast(expr->getContent()); + tys.push_back(ty); + } + return fegen::FegenType::getAnyType(tys); + } else if (ctx->collectProtoType()->LIST()) { + auto ty = std::any_cast(expr->getContent()); + return fegen::FegenType::getListType(ty); + } else { // optional + auto ty = std::any_cast(expr->getContent()); + return fegen::FegenType::getOptionalType(ty); + } + } + + // return FegenRightValue::Expression* + std::any visitExpression(FegenParser::ExpressionContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->andExpr(0))); + for (size_t i = 1; i <= ctx->andExpr().size() - 1; i++) { + auto rhs = std::any_cast( + this->visit(ctx->andExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation( + expr, rhs, FegenOperator::OR); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitAndExpr(FegenParser::AndExprContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->equExpr(0))); + for (size_t i = 1; i <= ctx->equExpr().size() - 1; i++) { + auto rhs = std::any_cast( + this->visit(ctx->equExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation( + expr, rhs, FegenOperator::AND); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitEquExpr(FegenParser::EquExprContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->compareExpr(0))); + for (size_t i = 1; i <= ctx->compareExpr().size() - 1; i++) { + FegenOperator op; + if (ctx->children[2 * i - 1]->getText() == "==") { + op = FegenOperator::EQUAL; + } else { + op = FegenOperator::NOT_EQUAL; + } + auto rhs = std::any_cast( + this->visit(ctx->compareExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitCompareExpr(FegenParser::CompareExprContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->addExpr(0))); + for (size_t i = 1; i <= ctx->addExpr().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "<") { + op = FegenOperator::LESS; + } else if (opStr == "<=") { + op = FegenOperator::LESS_EQUAL; + } else if (opStr == "<=") { + op = FegenOperator::LESS_EQUAL; + } else if (opStr == ">") { + op = FegenOperator::GREATER; + } else { + op = FegenOperator::GREATER_EQUAL; + } + auto rhs = std::any_cast( + this->visit(ctx->addExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitAddExpr(FegenParser::AddExprContext *ctx) override { + auto expr = + std::any_cast(this->visit(ctx->term(0))); + for (size_t i = 1; i <= ctx->term().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "+") { + op = FegenOperator::ADD; + } else { + op = FegenOperator::SUB; + } + auto rhs = std::any_cast( + this->visit(ctx->term(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitTerm(FegenParser::TermContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->powerExpr(0))); + for (size_t i = 1; i <= ctx->powerExpr().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "*") { + op = FegenOperator::MUL; + } else if (opStr == "/") { + op = FegenOperator::DIV; + } else { + op = FegenOperator::MOD; + } + auto rhs = std::any_cast( + this->visit(ctx->powerExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitPowerExpr(FegenParser::PowerExprContext *ctx) override { + auto expr = std::any_cast( + this->visit(ctx->unaryExpr(0))); + for (size_t i = 1; i <= ctx->unaryExpr().size() - 1; i++) { + auto rhs = std::any_cast( + this->visit(ctx->unaryExpr(i))); + expr = FegenRightValue::ExpressionNode::binaryOperation( + expr, rhs, FegenOperator::POWER); + } + return expr; + } + + // return FegenRightValue::Expression* + std::any visitUnaryExpr(FegenParser::UnaryExprContext *ctx) override { + if (ctx->children.size() == 1 || ctx->Plus()) { + return this->visit(ctx->primaryExpr()); + } + auto expr = std::any_cast( + this->visit(ctx->primaryExpr())); + FegenOperator op; + if (ctx->Minus()) { + op = FegenOperator::NEG; + } else { + op = FegenOperator::NOT; + } + expr = FegenRightValue::ExpressionNode::unaryOperation(expr, op); + return expr; + } + + // return FegenRightValue::Expression* + std::any visitParenSurroundedExpr( + FegenParser::ParenSurroundedExprContext *ctx) override { + return this->visit(ctx->expression()); + } + + // return FegenRightValue::Expression* + std::any visitPrimaryExpr(FegenParser::PrimaryExprContext *ctx) override { + if (ctx->identifier()) { + auto name = ctx->identifier()->getText(); + auto var = this->sstack.attemptFindVar(name); + if (var) { + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(var); + } else { + auto tyDef = this->sstack.attemptFindTypeDef(name); + if (tyDef) { + auto tyVar = fegen::FegenType::getTemplateType(tyDef); + return fegen::FegenValue::get(fegen::FegenType::getMetaTemplateType(), + "", fegen::FegenRightValue::get(tyVar)); + } else { + // TODO: error report + std::cerr << "can not find variable: " << ctx->identifier()->getText() + << "." << std::endl; + exit(0); + return nullptr; + } + } + } else if (ctx->typeSpec()) { + auto ty = std::any_cast(this->visit(ctx->typeSpec())); + return (FegenRightValue::Expression *) + FegenRightValue::ExpressionTerminal::get(ty); + } else { // constant, functionCall, parenSurroundedExpr,contextMethodInvoke, + // and variableAccess + return this->visit(ctx->children[0]); + } + } + + // return ExpressionTerminal* + std::any visitIntLiteral(FegenParser::IntLiteralContext *ctx) override { + int number = std::stoi(ctx->getText()); + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(number); + } + + // return ExpressionTerminal* + std::any visitRealLiteral(FegenParser::RealLiteralContext *ctx) override { + double number = std::stod(ctx->getText()); + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(float(number)); + } + + // return ExpressionTerminal* + std::any visitCharLiteral(FegenParser::CharLiteralContext *ctx) override { + std::string s = ctx->getText(); + // remove quotation marks + std::string strWithoutQuotation = s.substr(1, s.size() - 2); + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(strWithoutQuotation); + } + + // return ExpressionTerminal* + std::any visitBoolLiteral(FegenParser::BoolLiteralContext *ctx) override { + int content = 0; + if (ctx->getText() == "true") { + content = 1; + } + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(content); + } + + // return ExpressionTerminal* + std::any visitListLiteral(FegenParser::ListLiteralContext *ctx) override { + std::vector elements; + for (auto exprCtx : ctx->expression()) { + auto expr = std::any_cast( + this->visit(exprCtx)); + elements.push_back(expr); + } + return (FegenRightValue::Expression *) + fegen::FegenRightValue::ExpressionTerminal::get(elements); + } + + std::any visitActionSpec(FegenParser::ActionSpecContext *ctx) override { + return nullptr; + } + + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override{ + sstack.pushScope(); + auto returnType = std::any_cast(this->visit(ctx->typeSpec())); + auto functionName = std::any_cast(this->visit(ctx->funcName())); + auto hasfunc = manager.functionMap.find(functionName); + if(hasfunc != manager.functionMap.end()){ + std::cerr << "The function name \" " << functionName + << "\" has already been used. Please use another name." << std::endl; + exit(0); + return nullptr; + } + auto functionParams = std::any_cast>(this->visit(ctx->funcParams())); + this->visit(ctx->statementBlock()); + + fegen::FegenFunction* function = fegen::FegenFunction::get(functionName, functionParams, &returnType); + manager.functionMap.insert(std::pair{functionName, function}); + sstack.popScope(); + return nullptr; + } + + std::any visitFuncName(FegenParser::FuncNameContext *ctx) override{ + auto functionName = ctx->identifier()->getText(); + return functionName; + } + + std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override{ + std::vector paramsList = {}; + + for(size_t i = 0; i < ctx->typeSpec().size(); i++){ + auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); + auto paramName = ctx->identifier(i)->getText(); + auto param = fegen::FegenValue::get(paramType, paramName, nullptr); + paramsList.push_back(param); + sstack.attemptAddVar(param); + } + return paramsList; + } + + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override{ + auto varType = std::any_cast(this->visit(ctx->typeSpec())); + auto varName = ctx->identifier()->getText(); + fegen::FegenValue* var; + if(ctx->expression()){ + auto varcontent = std::any_cast(this->visit(ctx->expression())); + // TODO: check error + // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ + // std::cerr << "The variabel \" " << varName + // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." << std::endl; + // exit(0); + // return nullptr; + // } + var = fegen::FegenValue::get(varType, varName, varcontent); + } else { + var = fegen::FegenValue::get(varType, varName, nullptr); + } + sstack.attemptAddVar(var); + manager.stmtContentMap.insert(std::pair{ctx, var}); + return var; + } + + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override{ + auto varName = ctx->identifier()->getText(); + auto varcontent = std::any_cast(this->visit(ctx->expression())); + auto var = sstack.attemptFindVar(varName); + if(!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)){ + std::cerr << "The variabel \" " << varName + << "\" need \"" << var->getType().getTypeName() << " \" type rightvalue." << std::endl; + exit(0); + return nullptr; + } + fegen::FegenValue * stmt = fegen::FegenValue::get(var->getType(), varName, varcontent); + manager.stmtContentMap.insert(std::pair{ctx, stmt}); + + return stmt; + } + + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override{ + std::vector parasList = {}; + auto functionName =std::any_cast(this->visit(ctx->funcName())); + auto hasFunc = manager.functionMap.at(functionName); + auto paramsNum = ctx->expression().size(); + auto paraList = hasFunc->getInputTypeList(); + if( paramsNum> 0){ + for(size_t i = 0; i < paramsNum; i++){ + auto oprand = std::any_cast(this->visit(ctx->expression(i))); + parasList.push_back(oprand); + } + size_t len1 = paraList.size(); + size_t len2 = parasList.size(); + if(len1 != len2){ + std::cerr << "The function \" " << functionName + << "\" parameter count mismatch." << std::endl; + exit(0); + return nullptr; + } + for(size_t i = 0; i < len1; i++){ + if(!fegen::FegenType::isSameType(¶List[i]->getType(), ¶sList[i]->exprType)){ + std::cerr << "The function \" " << functionName + << "\" parameter" << i << " type mismatch." << std::endl; + exit(0); + return nullptr; + } + } + } + auto returnType = hasFunc->getReturnType(); + fegen::FegenFunction *funcCall = fegen::FegenFunction::get(functionName, paraList, returnType); + manager.stmtContentMap.insert(std::pair{ctx, funcCall}); + return returnType; + } + + std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override{ + return nullptr; + } + + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override{ + sstack.pushScope(); + this->visit(ctx->expression(0)); + this->visit(ctx->statementBlock(0)); + if(ctx->expression().size() > 1){ + for(size_t i = 1; i < ctx->expression().size(); i++){ + this->visit(ctx->expression(i)); + this->visit(ctx->statementBlock(i)); + } + } + if(ctx->statementBlock(ctx->expression().size()+1)) + this->visit(ctx->statementBlock(ctx->expression().size()+1)); + sstack.popScope(); + + return nullptr; + } + + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override{ + sstack.pushScope(); + this->visit(ctx->assignStmt(0)); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(1)); + this->visit(ctx->statementBlock()); + sstack.popScope(); + + return nullptr; + } + + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { + auto opName = ctx->opName()->getText(); + auto opDef = + std::any_cast(this->visit(ctx->opBlock())); + opDef->setOpName(opName); + bool success = this->manager.addOperationDefination(opDef); + if (!success) { + // TODO: error report + std::cerr << "operation " << opName << " already exist." << std::endl; + } + return nullptr; + } + + // return FegenOperation* + std::any visitOpBlock(FegenParser::OpBlockContext *ctx) override { + std::vector args; + std::vector res; + if (ctx->argumentSpec()) { + args = std::any_cast>( + this->visit(ctx->argumentSpec())); + } + if (ctx->resultSpec()) { + res = std::any_cast>( + this->visit(ctx->resultSpec())); + } + return fegen::FegenOperation::get("", args, res, ctx->bodySpec()); + } +}; +} // namespace fegen +#endif \ No newline at end of file diff --git a/frontend/FrontendGen/include/Lexer.h b/frontend/FrontendGen/include/Lexer.h deleted file mode 100644 index 4ec1a88ea3..0000000000 --- a/frontend/FrontendGen/include/Lexer.h +++ /dev/null @@ -1,59 +0,0 @@ -//====- Lexer.h ---------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_LEXER_H -#define INCLUDE_LEXER_H -#include "Diagnostics.h" -#include "Token.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Support/SourceMgr.h" -namespace frontendgen { - -/// Manage all keywords. -class KeyWordManager { - llvm::StringMap keywordMap; - void addKeyWords(); - -public: - KeyWordManager() { addKeyWords(); } - void addKeyWord(llvm::StringRef name, tokenKinds kind); - tokenKinds getKeyWord(llvm::StringRef name, tokenKinds kind); -}; - -class Lexer { - llvm::SourceMgr &srcMgr; - DiagnosticEngine &diagnostic; - const char *curPtr; - llvm::StringRef curBuffer; - KeyWordManager keywordManager; - -public: - Lexer(llvm::SourceMgr &srcMgr, DiagnosticEngine &diagnostic) - : srcMgr(srcMgr), diagnostic(diagnostic) { - curBuffer = srcMgr.getMemoryBuffer(srcMgr.getMainFileID())->getBuffer(); - curPtr = curBuffer.begin(); - } - DiagnosticEngine &getDiagnostic() { return diagnostic; } - void next(Token &token); - void identifier(Token &token); - void number(Token &token); - void formToken(Token &token, const char *tokenEnd, tokenKinds kind); - llvm::StringRef getMarkContent(std::string start, std::string end); - llvm::StringRef getEndChContent(const char *start, char ch); -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/Parser.h b/frontend/FrontendGen/include/Parser.h deleted file mode 100644 index 90ebb3a5b3..0000000000 --- a/frontend/FrontendGen/include/Parser.h +++ /dev/null @@ -1,60 +0,0 @@ -//====- Parser.h --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_PARSER_H -#define INCLUDE_PARSER_H -#include "AST.h" -#include "Lexer.h" -#include "Sema.h" -#include "Terminator.h" -#include "Token.h" -namespace frontendgen { - -/// A class for parsing tokens. -class Parser { - Lexer &lexer; - Token token; - Sema &action; - Terminators &terminators; - -public: - Parser(Lexer &lexer, Sema &action, Terminators &terminators) - : lexer(lexer), action(action), terminators(terminators) { - advance(); - } - bool consume(tokenKinds kind); - bool consumeNoAdvance(tokenKinds kind); - void advance(); - Module *parser(); - void compilEngine(Module *module); - void parserRules(Rule *rule); - void parserGenerator(GeneratorAndOthers *generatorAndOthers); - void lookToken(); - AntlrBase::baseKind getAntlrBaseKind(llvm::StringRef name); - void parserIdentifier(GeneratorAndOthers *generatorAndOthers); - void parserTerminator(GeneratorAndOthers *generatorAndOthers); - void parserPBExpression(GeneratorAndOthers *generatorAndOthers); - void parserDialect(Dialect *&dialect, llvm::StringRef defName); - bool parserOp(std::vector &ops, llvm::StringRef opName); - void parserCurlyBracketOpen(GeneratorAndOthers *generatorAndOthers); - void parserDAG(DAG *&dag); - void parserBuilders(std::vector &builders); - void parserCode(llvm::StringRef &code); - void parserCArg(llvm::StringRef &operand, llvm::StringRef &value); -}; -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Scope.h b/frontend/FrontendGen/include/Scope.h new file mode 100644 index 0000000000..c8c46573a7 --- /dev/null +++ b/frontend/FrontendGen/include/Scope.h @@ -0,0 +1,81 @@ +#ifndef FEGEN_SCOPE_H +#define FEGEN_SCOPE_H + +#include "FegenManager.h" +#include + +namespace fegen { + +template class SymbolTable { +private: + std::map table; + +public: + SymbolTable() = default; + void add(std::string, T *e); + T *get(std::string name); + /// @brief return true if name exist in map. + bool exist(std::string name); + ~SymbolTable(); +}; + +class FegenScope { + using TypeDefTable = SymbolTable; + using VariableTable = SymbolTable; + friend class ScopeStack; + +private: + unsigned int scopeId; + FegenScope *parentScope; + TypeDefTable typeTable; + VariableTable varTable; + +public: + explicit FegenScope(unsigned int scopeId, FegenScope *parentScope); + ~FegenScope() = default; + + /// @brief this will not check. + FegenTypeDefination *findTypeDef(std::string name); + /// @brief this will not check whether tyDef is already existed or not. + void addTypeDef(FegenTypeDefination *tyDef); + /// @brief return true if exist. + bool isExistTypeDef(std::string name); + /// @brief this will not check. + FegenValue *findVar(std::string name); + /// @brief this will not check whether var is already existed or not. + void addVar(FegenValue *var); + /// @brief return true if exist. + bool isExistVar(std::string name); +}; + +class ScopeStack { +private: + std::vector scopes; + std::stack scopeStack; + + FegenScope *currentScope; + FegenScope *globalScope; + // scope total count + size_t count; + + ScopeStack(); + ~ScopeStack(); + ScopeStack(const ScopeStack &) = delete; + const ScopeStack &operator=(const ScopeStack &) = delete; + +public: + static ScopeStack &getScopeStack(); + void pushScope(); + void popScope(); + /// @brief check and add var to current scope, return false if failed. + bool attemptAddVar(FegenValue *var); + /// @brief check add find var from current scope, return nullptr if failed. + FegenValue *attemptFindVar(std::string name); + /// @brief check and add tyDef to current scope, return false if failed. + bool attemptAddTypeDef(FegenTypeDefination *tyDef); + /// @brief check and find tyDef from current scope, return nullptr if failed. + FegenTypeDefination *attemptFindTypeDef(std::string name); +}; +} // namespace fegen + +#endif \ No newline at end of file diff --git a/frontend/FrontendGen/include/Sema.h b/frontend/FrontendGen/include/Sema.h deleted file mode 100644 index e9d40881ca..0000000000 --- a/frontend/FrontendGen/include/Sema.h +++ /dev/null @@ -1,35 +0,0 @@ -//====- Sema.h ----------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_SEMA_H -#define INCLUDE_SEMA_H -#include "AST.h" - -namespace frontendgen { - -class Sema { -public: - void actOnModule(Module *module, std::vector &rules, - Dialect *&dialect, std::vector &ops); - void actOnRule(Rule *rule, std::vector &generators); - void actOnDialect(Dialect *dialect, llvm::StringRef defName, - llvm::StringRef name, llvm::StringRef cppNamespace); - void actOnOps(std::vector &ops, llvm::StringRef opName, DAG *arguments, - DAG *results, std::vector &builder); - void actOnDag(DAG *&arguments, DAG &dag); -}; -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/Terminator.def b/frontend/FrontendGen/include/Terminator.def deleted file mode 100644 index 4423f42bd8..0000000000 --- a/frontend/FrontendGen/include/Terminator.def +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef terminator -#define terminator(NAME) -#endif - -terminator(Var, 'var') -terminator(Add, 'add') -terminator(Sub, 'sub') -terminator(Def, 'def') -terminator(Return, 'return') -terminator(ParentheseOpen, '(') -terminator(ParentheseClose, ')') -terminator(Comma, ',') -terminator(BracketOpen, '{') -terminator(BracketClose, '}') -terminator(SbracketOpen, '[') -terminator(SbracketClose, ']') -terminator(Semi, ';') -terminator(AngleBracketOpen, '<') -terminator(AngleBracketClose, '>') -terminator(Number, [0-9]+) -terminator(Equal, '=') -#undef terminator diff --git a/frontend/FrontendGen/include/Terminator.h b/frontend/FrontendGen/include/Terminator.h deleted file mode 100644 index 400ec45c3e..0000000000 --- a/frontend/FrontendGen/include/Terminator.h +++ /dev/null @@ -1,75 +0,0 @@ -//====- Terminator.h -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_TERMINATOR_H -#define INCLUDE_TERMINATOR_H -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" - -namespace frontendgen { -class CGModule; -/// A class store antlr's terminators. -class Terminators { - friend class CGModule; - -private: - llvm::StringMap terminators; - llvm::SmallSet customTerminators; - -public: - Terminators() { -#define terminator(NAME, VALUE) terminators.insert(std::pair(#NAME, #VALUE)); -#include "Terminator.def" - } - /// Determine if it is a terminator. - bool isTerminator(llvm::StringRef terminator) { - std::string tmp = terminator.str(); - tmp[0] += 32; - if (customTerminators.contains(tmp)) - return true; - if (terminators.find(terminator) == terminators.end()) - return false; - return true; - } - - void addCustomTerminators(llvm::StringRef terminator) { - customTerminators.insert(terminator); - } - void addTerminator(llvm::StringRef terminator) { - terminators.insert(std::pair(terminator, terminator)); - } - /// Output all terminators. - void lookTerminators() { - llvm::outs() << "customTerminators\n"; - for (llvm::StringRef terminator : customTerminators) { - std::string terminatorName = terminator.str(); - terminatorName[0] -= 32; - llvm::outs() << "terminator name:" << terminatorName << ' ' - << "terminator content:" << terminator << '\n'; - } - for (auto start = terminators.begin(); start != terminators.end(); - ++start) { - llvm::outs() << "terminator name:" << start->first() << ' ' - << "terminator content:" << start->second << '\n'; - } - } -}; - -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Token.def b/frontend/FrontendGen/include/Token.def deleted file mode 100644 index 253b5751bc..0000000000 --- a/frontend/FrontendGen/include/Token.def +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef TOK -#define TOK(ID) -#endif -#ifndef PUNCTUATOR -#define PUNCTUATOR(ID, SP) TOK(ID) -#endif -#ifndef KEYWORD -#define KEYWORD(ID, FLAG) TOK(kw_ ## ID) -#endif -TOK(unknown) -TOK(eof) -TOK(identifier) -TOK(number) -PUNCTUATOR(semi, ";") -PUNCTUATOR(colon, ":") -PUNCTUATOR(apostrophe, "'") -PUNCTUATOR(asterisk, "*") -PUNCTUATOR(parentheseOpen, "(") -PUNCTUATOR(parentheseClose, ")") -PUNCTUATOR(questionMark, "?") -PUNCTUATOR(plus, "+") -PUNCTUATOR(equal, "=") -PUNCTUATOR(curlyBlacketOpen, "{") -PUNCTUATOR(curlyBlacketClose, "}") -PUNCTUATOR(dollar, "$") -PUNCTUATOR(comma, ",") -PUNCTUATOR(angleBracketOpen, "<") -PUNCTUATOR(angleBracketClose, ">") -PUNCTUATOR(squareBracketOpen, "[") -PUNCTUATOR(squareBracketClose, "]") -PUNCTUATOR(doubleQuotationMark, "\"") -KEYWORD(rule, KEYALL) -KEYWORD(op, KEYALL) -KEYWORD(dialect, KEYALL) -#undef TOK -#undef PUNCTUATOR -#undef KEYWORD diff --git a/frontend/FrontendGen/include/Token.h b/frontend/FrontendGen/include/Token.h deleted file mode 100644 index 96f322753a..0000000000 --- a/frontend/FrontendGen/include/Token.h +++ /dev/null @@ -1,56 +0,0 @@ -//====- Token.h --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_TOKEN -#define INCLUDE_TOKEN -#include "Lexer.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/SMLoc.h" -#include "llvm/Support/raw_ostream.h" -namespace frontendgen { -enum tokenKinds { -#define TOK(ID) ID, -#include "Token.def" - NUM_TOKENS -}; -/// store token names. -static const char *tokenNameMap[] = { -#define TOK(ID) #ID, -#define KEYWORD(ID, FLAG) #ID, -#include "Token.def" - nullptr}; - -class Token { - friend class Lexer; - -private: - tokenKinds tokenKind; - const char *start; - int length; - -public: - void setTokenKind(tokenKinds kind) { tokenKind = kind; } - void setLength(int len) { length = len; } - - llvm::StringRef getContent() { return llvm::StringRef(start, length); } - tokenKinds getKind() { return tokenKind; } - const char *getTokenName() { return tokenNameMap[tokenKind]; } - bool is(tokenKinds kind); - llvm::SMLoc getLocation(); -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/TypeMap.def b/frontend/FrontendGen/include/TypeMap.def deleted file mode 100644 index 5d7862b8df..0000000000 --- a/frontend/FrontendGen/include/TypeMap.def +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef CPPMAP -#define CPPMAP(key, value) -#endif - -#ifndef ARGUMENTSMAP -#define ARGUMENTSMAP(key, value) -#endif - -#ifndef RESULTSMAP -#define RESULTSMAP(key, value) -#endif - -CPPMAP("\"StringRef\"", "llvm::StringRef") -CPPMAP("\"ArrayRef\"", "llvm::ArrayRef") -CPPMAP("\"FunctionType\"", "llvm::FunctionType") -CPPMAP("\"ArrayRef\"", "llvm::ArrayRef") -CPPMAP("\"Value\"", "mlir::Value") -CPPMAP("\"double\"", "double") -CPPMAP("\"DenseElementsAttr\"", "mlir::DenseElementsAttr") - -ARGUMENTSMAP("F64ElementsAttr", "mlir::Value") -ARGUMENTSMAP("F64Tensor", "mlir::Value") -ARGUMENTSMAP("Variadic", "mlir::Value") -ARGUMENTSMAP("SymbolNameAttr", "llvm::StringRef") -ARGUMENTSMAP("TypeAttrOf", "mlir::FunctionType") -ARGUMENTSMAP("F64MemRef", "mlir::Value") -RESULTSMAP("StaticShapeTensorOf<[F64]>", "mlir::Type") -RESULTSMAP("F64Tensor", "mlir::Type") - -#undef TYPEMAP -#undef ARGUMENTSMAP -#undef RESULTSMAP diff --git a/frontend/FrontendGen/lib/CGModule.cpp b/frontend/FrontendGen/lib/CGModule.cpp deleted file mode 100644 index 8d21ed6522..0000000000 --- a/frontend/FrontendGen/lib/CGModule.cpp +++ /dev/null @@ -1,422 +0,0 @@ -//====- CGModule.cpp ------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "CGModule.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include -#include - -using namespace frontendgen; - -/// Emit the ast,currently only antlr's ast are supported. -void CGModule::emitAST() { - for (auto i : module->getRules()) { - llvm::outs() << "rule name: " << i->getName() << '\n'; - for (auto j : i->getGeneratorsAndOthers()) { - llvm::outs() << " generator: " << '\n' << " "; - for (auto k : j->getGenerator()) { - if (k->getKind() == AntlrBase::baseKind::rule) - llvm::outs() << "\"" << k->getName() << "\"(rule) "; - else if (k->getKind() == AntlrBase::baseKind::terminator) - llvm::outs() << "\"" << k->getName() << "\"(terminator) "; - else if (k->getKind() == AntlrBase::baseKind::pbexpression) - llvm::outs() << "\"" << k->getName() << "\"(bpExpression) "; - } - llvm::outs() << '\n'; - } - } -} - -/// Emit the code of antlr , emit the generative formula first, then emit -/// user-defined terminator , and finally emit the system-defined terminator. -void CGModule::emitAntlr(llvm::StringRef grammarName) { - emitGrammar(grammarName); - emit(module->getRules()); - emitCustomTerminators(); - emitTerminators(); -} -/// Emit the system-defined terminator. -void CGModule::emitTerminators() { - for (auto start = terminators.terminators.begin(); - start != terminators.terminators.end(); start++) { - os << start->first() << '\n'; - os << " : " << start->second << "\n ;\n\n"; - } - emitWSAndComment(); -} - -void CGModule::emitGrammar(llvm::StringRef grammarName) { - os << "grammar " << grammarName << ";\n\n"; -} - -/// Emit user-defined terminator. -void CGModule::emitCustomTerminators() { - for (auto terminator : terminators.customTerminators) { - std::string tmp = terminator.str(); - if (tmp[0] >= 'a' && tmp[0] <= 'z') - tmp[0] -= 32; - llvm::StringRef name(tmp); - os << name << '\n'; - os << " : \'" << terminator.str() << "\'\n ;\n\n"; - } -} - -/// Emit the generative formula. -void CGModule::emit(const std::vector &rules) { - for (Rule *rule : rules) { - os << rule->getName() << '\n'; - emit(rule->getGeneratorsAndOthers()); - os << '\n'; - } -} - -/// Emit all generative formulas in a rule. -void CGModule::emit( - const std::vector &generatorsAndOthers) { - for (GeneratorAndOthers *generatorAndOthers : generatorsAndOthers) { - if (generatorAndOthers == generatorsAndOthers[0]) - os << " : "; - else - os << " | "; - emit(generatorAndOthers->getGenerator()); - } - os << " ;\n"; -} - -/// Output the elements of the generated formula. -void CGModule::emit(const std::vector &generator) { - for (AntlrBase *base : generator) { - if (base->getKind() == AntlrBase::baseKind::terminator) { - std::string tmp = base->getName().str(); - // The terminator in antlr must be capitalized. - if (tmp[0] >= 'a' && tmp[0] <= 'z') - tmp[0] -= 32; - llvm::StringRef name(tmp); - os << name << " "; - } else if (base->getKind() == AntlrBase::baseKind::rule) { - os << base->getName() << " "; - } else if (base->getKind() == AntlrBase::baseKind::pbexpression) { - os << base->getName(); - } - } - os << '\n'; -} - -/// TODO: Supports user-defined comment whitespace. -void CGModule::emitWSAndComment() { - os << "Identifier\n : [a-zA-Z][a-zA-Z0-9_]*\n ;\n\n"; - os << "WS\n : [ \\r\\n\\t] -> skip\n ;\n\n"; - os << "Comment\n : '#' .*? \'\\n\' ->skip\n ;\n"; -} - -void CGModule::emitMLIRVisitor(llvm::StringRef grammarName) { - emitIncludes(grammarName); - emitClass(grammarName); -} - -void CGModule::emitIncludes(llvm::StringRef grammarName) { - os << "#include \"" << grammarName << "BaseVisitor.h\"\n"; - os << "#include \"" << grammarName << "Lexer.h\"\n"; - os << "#include \"" << grammarName << "Parser.h\"\n"; - os << "#include \"mlir/IR/Attributes.h\"\n"; - os << "#include \"mlir/IR/Builders.h\"\n"; - os << "#include \"mlir/IR/BuiltinOps.h\"\n"; - os << "#include \"mlir/IR/BuiltinTypes.h\"\n"; - os << "#include \"mlir/IR/MLIRContext.h\"\n"; - os << "#include \"mlir/IR/Verifier.h\"\n"; - os << "#include \"llvm/ADT/STLExtras.h\"\n"; - os << "#include \"llvm/ADT/ScopedHashTable.h\"\n"; - os << "#include \"llvm/ADT/StringRef.h\"\n"; - os << "#include \"llvm/Support/raw_ostream.h\"\n"; - os << "\n"; -} - -/// Emit visitor class. -void CGModule::emitClass(llvm::StringRef grammarName) { - os << "class MLIR" << grammarName << "Visitor : public " << grammarName - << "BaseVisitor {\n"; - - os << "mlir::ModuleOp theModule;\n"; - os << "mlir::OpBuilder builder;\n"; - os << "std::string fileName;\n\n"; - - os << "public:\n"; - os << "MLIR" << grammarName - << "Visitor(std::string filename, mlir::MLIRContext &context)\n" - << ": builder(&context), fileName(filename) " - << "{\n theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); " - "\n}\n\n"; - os << "mlir::ModuleOp getModule() { return theModule; }\n\n"; - - // Emit all virtual functions. - auto rules = module->getRules(); - for (auto rule : rules) { - emitRuleVisitor(grammarName, rule); - } - os << "};\n"; -} -/// Emit virtual function in antlr. -void CGModule::emitRuleVisitor(llvm::StringRef grammarName, Rule *rule) { - std::string ruleName = rule->getName().str(); - ruleName[0] = ruleName[0] - 32; - os << "virtual std::any visit" << ruleName; - os << "(" << grammarName << "Parser::" << ruleName << "Context *ctx) {\n"; - emitBuilders(rule); - os << " return visitChildren(ctx);\n"; - os << "}\n\n"; -} - -void CGModule::emitBuilders(Rule *rule) { - for (GeneratorAndOthers *generatorAndOthers : - rule->getGeneratorsAndOthers()) { - llvm::SmallVector builderOpNames = - generatorAndOthers->getBuilderNames(); - llvm::SmallVector indices = generatorAndOthers->getBuilderIndices(); - int size = builderOpNames.size(); - for (int start = 0; start < size; start++) - emitBuilder(builderOpNames[start], indices[start]); - } -} - -void CGModule::emitBuilder(llvm::StringRef builderOp, int index) { - Op *op = findOp(builderOp); - if (op == nullptr) { - llvm::errs() << builderOp << " is undefined!\n"; - return; - } - emitOp(op, index); -} - -Op *CGModule::findOp(llvm::StringRef opName) { - for (Op *op : module->getOps()) { - if (op->getOpName() == opName) - return op; - } - return nullptr; -} - -/// Emit the operation we want to create. -void CGModule::emitOp(Op *op, int index) { - // Emit the default builder function. - if (index == 0) { - DAG *arguments = op->getArguments(); - DAG *result = op->getResults(); - llvm::SmallVector argOperands; - llvm::SmallVector argOperandNames; - llvm::SmallVector resOperands; - llvm::SmallVector resOperandNames; - if (arguments) { - argOperands = arguments->getOperands(); - argOperandNames = arguments->getOperandNames(); - } - if (result) { - resOperands = result->getOperands(); - resOperandNames = result->getOperandNames(); - } - os << " {\n"; - // opArguments are used to store the names of the arguments needed to create - // the operation. - llvm::SmallVector opArguments; - // tmpStrings are used to store and own some computed string, keeping their - // lifetime longger than the StringRefs in opArguments - llvm::SmallVector tmpStrings; - // Emit variables for creation operation. - // Emit variables of result type. - for (size_t index = 0; index < resOperands.size(); index++) { - if (!typeMap.findResultsMap(resOperands[index]).empty()) { - os << " " << typeMap.findResultsMap(resOperands[index]) << " "; - if (!resOperandNames[index].empty()) { - os << resOperandNames[index] << ";\n"; - opArguments.push_back(resOperandNames[index]); - } else { - tmpStrings.emplace_back("res" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else if (resOperands[index].startswith("AnyTypeOf")) { - llvm::StringRef operand = resOperands[index]; - auto start = operand.find('[') + 1; - auto end = operand.find(']'); - auto cur = start; - if (start == std::string::npos || end == std::string::npos) { - return; - } - llvm::StringRef type; - while (cur <= end) { - if (operand[cur] == ',' || cur == end) { - std::string str(operand, start, cur - start); - str.erase(0, str.find_first_not_of(" ")); - str.erase(str.find_last_not_of(" ") + 1); - if (typeMap.findResultsMap(str).empty()) { - llvm::errs() << str << " in " << op->getOpName() - << " in results is not supported.\n"; - } - type = typeMap.findResultsMap(str); - start = cur + 1; - } - cur++; - } - os << " " << type << " "; - if (!resOperandNames[index].empty()) { - os << resOperandNames[index] << ";\n"; - opArguments.push_back(resOperandNames[index]); - } else { - tmpStrings.emplace_back("res" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else { - llvm::errs() << resOperands[index] << " in " << op->getOpName() - << " in results is not supported.\n"; - return; - } - } - // Emit variables of argument type. - for (size_t index = 0; index < argOperands.size(); index++) { - if (!typeMap.findArgumentMap(argOperands[index]).empty()) { - os << " " << typeMap.findArgumentMap(argOperands[index]) << " "; - if (!argOperandNames[index].empty()) { - os << argOperandNames[index] << ";\n"; - opArguments.push_back(argOperandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else if (argOperands[index].startswith("AnyTypeOf")) { - llvm::StringRef operand = argOperands[index]; - auto start = operand.find('[') + 1; - auto end = operand.find(']'); - auto cur = start; - if (start == std::string::npos || end == std::string::npos) { - return; - } - llvm::StringRef type; - while (cur <= end) { - if (operand[cur] == ',' || cur == end) { - std::string str(operand, start, cur - start); - str.erase(0, str.find_first_not_of(" ")); - str.erase(str.find_last_not_of(" ") + 1); - if (typeMap.findArgumentMap(str).empty()) { - llvm::errs() << str << " in " << op->getOpName() - << " in arguments is not supported.\n"; - } - start = cur + 1; - type = typeMap.findArgumentMap(str); - } - cur++; - } - os << " " << type << " "; - if (!argOperandNames[index].empty()) { - os << argOperandNames[index] << ";\n"; - opArguments.push_back(argOperandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else { - llvm::errs() << argOperands[index] << " in " << op->getOpName() - << " in arguments is not supported.\n"; - return; - } - } - // Emit the operation we want to create. - os << " mlir::Location location;\n"; - llvm::StringRef cppNameSpace( - module->getDialect()->getCppNamespace().data() + 1, - module->getDialect()->getCppNamespace().size() - 2); - os << " " - << "builder.create<" << cppNameSpace << "::" << op->getOpName() - << ">(location"; - if (opArguments.size()) - os << ", "; - for (size_t index = 0; index < opArguments.size(); index++) { - os << opArguments[index]; - if (index + 1 != opArguments.size()) - os << ", "; - } - os << ");\n"; - os << " }\n\n"; - } else if (index > 0) { - // Emit custom builder function. - index--; - // Emit the variables which are used to fill builder function. - llvm::SmallVector operands = - op->getBuilders()[index]->getDag()->getOperands(); - llvm::SmallVector operandNames = - op->getBuilders()[index]->getDag()->getOperandNames(); - llvm::SmallVector opArguments; - llvm::SmallVector tmpStrings; - os << " {\n"; - for (size_t index = 0; index < operands.size(); index++) { - if (!typeMap.findCppMap(operands[index]).empty()) - os << " " << typeMap.findCppMap(operands[index]); - else - os << " " << operands[index]; - if (!operandNames[index].empty()) { - os << " " << operandNames[index] << ";\n"; - opArguments.push_back(operandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } - // Emit the operation we want to create. - os << " mlir::Location location;\n"; - llvm::StringRef cppNameSpace( - module->getDialect()->getCppNamespace().data() + 1, - module->getDialect()->getCppNamespace().size() - 2); - os << " " - << "builder.create<" << cppNameSpace << "::" << op->getOpName() - << ">(location"; - if (!operandNames.empty()) { - os << ", "; - for (size_t index = 0; index < opArguments.size(); index++) { - os << opArguments[index]; - if (index + 1 != opArguments.size()) - os << ", "; - } - } - os << ");\n"; - os << " }\n\n"; - } -} - -llvm::StringRef TypeMap::findCppMap(llvm::StringRef key) { - if (cppMap.find(key) == cppMap.end()) - return llvm::StringRef(); - return cppMap[key]; -} - -llvm::StringRef TypeMap::findArgumentMap(llvm::StringRef key) { - if (argumentsMap.find(key) == argumentsMap.end()) - return llvm::StringRef(); - return argumentsMap[key]; -} - -llvm::StringRef TypeMap::findResultsMap(llvm::StringRef key) { - if (resultsMap.find(key) == resultsMap.end()) - return llvm::StringRef(); - return resultsMap[key]; -} diff --git a/frontend/FrontendGen/lib/CMakeLists.txt b/frontend/FrontendGen/lib/CMakeLists.txt index 90e9d45027..8907951231 100644 --- a/frontend/FrontendGen/lib/CMakeLists.txt +++ b/frontend/FrontendGen/lib/CMakeLists.txt @@ -1,12 +1,50 @@ -include_directories(../include) -set(LLVM_LINK_COMPONENTS -support) - -add_llvm_component_library(LLVMfrontendgenlib -CGModule.cpp -Lexer.cpp -Parser.cpp -Sema.cpp -Diagnostics.cpp -LINK_COMPONENTS -support) +antlr_target(FegenLexer FegenLexer.g4 + PACKAGE fegen + LEXER + ) + +antlr_target(FegenParser FegenParser.g4 + PACKAGE fegen + DEPENDS_ANTLR FegenLexer + PARSER + LISTENER + VISITOR + COMPILE_FLAGS -lib + ${ANTLR_FegenLexer_OUTPUT_DIR} + ) + +include_directories(${ANTLR_FegenLexer_OUTPUT_DIR}) +set(ANTLR_FegenLexer_OUTPUT_DIR ${ANTLR_FegenLexer_OUTPUT_DIR} CACHE STRING "ANTLR_FegenLexer_OUTPUT_DIR") +include_directories(${ANTLR_FegenParser_OUTPUT_DIR}) +set(ANTLR_FegenParser_OUTPUT_DIR ${ANTLR_FegenParser_OUTPUT_DIR} CACHE STRING "ANTLR_FegenParser_OUTPUT_DIR") + +add_library(fegen_antlr_generated + ${ANTLR_FegenLexer_CXX_OUTPUTS} + ${ANTLR_FegenParser_CXX_OUTPUTS} +) +add_dependencies(fegen_antlr_generated antlr4_runtime) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../include") +add_library(FegenSupport + FegenManager.cpp + Scope.cpp +) +add_dependencies(FegenSupport fegen_antlr_generated) + +llvm_map_components_to_libnames(llvm_libs support) + +target_link_libraries(FegenSupport + PRIVATE + ${llvm_libs} +) + +add_library(fegenVisitor + FegenVisitor.cpp +) + +target_link_libraries(fegenVisitor + PUBLIC + fegen_antlr_generated + antlr4_static + FegenSupport +) \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Diagnostics.cpp b/frontend/FrontendGen/lib/Diagnostics.cpp deleted file mode 100644 index d759867098..0000000000 --- a/frontend/FrontendGen/lib/Diagnostics.cpp +++ /dev/null @@ -1,43 +0,0 @@ -//====- Diagnostics.cpp -------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Diagnostics.h" -#include "llvm/Support/SourceMgr.h" - -using namespace frontendgen; -namespace { - -/// Storage the message of the diagnostic. -const char *diagnosticText[] = { -#define DIAG(ID, Level, Msg) Msg, -#include "Diagnostics.def" -}; - -/// Storage the kind of the diagnostic. -llvm::SourceMgr::DiagKind diagnosticKind[] = { -#define DIAG(ID, Level, Msg) llvm::SourceMgr::DK_##Level, -#include "Diagnostics.def" -}; -} // namespace - -/// Get the message of the diagnostic. -const char *DiagnosticEngine::getDiagnosticText(unsigned diagID) { - return diagnosticText[diagID]; -} -/// Get the kind of the diagnostic. -llvm::SourceMgr::DiagKind DiagnosticEngine::getDiagnosticKind(unsigned DiagID) { - return diagnosticKind[DiagID]; -} diff --git a/frontend/FrontendGen/lib/FegenLexer.g4 b/frontend/FrontendGen/lib/FegenLexer.g4 new file mode 100644 index 0000000000..5dc8006b21 --- /dev/null +++ b/frontend/FrontendGen/lib/FegenLexer.g4 @@ -0,0 +1,221 @@ +lexer grammar FegenLexer; + +fragment Schar: ~ ["\\\r\n]; + +fragment NONDIGIT: [a-zA-Z_]; + +fragment UPPERCASE: [A-Z]; + +fragment LOWERCASE: [a-z]; + +fragment ALLCASE: [a-zA-Z0-9_]; + +fragment NOZERODIGIT: [1-9]; + +fragment DIGIT: [0-9]; + +fragment SQuoteLiteral + : '\'' (('\\' ([btnfr"'\\] | . |EOF))|( ~ ['\r\n\\]))* '\'' + ; + +// literal + +UnsignedInt: NOZERODIGIT DIGIT* | '0'; + +ScienceReal : (Plus | Minus)? UnsignedInt Dot UnsignedInt ( 'E' (Plus | Minus)? UnsignedInt )?; + +ConstBoolean: 'true' | 'false'; + +// key words + +FEGEN: 'fegen'; + +INPUTS: 'inputs'; + +RETURNS: 'returns'; + +ACTIONS: 'actions'; + +IR: 'ir'; + +OPERAND_VALUE: 'operandValue'; + +ATTRIBUTE_VALUE: 'attributeValue'; + +CPP_VALUE: 'cppValue'; + +OPERATION: 'operation'; + +FUNCTION: 'function'; + +TYPEDEF: 'typedef'; + +OPDEF: 'opdef'; + +ARGUMENTS: 'arguments'; + +RESULTS: 'results'; + +BODY: 'body'; + +EMPTY: 'null'; + +PARAMETERS: 'parameters'; + +ASSEMBLY_FORMAT: 'assemblyFormat'; + + +// types +TYPE: 'Type'; + +TYPETEMPLATE: 'TypeTemplate'; + +BOOL: 'bool'; + +INT: 'int'; + +FLOAT: 'float'; + +DOUBLE: 'double'; + +// F64TENSOR: 'F64Tensor'; + +// F64VECTOR: 'F64Vector'; + +CHAR: 'char'; + +STRING: 'string'; + +LIST: 'list'; + +ANY: 'any'; + +OPTIONAL: 'optional'; + +INTEGER: 'Integer'; + +FLOATPOINT: 'FloatPoint'; + +TENSOR: 'Tensor'; + +VECTOR: 'Vector'; + +CPP: 'cpp'; + +OPERAND: 'operand'; + +ATTRIBUTE: 'attribute'; + +// stmt + +IF: 'if'; + +ELSE: 'else'; + +FOR: 'for'; + +IN: 'in'; + +WHILE: 'while'; + +// identifiers + +LexerRuleName: UPPERCASE (NONDIGIT | DIGIT)*; + +ParserRuleName: LOWERCASE (NONDIGIT | DIGIT)*; + +// literal + +StringLiteral + : SQuoteLiteral + ; + + +// marks + +AND: '&&'; + +Logic_OR: '||'; + +EQUAL: '=='; + +NOT_EQUAL: '!='; + +Less: '<'; + +LessEqual: '<='; + +Greater: '>'; + +GreaterEqual: '>='; + +Comma: ','; + +Semi: ';'; + +LeftParen: '('; + +RightParen: ')'; + +LeftBracket: '['; + +RightBracket: ']'; + +LeftBrace: '{'; + +RightBrace: '}'; + +Dot: '.'; + +Colon: ':'; + +OR: '|'; + +QuestionMark: '?'; + +Star: '*'; + +Div: '/'; + +Plus: '+'; + +Minus: '-'; + +Assign: '='; + +Dollar: '$'; + +StarStar: '**'; + +MOD: '%'; + +Arror: '->'; + +Underline: '_'; + +Tilde: '~'; + +Exclamation: '!'; + +Range: '..'; + +BeginInclude: '@header' LeftBrace -> pushMode (TargetLanguageAction); + +Whitespace: [ \t]+ -> skip; + +Newline: ('\r' '\n'? | '\n') -> skip; + +BlockComment: '/*' .*? '*/' -> skip; + +LineComment: '//' ~ [\r\n]* -> skip; + +mode TargetLanguageAction; + +EndInclude: RightBrace -> popMode; + +INCLUDE_CONTENT + : . + | '\n' + | ' ' + ; + diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp new file mode 100644 index 0000000000..8a0fe23a6c --- /dev/null +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -0,0 +1,1480 @@ +#include "FegenParserBaseVisitor.h" +#include "FegenManager.h" +#include "Scope.h" +#include +#include +#include +#include +#include +#include +#include + +fegen::FegenFunction::FegenFunction(std::string name, + std::vector &&inputTypeList, + FegenType *returnType) + : name(name), inputTypeList(inputTypeList), returnType(returnType) {} + +fegen::FegenFunction * +fegen::FegenFunction::get(std::string name, + std::vector inputTypeList, + FegenType *returnType) { + return new fegen::FegenFunction(name, std::move(inputTypeList), returnType); +} +std::string fegen::FegenFunction::getName() { this->name; } + +std::vector &fegen::FegenFunction::getInputTypeList() { + return this->inputTypeList; +} + +fegen::FegenValue *fegen::FegenFunction::getInputTypeList(size_t i) { + return this->inputTypeList[i]; +} + +fegen::FegenType *fegen::FegenFunction::getReturnType() { + return this->returnType; +} + +fegen::FegenOperation::FegenOperation(std::string dialectName, + std::string operationName, + std::vector &&arguments, + std::vector &&results, + fegen::FegenParser::BodySpecContext *ctx) + : dialectName(dialectName), arguments(arguments), results(results), + ctx(ctx) {} + +void fegen::FegenOperation::setOpName(std::string name) { + this->operationName = name; +} +std::string fegen::FegenOperation::getOpName() { return this->operationName; } + +std::vector &fegen::FegenOperation::getArguments() { + return this->arguments; +} + +fegen::FegenValue *fegen::FegenOperation::getArguments(size_t i) { + return this->arguments[i]; +} + +std::vector &fegen::FegenOperation::getResults() { + return this->results; +} + +fegen::FegenValue *fegen::FegenOperation::getResults(size_t i) { + return this->results[i]; +} + +fegen::FegenOperation *fegen::FegenOperation::get( + std::string operationName, std::vector arguments, + std::vector results, FegenParser::BodySpecContext *ctx) { + return new fegen::FegenOperation(fegen::FegenManager::getManager().moduleName, + operationName, std::move(arguments), + std::move(results), ctx); +} + +// class FegenType + +/// @brief get name of Type Instance by jointsing template name and parameters, +/// for example: Integer + 32 --> Integer<32> +/// @return joint name +std::string jointTypeName(std::string templateName, + const std::vector ¶meters) { + if (parameters.empty()) { + return templateName; + } + std::string res = templateName; + res.append("<"); + size_t count = parameters.size(); + auto firstParamStr = parameters[0]->getContentString(); + res.append(firstParamStr); + for (size_t i = 1; i <= count - 1; i++) { + auto paramStr = parameters[i]->getContentString(); + res.append(", "); + res.append(paramStr); + } + res.append(">"); + return res; +} + +fegen::FegenType::FegenType(TypeKind kind, std::string name, + std::vector parameters, + FegenTypeDefination *tyDef, int typeLevel) + : kind(kind), typeName(name), parameters(std::move(parameters)), + typeDefine(tyDef), typeLevel(typeLevel) {} + +fegen::FegenType::FegenType(fegen::FegenType::TypeKind kind, + std::vector parameters, + FegenTypeDefination *tyDef, int typeLevel) + : kind(kind), typeName(jointTypeName(tyDef->getName(), parameters)), + parameters(std::move(parameters)), typeDefine(tyDef), + typeLevel((typeLevel)) {} + +fegen::FegenType::FegenType(const fegen::FegenType &fty) + : kind(fty.kind), typeName(fty.typeName), typeDefine(fty.typeDefine), + typeLevel(fty.typeLevel) { + // deep copy parameters + for (auto paramPtr : fty.parameters) { + this->parameters.push_back(new fegen::FegenValue(*paramPtr)); + } +} + +fegen::FegenType::FegenType(fegen::FegenType &&fty) + : kind(fty.kind), typeName(std::move(fty.typeName)), + parameters(std::move(fty.parameters)), typeDefine(fty.typeDefine), + typeLevel(fty.typeLevel) {} + +fegen::FegenType::TypeKind fegen::FegenType::getTypeKind() { + return this->kind; +} + +void fegen::FegenType::setTypeKind(fegen::FegenType::TypeKind kind) { + this->kind = kind; +} + +std::vector &fegen::FegenType::getParameters() { + return this->parameters; +} + +fegen::FegenValue *fegen::FegenType::getParameters(size_t i) { + return this->parameters[i]; +} + +void fegen::FegenType::setParameters(std::vector ¶ms) { + this->parameters = params; + // set parameters and level up! + this->typeLevel++; +} + +fegen::FegenTypeDefination *fegen::FegenType::getTypeDefination() { + return this->typeDefine; +} + +void fegen::FegenType::setTypeDefination(fegen::FegenTypeDefination *tyDef) { + this->typeDefine = tyDef; +} + +std::string fegen::FegenType::getTypeName() { return this->typeName; } + +int fegen::FegenType::getTypeLevel() { return this->typeLevel; } + +bool fegen::FegenType::isSameType(fegen::FegenType *type1, + fegen::FegenType *type2) { + if (type1->getTypeName() == type2->getTypeName()) + return true; + else + return false; +} + +std::string fegen::FegenType::toStringForTypedef() { + // handle builtin type instance + auto typeName = this->typeName; + auto typedefName = this->typeDefine->getName(); + if (this->typeDefine->isCustome()) { + return this->typeDefine->getName(); + } else if (typedefName == FEGEN_TYPE) { + return "\"Type\""; + } else if (typedefName == FEGEN_LIST) { + std::string res = "ArrayRefParameter<"; + for (size_t i = 0; i <= this->parameters.size() - 1; i++) { + res.append(this->parameters[i]->getContentStringForTypedef()); + if (i != this->parameters.size() - 1) { + res.append(", "); + } + } + res.append(">"); + return res; + } else if (typedefName == FEGEN_INTEGER) { + if (this->parameters.size() == 0) { + return "Builtin_IntegerAttr"; + } else { + if (typeName == "int") { + return "\"int\""; + } else if (typeName == "bool") { + return "\"bool\""; + } + int size = this->getParameters(0)->getContent(); + if (size == 64) { + return "\"long\""; + } else if (size == 16) { + return "\"short\""; + } else { + std::cerr << "unsupport type: " << typeName << std::endl; + exit(0); + } + } + } else if (typedefName == FEGEN_FLOATPOINT) { + if (this->parameters.size() == 0) { + return "Builtin_FloatAttr"; + } else { + if (typeName == "float") { + return "\"float\""; + } else if (typeName == "double") { + return "\"double\""; + } else { + std::cerr << "unsupport type: " << typeName << std::endl; + exit(0); + } + } + } else { + std::cerr << "unsupport type: " << typeName << std::endl; + exit(0); + } +} + +std::string fegen::FegenType::toStringForOpdef() { + // handle builtin type instance + auto typeName = this->typeName; + auto typedefName = this->typeDefine->getName(); + if (this->typeDefine->isCustome()) { + return this->typeDefine->getName(); + } else if (typedefName == FEGEN_LIST) { + std::string res = "Variadic<"; + for (size_t i = 0; i <= this->parameters.size() - 1; i++) { + res.append(this->parameters[i]->getContentStringForTypedef()); + if (i != this->parameters.size() - 1) { + res.append(", "); + } + } + res.append(">"); + return res; + } else if (typedefName == FEGEN_INTEGER) { + if (this->parameters.size() == 0) { + return "Builtin_Integer"; + } else { + if (typeName == "int") { + return "I32"; + } + int size = this->getParameters(0)->getContent(); + if (size == 64) { + return "I64"; + } else if (size == 16) { + return "I16"; + } + } + } + + std::cerr << "unsupport type: " << typeName << std::endl; + exit(0); +} + +fegen::FegenType::~FegenType() { + for (auto p : this->parameters) { + delete p; + } +} + +fegen::FegenType fegen::FegenType::getPlaceHolder() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), + 0); +} + +fegen::FegenType fegen::FegenType::getMetaType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_TYPE), 2); +} + +fegen::FegenType fegen::FegenType::getMetaTemplateType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), + 1); +} + +fegen::FegenType fegen::FegenType::getInt32Type() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, "int", + {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", + fegen::FegenRightValue::get())}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +} + +fegen::FegenType fegen::FegenType::getFloatType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, "float", + {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::get(32))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +} + +fegen::FegenType fegen::FegenType::getDoubleType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, "double", + {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::get(64))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +} + +fegen::FegenType fegen::FegenType::getBoolType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, "bool", + {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::get(1))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +} + +fegen::FegenType fegen::FegenType::getIntegerType(fegen::FegenValue *size) { + if (size->getContent() == 32) + return fegen::FegenType::getInt32Type(); + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {size}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +} + +fegen::FegenType fegen::FegenType::getFloatPointType(fegen::FegenValue *size) { + if (size->getContent() == 32) { + return fegen::FegenType::getFloatType(); + } else if (size->getContent() == 64) { + return fegen::FegenType::getDoubleType(); + } + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {size}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +} + +fegen::FegenType fegen::FegenType::getCharType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_CHAR), 3); +} + +fegen::FegenType fegen::FegenType::getStringType() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_STRING), 3); +} + +fegen::FegenType fegen::FegenType::getVectorType(fegen::FegenValue *size, + fegen::FegenType elementType) { + assert(elementType.typeLevel == 3); + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, + {size, + fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", + fegen::FegenRightValue::get(elementType))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_VECTOR), + elementType.typeLevel); +} + +fegen::FegenType fegen::FegenType::getTensorType(fegen::FegenValue *shape, + fegen::FegenType elementType) { + assert(elementType.typeLevel == 3); + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, + {shape, + fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", + fegen::FegenRightValue::get(elementType))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_TENSOR), + elementType.typeLevel); +} + +// List +fegen::FegenType fegen::FegenType::getListType(fegen::FegenType elementType) { + assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, + {fegen::FegenValue::get( + elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() + : fegen::FegenType::getMetaType(), + "elementType", fegen::FegenRightValue::get(elementType))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_LIST), + elementType.typeLevel); +} + +// Optional +fegen::FegenType +fegen::FegenType::getOptionalType(fegen::FegenType elementType) { + assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, + {fegen::FegenValue::get( + elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() + : fegen::FegenType::getMetaType(), + "elementType", fegen::FegenRightValue::get(elementType))}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_OPTINAL), + elementType.typeLevel); +} + +// Any +fegen::FegenType +fegen::FegenType::getAnyType(std::vector elementTypes) { + std::vector p_elemTy; + int i = 0; + std::string name("elementType_"); + auto tyLevel = elementTypes[0].typeLevel; + assert(tyLevel == 2 || tyLevel == 3); + auto tyty = tyLevel == 2 ? fegen::FegenType::getMetaTemplateType() + : fegen::FegenType::getMetaType(); + for (auto &ty : elementTypes) { + assert(ty.typeLevel == tyLevel); + p_elemTy.push_back(fegen::FegenValue::get(tyty, name + std::to_string(i), + fegen::FegenRightValue::get(ty))); + i++; + } + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, p_elemTy, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_ANY), tyLevel); +} + +fegen::FegenType fegen::FegenType::getIntegerTemplate() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 2); +} + +fegen::FegenType fegen::FegenType::getFloatPointTemplate() { + return fegen::FegenType( + fegen::FegenType::TypeKind::CPP, {}, + fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 2); +} + +fegen::FegenType +fegen::FegenType::getInstanceType(fegen::FegenTypeDefination *typeDefination, + std::vector parameters) { + return fegen::FegenType(fegen::FegenType::TypeKind::CPP, parameters, + typeDefination, 3); +} +fegen::FegenType +fegen::FegenType::getTemplateType(fegen::FegenTypeDefination *typeDefination) { + return fegen::FegenType(fegen::FegenType::TypeKind::CPP, {}, typeDefination, + 2); +} + +// class FegenTypeDefination +fegen::FegenTypeDefination::FegenTypeDefination( + std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome) + : dialectName(std::move(dialectName)), name(std::move(name)), + parameters(std::move(parameters)), ctx(ctx), ifCustome(ifCustome) {} + +fegen::FegenTypeDefination * +fegen::FegenTypeDefination::get(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome) { + return new fegen::FegenTypeDefination(std::move(dialectName), std::move(name), + std::move(parameters), ctx, ifCustome); +} + +std::string fegen::FegenTypeDefination::getDialectName() { + return this->dialectName; +} + +void fegen::FegenTypeDefination::setDialectName(std::string name) { + this->dialectName = name; +} + +std::string fegen::FegenTypeDefination::getName() { return this->name; } + +std::string fegen::FegenTypeDefination::getMnemonic() { + if (this->mnemonic.empty()) { + this->mnemonic = this->name; + std::transform(this->mnemonic.begin(), this->mnemonic.end(), + this->mnemonic.begin(), ::tolower); + } + return this->mnemonic; +} + +void fegen::FegenTypeDefination::setName(std::string name) { + this->name = name; +} + +const std::vector & +fegen::FegenTypeDefination::getParameters() { + return this->parameters; +} + +fegen::FegenParser::TypeDefinationDeclContext * +fegen::FegenTypeDefination::getCtx() { + return this->ctx; +} + +void fegen::FegenTypeDefination::setCtx( + FegenParser::TypeDefinationDeclContext *ctx) { + this->ctx = ctx; +} + +bool fegen::FegenTypeDefination::isCustome() { return this->ifCustome; } + +// class Expression + +fegen::FegenRightValue::Expression::Expression(bool ifTerminal, + LiteralKind kind, + FegenType &exprTy, + bool isConstexpr) + : ifTerminal(ifTerminal), kind(kind), exprType(exprTy), + ifConstexpr(isConstexpr) {} + +bool fegen::FegenRightValue::Expression::isTerminal() { + return this->ifTerminal; +} + +fegen::FegenRightValue::LiteralKind +fegen::FegenRightValue::Expression::getKind() { + return this->kind; +} + +std::any fegen::FegenRightValue::Expression::getContent() { + if (this->ifTerminal) { + auto tPtr = + dynamic_cast(this); + return tPtr->content; + } else { + return dynamic_cast(this); + ; + } +} + +bool fegen::FegenRightValue::Expression::isConstexpr() { + return this->ifConstexpr; +} + +// class ExpressionNode + +fegen::FegenRightValue::ExpressionNode::ExpressionNode( + std::vector params, + std::variant + op, + FegenType &exprTy, bool ifConstexpr) + : Expression(false, fegen::FegenRightValue::LiteralKind::EXPRESSION, exprTy, + ifConstexpr), + op(op), params(params) {} + +fegen::FegenRightValue::ExpressionNode::~ExpressionNode() { + for (auto p : this->params) { + delete p; + } +} + +std::string fegen::FegenRightValue::ExpressionNode::toString() { + // TODO: toString + return "todo: fegen::FegenRightValue::ExpressionNode::toString"; +} + +inline bool isBinaryOperator(fegen::FegenOperator &op) { + switch (op) { + case fegen::FegenOperator::NEG: + case fegen::FegenOperator::NOT: + return false; + default: + return true; + } +} + +inline std::string OperatorToString(fegen::FegenOperator &op) { + switch (op) { + case fegen::FegenOperator::ADD: + return "+"; + case fegen::FegenOperator::SUB: + return "-"; + case fegen::FegenOperator::MUL: + return "*"; + case fegen::FegenOperator::DIV: + return "/"; + default: + std::cerr << "unsupproted operator." << std::endl; + exit(0); + } +} + +std::string fegen::FegenRightValue::ExpressionNode::toStringForTypedef() { + assert(false); + std::cerr << "error type." << std::endl; + exit(0); +} + +std::string fegen::FegenRightValue::ExpressionNode::toStringForOpdef() { + assert(false); + std::cerr << "error type." << std::endl; + exit(0); +} + +std::any fegen::FegenRightValue::ExpressionNode::getContent() { return this; } + +fegen::FegenRightValue::ExpressionNode * +fegen::FegenRightValue::ExpressionNode::binaryOperation( + fegen::FegenRightValue::Expression *lhs, + fegen::FegenRightValue::Expression *rhs, FegenOperator op) { + // TODO: infer type kind: cpp, attribute, or operand + FegenType resTy = fegen::inferenceType({lhs, rhs}, op); + return new fegen::FegenRightValue::ExpressionNode( + {lhs, rhs}, op, resTy, (lhs->isConstexpr() && rhs->isConstexpr())); +} + +fegen::FegenRightValue::ExpressionNode * +fegen::FegenRightValue::ExpressionNode::unaryOperation( + fegen::FegenRightValue::Expression *v, FegenOperator op) { + // TODO: infer type kind: cpp, attribute, or operand + FegenType resTy = fegen::inferenceType({v}, op); + return new fegen::FegenRightValue::ExpressionNode({v}, op, resTy, + v->isConstexpr()); +} + +// class ExpressionTerminal +fegen::FegenRightValue::ExpressionTerminal::ExpressionTerminal( + primLiteralType c, fegen::FegenRightValue::LiteralKind kind, + FegenType exprTy, bool ifConstexpr) + : Expression(true, kind, exprTy, ifConstexpr), content(c) {} + +fegen::FegenRightValue::ExpressionTerminal::~ExpressionTerminal() { + if (this->kind == fegen::FegenRightValue::LiteralKind::VECTOR) { + auto &v = std::get>(this->content); + for (auto p : v) { + delete p; + } + } +} + +std::string fegen::FegenRightValue::ExpressionTerminal::toString() { + // TODO: toString + return "todo: fegen::FegenRightValue::ExpressionTerminal::toString"; +} + +std::string fegen::FegenRightValue::ExpressionTerminal::toStringForTypedef() { + assert(this->isConstexpr()); + switch (this->kind) { + case fegen::FegenRightValue::LiteralKind::TYPE: { + auto ty = std::get(this->content); + return ty.toStringForTypedef(); + } + case fegen::FegenRightValue::LiteralKind::VECTOR: { + std::string res; + res.append("["); + auto exprs = std::get>(this->content); + for (size_t i = 0; i <= exprs.size() - 1; i++) { + res.append(exprs[i]->toStringForTypedef()); + if (i != exprs.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; + } + default: { + std::cerr << "unsupport expression" << std::endl; + exit(0); + } + } +} + +std::string fegen::FegenRightValue::ExpressionTerminal::toStringForOpdef() { + assert(this->isConstexpr()); + switch (this->kind) { + case fegen::FegenRightValue::LiteralKind::TYPE: { + auto ty = std::get(this->content); + return ty.toStringForOpdef(); + } + case fegen::FegenRightValue::LiteralKind::VECTOR: { + std::string res; + res.append("["); + auto exprs = std::get>(this->content); + for (size_t i = 0; i <= exprs.size() - 1; i++) { + res.append(exprs[i]->toStringForOpdef()); + if (i != exprs.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; + } + default: { + assert(false); + std::cerr << "unsupport expression" << std::endl; + exit(0); + } + } +} + +std::any fegen::FegenRightValue::ExpressionTerminal::getContent() { + switch (this->kind) { + case fegen::FegenRightValue::LiteralKind::INT: + return std::get(this->content); + case fegen::FegenRightValue::LiteralKind::FLOAT: + return std::get(this->content); + case fegen::FegenRightValue::LiteralKind::STRING: + return std::get(this->content); + case fegen::FegenRightValue::LiteralKind::TYPE: + return std::get(this->content); + case fegen::FegenRightValue::LiteralKind::VECTOR: + return std::get>(this->content); + case fegen::FegenRightValue::LiteralKind::LEFT_VAR: + return std::get(this->content); + default: + return std::monostate(); + } +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(std::monostate content) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::MONOSTATE, + fegen::FegenType::getPlaceHolder(), true); +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(int content) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::INT, + fegen::FegenType::getInt32Type(), true); +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(float content) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::FLOAT, + fegen::FegenType::getFloatType(), true); +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(std::string content) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::STRING, + fegen::FegenType::getStringType(), true); +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenType &content) { + bool ifConstexpr = true; + for (auto param : content.getParameters()) { + if (!param->getExpr()->isConstexpr()) { + ifConstexpr = false; + break; + } + } + if (content.getTypeLevel() == 2) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::TYPE, + fegen::FegenType::getMetaTemplateType(), ifConstexpr); + } else if (content.getTypeLevel() == 3) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::TYPE, + fegen::FegenType::getMetaType(), ifConstexpr); + } else { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::TYPE, + fegen::FegenType::getPlaceHolder(), ifConstexpr); + } +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get( + std::vector &content) { + bool ifConstexpr = true; + for (auto p : content) { + if (!p->isConstexpr()) { + ifConstexpr = false; + break; + } + } + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::VECTOR, + fegen::FegenType::getListType(content[0]->exprType), ifConstexpr); +} + +fegen::FegenRightValue::ExpressionTerminal * +fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenValue *content) { + return new fegen::FegenRightValue::ExpressionTerminal( + content, fegen::FegenRightValue::LiteralKind::LEFT_VAR, + content->getType(), content->getExpr()->isConstexpr()); +} + +// class FegenRightValue +fegen::FegenRightValue::FegenRightValue( + fegen::FegenRightValue::Expression *content) + : content(content) {} + +fegen::FegenRightValue::FegenRightValue(const fegen::FegenRightValue &rhs) { + if (rhs.content->isTerminal()) { + auto expr = + dynamic_cast(rhs.content); + this->content = new fegen::FegenRightValue::ExpressionTerminal(*expr); + } else { + auto expr = + dynamic_cast(rhs.content); + this->content = new fegen::FegenRightValue::ExpressionNode(*expr); + } +} + +fegen::FegenRightValue::FegenRightValue(fegen::FegenRightValue &&rhs) { + this->content = rhs.content; + rhs.content = nullptr; +} + +fegen::FegenRightValue::LiteralKind fegen::FegenRightValue::getKind() { + return this->content->getKind(); +} + +std::string fegen::FegenRightValue::toString() { + return this->content->toString(); +} + +std::string fegen::FegenRightValue::toStringForTypedef() { + return this->content->toStringForTypedef(); +} + +std::string fegen::FegenRightValue::toStringForOpdef() { + return this->content->toStringForOpdef(); +} + +std::any fegen::FegenRightValue::getContent() { + return this->content->getContent(); +} + +fegen::FegenRightValue::Expression *fegen::FegenRightValue::getExpr() { + return this->content; +} + +fegen::FegenRightValue fegen::FegenRightValue::get() { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(std::monostate())); +} + +fegen::FegenRightValue fegen::FegenRightValue::get(int content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} +fegen::FegenRightValue fegen::FegenRightValue::get(float content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} +fegen::FegenRightValue fegen::FegenRightValue::get(std::string content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} +fegen::FegenRightValue fegen::FegenRightValue::get(fegen::FegenType &content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} + +fegen::FegenRightValue fegen::FegenRightValue::get( + std::vector &content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} +fegen::FegenRightValue fegen::FegenRightValue::get(fegen::FegenValue *content) { + return fegen::FegenRightValue( + fegen::FegenRightValue::ExpressionTerminal::get(content)); +} + +fegen::FegenRightValue +fegen::FegenRightValue::get(fegen::FegenRightValue::Expression *expr) { + assert(expr != nullptr); + return fegen::FegenRightValue(expr); +} + +fegen::FegenRightValue::~FegenRightValue() { delete this->content; } + +// class FegenValue +fegen::FegenValue::FegenValue(fegen::FegenType type, std::string name, + fegen::FegenRightValue content) + : type(std::move(type)), name(std::move(name)), + content(std::move(content)) {} + +fegen::FegenValue::FegenValue(const fegen::FegenValue &rhs) + : type(rhs.type), name(rhs.name), content(rhs.content) {} +fegen::FegenValue::FegenValue(fegen::FegenValue &&rhs) + : type(std::move(rhs.type)), name(std::move(rhs.name)), + content(std::move(rhs.content)) {} + +fegen::FegenValue *fegen::FegenValue::get(fegen::FegenType type, + std::string name, + FegenRightValue content) { + return new fegen::FegenValue(std::move(type), std::move(name), + std::move(content)); +} + +fegen::FegenType &fegen::FegenValue::getType() { return this->type; } + +std::string fegen::FegenValue::getName() { return this->name; } + +fegen::FegenRightValue::LiteralKind fegen::FegenValue::getContentKind() { + return this->content.getKind(); +} + +std::string fegen::FegenValue::getContentString() { + return this->content.toString(); +} + +std::string fegen::FegenValue::getContentStringForTypedef() { + return this->content.toStringForTypedef(); +} + +std::string fegen::FegenValue::getContentStringForOpdef() { + return this->content.toStringForOpdef(); +} + +fegen::FegenRightValue::Expression *fegen::FegenValue::getExpr() { + return this->content.getExpr(); +} + +fegen::FegenRule::FegenRule(std::string content, fegen::FegenNode *src, + antlr4::ParserRuleContext *ctx) + : content(content), src(src), ctx(ctx) {} + +fegen::FegenRule *fegen::FegenRule::get(std::string content, + fegen::FegenNode *src, + antlr4::ParserRuleContext *ctx) { + return new fegen::FegenRule(content, src, ctx); +} + +llvm::StringRef fegen::FegenRule::getContent() { return this->content; } + +bool fegen::FegenRule::addInput(fegen::FegenValue input) { + auto name = input.getName(); + if (this->inputs.count(name) == 0) { + return false; + } + this->inputs.insert({name, new fegen::FegenValue(input)}); + return true; +} + +bool fegen::FegenRule::addReturn(fegen::FegenValue output) { + auto name = output.getName(); + if (this->returns.count(name) == 0) { + return false; + } + this->returns.insert({name, new fegen::FegenValue(output)}); + return true; +} + +void fegen::FegenRule::setSrc(FegenNode *src) { this->src = src; } + +fegen::FegenNode::FegenNode(std::vector &&rules, + antlr4::ParserRuleContext *ctx, + fegen::FegenNode::NodeType ntype) + : rules(rules), ctx(ctx), ntype(ntype) {} + +fegen::FegenNode *fegen::FegenNode::get(std::vector rules, + antlr4::ParserRuleContext *ctx, + fegen::FegenNode::NodeType ntype) { + return new fegen::FegenNode(std::move(rules), ctx, ntype); +} +fegen::FegenNode *fegen::FegenNode::get(antlr4::ParserRuleContext *ctx, + fegen::FegenNode::NodeType ntype) { + std::vector rules; + return new fegen::FegenNode(std::move(rules), ctx, ntype); +} + +void fegen::FegenNode::addFegenRule(fegen::FegenRule *rule) { + this->rules.push_back(rule); +} + +fegen::FegenNode::~FegenNode() { + for (auto rule : this->rules) { + delete rule; + } +} + +void fegen::FegenManager::setModuleName(std::string name) { + this->moduleName = name; +} + +std::string getChildrenText(antlr4::tree::ParseTree *ctx) { + std::string ruleText; + for (auto child : ctx->children) { + if (antlr4::tree::TerminalNode::is(child)) { + ruleText.append(child->getText()).append(" "); + } else { + ruleText.append(getChildrenText(child)).append(" "); + } + } + return ruleText; +} + +fegen::FegenManager::FegenManager() {} + +class Emitter { +private: + std::ostream &stream; + int tabCount; + bool isNewLine; + +public: + Emitter() = delete; + Emitter(Emitter &) = delete; + Emitter(Emitter &&) = delete; + Emitter(std::ostream &stream) + : stream(stream), tabCount(0), isNewLine(true) {} + void tab() { tabCount++; } + + void shiftTab() { + tabCount--; + if (tabCount < 0) { + tabCount = 0; + } + } + + void newLine() { + this->stream << std::endl; + isNewLine = true; + } + + std::ostream &operator<<(std::string s) { + if (this->isNewLine) { + for (int i = 0; i <= (this->tabCount - 1); i++) { + this->stream << '\t'; + } + this->isNewLine = false; + } + this->stream << s; + return this->stream; + } +}; + +void fegen::FegenManager::emitG4() { + std::ofstream fileStream; + fileStream.open(this->moduleName + ".g4"); + Emitter emitter(fileStream); + emitter << "grammar " << this->moduleName << ";"; + emitter.newLine(); + for (auto node_pair : this->nodeMap) { + auto nodeName = node_pair.first; + auto node = node_pair.second; + emitter << nodeName; + emitter.newLine(); + emitter.tab(); + auto ruleCount = node->rules.size(); + if (ruleCount > 0) { + emitter << ": " << getChildrenText(node->rules[0]->ctx); + emitter.newLine(); + for (size_t i = 1; i <= ruleCount - 1; i++) { + emitter << "| " << getChildrenText(node->rules[i]->ctx); + emitter.newLine(); + } + emitter << ";" << std::endl; + } + emitter.shiftTab(); + emitter.newLine(); + } +} + +// TODO: emit to file +void fegen::FegenManager::emitTypeDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Types.td"); + Emitter emitter(fileStream); + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_TYPE_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_TYPE_TD"; + emitter << "\n"; + emitter.newLine(); + + // include files + emitter << "include \"mlir/IR/AttrTypeBase.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Dialect.td\""; + emitter << "\n"; + emitter.newLine(); + // Type class defination + std::string typeClassName = this->moduleName + "Type"; + emitter << "class " << typeClassName + << " traits = []>"; + emitter.tab(); + emitter << ": TypeDef {"; + emitter.newLine(); + emitter << "let mnemonic = typeMnemonic;"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}" << std::endl; + emitter.newLine(); + + for (auto pair : this->typeDefMap) { + auto tyDef = pair.second; + if (!tyDef->isCustome()) { + continue; + } + auto typeName = pair.first; + // head of typedef + emitter << "def " << typeName << " : " << typeClassName << "<\"" << typeName + << "\", \"" << tyDef->getMnemonic() << "\"> {"; + emitter.newLine(); + emitter.tab(); + // summary + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + // description + emitter << "let description = [{ This is generated by buddy fegen. }];"; + emitter.newLine(); + // parameters + emitter << "let parameters = ( ins"; + emitter.newLine(); + emitter.tab(); + for (size_t i = 0; i <= tyDef->getParameters().size() - 1; i++) { + auto param = tyDef->getParameters()[i]; + auto ¶mTy = param->getType(); + auto paramName = param->getName(); + auto paramTyStr = paramTy.toStringForTypedef(); + emitter << paramTyStr << ":" << "$" << paramName; + if (i != tyDef->getParameters().size() - 1) { + emitter << ", "; + } + emitter.newLine(); + } + emitter.shiftTab(); + emitter << ");"; + emitter.newLine(); + // assemblyFormat + // TODO: handle list, Type ... + emitter << "let assemblyFormat = [{"; + emitter.newLine(); + emitter.tab(); + emitter << "`<` "; + for (size_t i = 0; i <= tyDef->getParameters().size() - 1; i++) { + auto param = tyDef->getParameters()[i]; + auto paramName = param->getName(); + emitter << "$" << paramName << " "; + if (i != tyDef->getParameters().size() - 1) { + emitter << "`x` "; + } + } + emitter << "`>`"; + emitter.shiftTab(); + emitter.newLine(); + emitter << "}];"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + } + emitter.shiftTab(); + emitter << "\n"; + emitter << "#endif // " << mn << "_TYPE_TD"; + fileStream.close(); +} + +void fegen::FegenManager::emitOpDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Ops.td"); + Emitter emitter(fileStream); + + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_OPS_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_OPS_TD"; + emitter << "\n"; + emitter.newLine(); + + // TODO: custome include files + // include + emitter << "include \"mlir/IR/BuiltinAttributes.td\""; + emitter.newLine(); + emitter << "include \"mlir/IR/BuiltinTypes.td\""; + emitter.newLine(); + emitter << "include \"mlir/IR/CommonAttrConstraints.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Dialect.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Types.td\""; + emitter.newLine(); + emitter << "\n"; + + // op class defination + std::string classname = this->moduleName + "Op"; + emitter << "class " << classname + << " traits = []>:"; + emitter.newLine(); + emitter.tab(); + emitter << "Op;"; + emitter << "\n"; + emitter.shiftTab(); + emitter.newLine(); + + // op definations + for (auto pair : this->operationMap) { + auto opName = pair.first; + auto opDef = pair.second; + // head of def + emitter << "def " << opName << " : " << classname << "<\"" << opName + << "\", [Pure]> {"; + emitter.newLine(); + emitter.tab(); + // summary and description + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + emitter << "let description = [{This is generated by buddy fegen.}];"; + emitter.newLine(); + // arguments + emitter << "let arguments = ( ins "; + emitter.newLine(); + emitter.tab(); + for (auto param : opDef->getArguments()) { + auto paramTyStr = param->getType().toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + emitter << ");"; + emitter.newLine(); + // results + emitter << "let results = (outs "; + emitter.newLine(); + emitter.tab(); + for (auto param : opDef->getArguments()) { + auto paramTyStr = param->getType().toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + emitter << ");"; + emitter.newLine(); + // end of def + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + } + + // end of file + emitter << "\n"; + emitter << "#endif // " << mn << "_DIALECT_TD"; + fileStream.close(); +} + +void fegen::FegenManager::emitDialectDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Dialect.td"); + Emitter emitter(fileStream); + + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_DIALECT_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_DIALECT_TD"; + emitter << "\n"; + emitter.newLine(); + + // include + emitter << "include \"mlir/IR/OpBase.td\""; + emitter << "\n"; + emitter.newLine(); + + // dialect defination + emitter << "def " << this->moduleName << "_Dialect : Dialect {"; + emitter.newLine(); + emitter.tab(); + emitter << "let name = \"" << this->moduleName << "\";"; + emitter.newLine(); + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + emitter << "let description = [{This is generated by buddy fegen.}];"; + emitter.newLine(); + emitter << "let cppNamespace = \"::mlir::" << this->moduleName << "\";"; + emitter.newLine(); + emitter << "let extraClassDeclaration = [{"; + emitter.newLine(); + emitter.tab(); + emitter << "/// Register all types."; + emitter.newLine(); + emitter << "void registerTypes();"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}];"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + + // end of file + emitter << "#endif // " << mn << "_DIALECT_TD"; + fileStream.close(); +} + +void fegen::FegenManager::emitTdFiles() { + this->emitDialectDefination(); + this->emitTypeDefination(); + this->emitOpDefination(); +} + +void fegen::FegenManager::initbuiltinTypes() { + // placeholder type + auto placeholderTypeDefination = fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_PLACEHOLDER, {}, nullptr, false); + this->typeDefMap.insert({FEGEN_PLACEHOLDER, placeholderTypeDefination}); + + // Type + this->typeDefMap.insert( + {FEGEN_TYPE, fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_TYPE, + {}, nullptr, false)}); + + // TypeTemplate + this->typeDefMap.insert( + {FEGEN_TYPETEMPLATE, + fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_TYPETEMPLATE, {}, + nullptr, false)}); + + // recursive define Integer Type + // Integer>> + auto intTypeDefination = fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_INTEGER, {}, nullptr, false); + auto intType = fegen::FegenType( + fegen::FegenType::TypeKind::CPP, + {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", + fegen::FegenRightValue::get())}, + intTypeDefination, false); + // parameters of Integer is int32(Integer<32>) + intTypeDefination->parameters.push_back( + fegen::FegenValue::get(intType, "size", fegen::FegenRightValue::get())); + this->typeDefMap.insert({FEGEN_INTEGER, intTypeDefination}); + + // FloatPoint + this->typeDefMap.insert( + {FEGEN_FLOATPOINT, + fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_FLOATPOINT, + {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::get())}, + nullptr, false)}); + + // Char + this->typeDefMap.insert( + {FEGEN_CHAR, fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_CHAR, + {}, nullptr, false)}); + + // String + this->typeDefMap.insert( + {FEGEN_STRING, fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_STRING, {}, nullptr, false)}); + + // Vector + this->typeDefMap.insert( + {FEGEN_VECTOR, + fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_VECTOR, + {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::get()), + fegen::FegenValue::get(fegen::FegenType::getMetaType(), + "elementType", + fegen::FegenRightValue::get())}, + nullptr, false)}); + + // List (this should be ahead of Tensor and Any Type defination) + this->typeDefMap.insert( + {FEGEN_LIST, fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_LIST, + {fegen::FegenValue::get(fegen::FegenType::getMetaType(), + "elementType", + fegen::FegenRightValue::get())}, + nullptr, false)}); + + // Tensor + this->typeDefMap.insert( + {FEGEN_TENSOR, + fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_TENSOR, + {fegen::FegenValue::get( + fegen::FegenType::getListType(fegen::FegenType::getInt32Type()), + "shape", fegen::FegenRightValue::get()), + fegen::FegenValue::get(fegen::FegenType::getMetaType(), + "elementType", + fegen::FegenRightValue::get())}, + nullptr, false)}); + + // Optional + this->typeDefMap.insert( + {FEGEN_OPTINAL, fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_OPTINAL, + {fegen::FegenValue::get( + fegen::FegenType::getMetaType(), "elementType", + fegen::FegenRightValue::get())}, + nullptr, false)}); + + // Any + this->typeDefMap.insert( + {FEGEN_ANY, + fegen::FegenTypeDefination::get( + "fegen_builtin", FEGEN_ANY, + {fegen::FegenValue::get( + fegen::FegenType::getListType(fegen::FegenType::getMetaType()), + "elementType", fegen::FegenRightValue::get())}, + nullptr, false)}); +} + +fegen::FegenTypeDefination * +fegen::FegenManager::getTypeDefination(std::string name) { + return this->typeDefMap[name]; +} + +bool fegen::FegenManager::addTypeDefination(fegen::FegenTypeDefination *tyDef) { + if (this->typeDefMap.count(tyDef->name) != 0) { + return false; + } + this->typeDefMap[tyDef->name] = tyDef; + return true; +} + +fegen::FegenOperation * +fegen::FegenManager::getOperationDefination(std::string name) { + return this->operationMap[name]; +} + +bool fegen::FegenManager::addOperationDefination(fegen::FegenOperation *opDef) { + if (this->operationMap.count(opDef->getOpName()) != 0) { + return false; + } + this->operationMap[opDef->getOpName()] = opDef; + return true; +} + +void fegen::FegenManager::addStmtContent(antlr4::ParserRuleContext *ctx, + std::any content) { + this->stmtContentMap.insert({ctx, content}); +} + +fegen::FegenManager &fegen::FegenManager::getManager() { + static fegen::FegenManager fmg; + return fmg; +} + +fegen::FegenManager::~FegenManager() { + // release nodes + for (auto node_pair : this->nodeMap) { + delete node_pair.second; + } +} + +fegen::FegenType +fegen::inferenceType(std::vector operands, + fegen::FegenOperator op) { + // TODO: infer type + return fegen::FegenType::getInt32Type(); +} +namespace fegen{ + +// class StmtVisitor : public FegenParserBaseVisitor{ +// public: +// }; + +} +void fegen::FegenManager::emitBuiltinFunction() { + Emitter emitter(std::cout); + for (auto function_pair : this->functionMap) { + auto functionName = function_pair.first; + auto function = function_pair.second; + auto paraList = function->getInputTypeList(); + emitter << function->getReturnType()->toStringForTypedef() << " " + << functionName << "("; + for (auto para : paraList) { + emitter << para->getContentStringForTypedef() << " " << para->getName(); + if (para != paraList.back()) + emitter << ", "; + } + emitter << "){"; + emitter.newLine(); + emitter.tab(); + // TODO::function body + + emitter.shiftTab(); + emitter.newLine(); + emitter << "}"; + } +} \ No newline at end of file diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 new file mode 100644 index 0000000000..23feea1b2a --- /dev/null +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -0,0 +1,443 @@ +parser grammar FegenParser; + +options { + tokenVocab = FegenLexer; +} + +fegenSpec + : fegenDecl (prequelConstruct | functionDecl | typeDefinationDecl | statement | opDecl | rules)* EOF + ; + +fegenDecl + : FEGEN identifier + ; + +// preprocess declare +prequelConstruct + : BeginInclude INCLUDE_CONTENT* EndInclude + ; + +// function declare +functionDecl + : typeSpec funcName LeftParen funcParams? RightParen statementBlock + ; + +funcName + : identifier + ; + +funcParams + : typeSpec identifier (Comma typeSpec identifier)* + ; + +// typedef declare +typeDefinationDecl + : TYPEDEF typeDefinationName typeDefinationBlock + ; + +typeDefinationName + : identifier + ; + +typeDefinationBlock + : LeftBrace parametersSpec assemblyFormatSpec? RightBrace + ; + +parametersSpec + : PARAMETERS varDecls + ; + +assemblyFormatSpec + : ASSEMBLY_FORMAT LeftBracket StringLiteral RightBracket + ; + +// opdef declare +opDecl + : OPDEF opName opBlock + ; + +opName + : identifier + ; + +opBlock + : LeftBrace argumentSpec? resultSpec? bodySpec? RightBrace + ; + +argumentSpec + : ARGUMENTS varDecls + ; + +resultSpec + : RESULTS varDecls + ; + +bodySpec + : BODY statementBlock + ; + +// rule definations +rules + : ruleSpec+ + ; + +ruleSpec + : parserRuleSpec + | lexerRuleSpec + ; + +parserRuleSpec + : ParserRuleName Colon ruleBlock Semi + ; + +ruleBlock + : ruleAltList + ; + +ruleAltList + : actionAlt (OR actionAlt)* + ; + +actionAlt + : alternative actionBlock? + ; + +alternative + : element* + ; + +element + : atom (ebnfSuffix |) + | ebnf + ; + +atom + : terminalDef + | ruleref + | notSet + ; + +// terminal rule reference +terminalDef + : LexerRuleName + | StringLiteral + ; + +// parser rule reference +ruleref + : ParserRuleName + ; + +notSet + : Tilde setElement + | Tilde blockSet + ; + +setElement + : LexerRuleName + | StringLiteral + | characterRange + ; + +characterRange + : StringLiteral Range StringLiteral + ; + +blockSet + : LeftParen setElement (OR setElement)* RightParen + ; + +ebnfSuffix + : QuestionMark QuestionMark? + | Star QuestionMark? + | Plus QuestionMark? + ; + +ebnf + : block blockSuffix? + ; + +block + : LeftParen altList RightParen + ; + +blockSuffix + : ebnfSuffix + ; + +altList + : alternative (OR alternative)* + ; + +// lexer rule +lexerRuleSpec + : LexerRuleName Colon lexerRuleBlock Semi + ; + +lexerRuleBlock + : lexerAltList + ; + +lexerAltList + : lexerAlt (OR lexerAlt)* + ; + +lexerAlt + : lexerElements lexerCommands? + | + ; + +// E.g., channel(HIDDEN), skip, more, mode(INSIDE), push(INSIDE), pop +lexerCommands + : Arror lexerCommand (Comma lexerCommand)* + ; + +lexerCommand + : lexerCommandName + ; + +lexerCommandName + : identifier + ; + +lexerElements + : lexerElement+ + | + ; + +lexerElement + : lexerAtom ebnfSuffix? + | lexerBlock ebnfSuffix? + ; + +lexerAtom + : characterRange + | terminalDef + | notSet + | Dot + ; + +lexerBlock + : LeftParen lexerAltList RightParen + ; + +// action block declare +actionBlock + : LeftBrace inputsSpec? returnsSpec? actionSpec? RightBrace + ; + +inputsSpec + : INPUTS varDecls + ; + +varDecls + : LeftBracket typeSpec identifier (Comma typeSpec identifier)* RightBracket + ; + +prefixedName + : identifier (Dot identifier)? + ; + +identifier + : LexerRuleName + | ParserRuleName + ; + +returnsSpec + : RETURNS varDecls + ; + +actionSpec + : ACTIONS statementBlock + ; + +statementBlock + : LeftBrace statement* RightBrace + ; + +statement + : varDeclStmt Semi + | assignStmt Semi + | functionCall Semi + | opInvokeStmt Semi + | ifStmt + | forStmt + ; + +varDeclStmt + : typeSpec identifier (Assign expression)? + ; + +assignStmt + : identifier Assign expression + ; + +functionCall + : funcName LeftParen (expression (Comma expression)*)? RightParen + ; + +opInvokeStmt + : opName LeftParen opParams? (Comma opResTypeParams)? RightParen+ + ; + +opParams + : identifier (Comma identifier)* + ; + +opResTypeParams + : typeInstance (Comma typeInstance)* + ; + +ifStmt + : IF LeftParen expression RightParen statementBlock (ELSE IF LeftParen expression RightParen statementBlock)* (ELSE statementBlock)? + ; + +forStmt + : FOR LeftParen assignStmt Semi expression Semi assignStmt RightParen statementBlock + ; + +// expression +expression + : andExpr (Logic_OR andExpr)* + ; + +andExpr + : equExpr (AND equExpr )* + ; + +equExpr + : compareExpr ((EQUAL | NOT_EQUAL) compareExpr)* + ; + +compareExpr + : addExpr ((Less | LessEqual | Greater | GreaterEqual) addExpr)* + ; + +addExpr + : term ((Plus | Minus) term)* + ; + +term + : powerExpr ((Star | Div | MOD) powerExpr)* + ; + +powerExpr + : unaryExpr (StarStar unaryExpr)* + ; + +unaryExpr + : (Minus | Plus | Exclamation)? primaryExpr + ; + +parenSurroundedExpr + : LeftParen expression RightParen + ; + +primaryExpr + : constant + | identifier + | functionCall + | parenSurroundedExpr + | contextMethodInvoke + | typeSpec + | variableAccess + ; + +constant + : numericLiteral + | charLiteral + | boolLiteral + | listLiteral + ; + +// ex: $ctx(0).getText() +contextMethodInvoke + : Dollar identifier LeftParen intLiteral? RightParen Dot functionCall + ; + +variableAccess + : identifier LeftBracket expression RightBracket + ; + +numericLiteral + : intLiteral + | realLiteral + ; + +intLiteral + : UnsignedInt + | (Plus | Minus) UnsignedInt + ; + +realLiteral + : ScienceReal + ; + +charLiteral + : StringLiteral + ; + +boolLiteral + : ConstBoolean + ; + +listLiteral + : LeftBracket (expression (Comma expression)*)? RightBracket + ; + +// type system +typeSpec + : valueKind? typeInstance # typeInstanceSpec + | typeTemplate # typeTemplateSpce + | valueKind? collectType # collectTypeSpec + ; + +valueKind + : CPP + | OPERAND + | ATTRIBUTE + ; + +// 这里的identifier是不是没用? +typeInstance + : typeTemplate (Less typeTemplateParam (Comma typeTemplateParam)* Greater)? + | builtinTypeInstances + | identifier + ; + +typeTemplate + : prefixedName + | builtinTypeTemplate + | TYPE + ; + +typeTemplateParam + : expression + | builtinTypeInstances + ; + +builtinTypeInstances + : BOOL + | INT + | FLOAT + | DOUBLE + | CHAR + | STRING + ; + +builtinTypeTemplate + : INTEGER + | FLOATPOINT + | TENSOR + | VECTOR + ; + +collectType + : collectProtoType Less expression Greater + ; + +collectProtoType + : ANY + | LIST + | OPTIONAL + ; \ No newline at end of file diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp new file mode 100644 index 0000000000..4cb330785a --- /dev/null +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -0,0 +1,11 @@ +#include "FegenVisitor.h" + +bool fegen::checkParams(std::vector &expected, + std::vector &actual) { + return true; +} + +bool fegen::checkListLiteral( + std::vector listLiteral) { + return true; +} \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Lexer.cpp b/frontend/FrontendGen/lib/Lexer.cpp deleted file mode 100644 index 6cce20df8a..0000000000 --- a/frontend/FrontendGen/lib/Lexer.cpp +++ /dev/null @@ -1,198 +0,0 @@ -//====- Lexer.cpp --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Lexer.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; -/// some function about handing characters. -namespace charinfo { -inline bool isASCLL(char ch) { return static_cast(ch) <= 127; } - -inline bool isWhitespace(char ch) { - return isASCLL(ch) && (ch == ' ' || ch == '\t' || ch == '\f' || ch == '\v' || - ch == '\r' || ch == '\n'); -} - -inline bool isIdentifierHead(char ch) { - return isASCLL(ch) && - (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')); -} - -inline bool isDigit(char ch) { return isASCLL(ch) && (ch >= '0' && ch <= '9'); } - -inline bool isIdentifierBody(char ch) { - return isIdentifierHead(ch) || isDigit(ch); -} -} // namespace charinfo - -/// Add keyword to keywordmap. -void KeyWordManager::addKeyWord(llvm::StringRef name, tokenKinds kind) { - keywordMap.insert(std::make_pair(name, kind)); -} - -/// A function add all keywords. -void KeyWordManager::addKeyWords() { -#define KEYWORD(NAME, FLAG) addKeyWord(#NAME, tokenKinds::kw_##NAME); -#include "Token.def" -} - -/// Determine if a string is a keyword. -tokenKinds KeyWordManager::getKeyWord(llvm::StringRef name, tokenKinds kind) { - auto result = keywordMap.find(name); - if (result != keywordMap.end()) - return result->second; - return kind; -} - -bool Token::is(tokenKinds kind) { return kind == tokenKind; } - -llvm::SMLoc Token::getLocation() { return llvm::SMLoc::getFromPointer(start); } -//// Get next token. -void Lexer::next(Token &token) { - // Skip whitespace. - while (*curPtr && charinfo::isWhitespace(*curPtr)) - curPtr++; - if (!*curPtr) { - token.setTokenKind(tokenKinds::eof); - return; - } - // Get identifier. - if (charinfo::isIdentifierHead(*curPtr)) { - identifier(token); - return; - } else if (charinfo::isDigit(*curPtr)) { - number(token); - return; - } else if (*curPtr == ';') { - formToken(token, curPtr + 1, tokenKinds::semi); - return; - } else if (*curPtr == ':') { - formToken(token, curPtr + 1, tokenKinds::colon); - return; - } else if (*curPtr == '\'') { - formToken(token, curPtr + 1, tokenKinds::apostrophe); - return; - } else if (*curPtr == '(') { - formToken(token, curPtr + 1, tokenKinds::parentheseOpen); - return; - } else if (*curPtr == ')') { - formToken(token, curPtr + 1, tokenKinds::parentheseClose); - return; - } else if (*curPtr == '*') { - formToken(token, curPtr + 1, tokenKinds::asterisk); - return; - } else if (*curPtr == '?') { - formToken(token, curPtr + 1, tokenKinds::questionMark); - return; - } else if (*curPtr == '+') { - formToken(token, curPtr + 1, tokenKinds::plus); - return; - } else if (*curPtr == '=') { - formToken(token, curPtr + 1, tokenKinds::equal); - return; - } else if (*curPtr == '{') { - formToken(token, curPtr + 1, tokenKinds::curlyBlacketOpen); - return; - } else if (*curPtr == '}') { - formToken(token, curPtr + 1, tokenKinds::curlyBlacketClose); - return; - } else if (*curPtr == '$') { - formToken(token, curPtr + 1, tokenKinds::dollar); - return; - } else if (*curPtr == ',') { - formToken(token, curPtr + 1, tokenKinds::comma); - return; - } else if (*curPtr == '<') { - formToken(token, curPtr + 1, tokenKinds::angleBracketOpen); - return; - } else if (*curPtr == '>') { - formToken(token, curPtr + 1, tokenKinds::angleBracketClose); - return; - } else if (*curPtr == '[') { - formToken(token, curPtr + 1, tokenKinds::squareBracketOpen); - return; - } else if (*curPtr == ']') { - formToken(token, curPtr + 1, tokenKinds::squareBracketClose); - return; - } else if (*curPtr == '"') { - formToken(token, curPtr + 1, tokenKinds::doubleQuotationMark); - return; - } - token.tokenKind = tokenKinds::unknown; -} - -void Lexer::identifier(Token &token) { - const char *start = curPtr; - const char *end = curPtr + 1; - while (charinfo::isIdentifierBody(*end)) - ++end; - llvm::StringRef name(start, end - start); - tokenKinds kind = keywordManager.getKeyWord(name, tokenKinds::identifier); - formToken(token, end, kind); -} - -void Lexer::formToken(Token &token, const char *tokenEnd, tokenKinds kind) { - int length = tokenEnd - curPtr; - token.start = curPtr; - token.length = length; - token.tokenKind = kind; - curPtr = tokenEnd; -} - -void Lexer::number(Token &token) { - const char *end = curPtr; - end++; - while (charinfo::isDigit(*end)) - end++; - formToken(token, end, tokenKinds::number); -} -/// Get the corresponding content according to start and end. -llvm::StringRef Lexer::getMarkContent(std::string start, std::string end) { - while (*curPtr && charinfo::isWhitespace(*curPtr)) - curPtr++; - int index = start.find(*curPtr); - if (index == -1) - return llvm::StringRef(); - char s = start[index]; - char e = end[index]; - const char *endPtr = curPtr + 1; - int number = 1; - if (s == e) - while (*endPtr != e) - endPtr++; - else - while (number) { - if (*endPtr == s) - number++; - if (*endPtr == e) - number--; - if (number) - endPtr++; - } - endPtr++; - llvm::StringRef content(curPtr, endPtr - curPtr); - curPtr = endPtr; - return content; -} -/// Get the corresponding content according to statr and ch. -llvm::StringRef Lexer::getEndChContent(const char *start, char ch) { - const char *endPtr = curPtr; - while (*endPtr != ch) - endPtr++; - endPtr++; - curPtr = endPtr; - return llvm::StringRef(start, endPtr - start); -} diff --git a/frontend/FrontendGen/lib/Parser.cpp b/frontend/FrontendGen/lib/Parser.cpp deleted file mode 100644 index 152462fce1..0000000000 --- a/frontend/FrontendGen/lib/Parser.cpp +++ /dev/null @@ -1,403 +0,0 @@ -//====- Parser.cpp -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Parser.h" -#include "AST.h" -#include "Lexer.h" -#include "Sema.h" -#include "unistd.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; - -void Parser::advance() { lexer.next(token); } - -void Parser::lookToken() { - while (token.getKind() != tokenKinds::eof) { - llvm::outs() << token.getContent() << '\n'; - llvm::outs() << "token type:" << token.getTokenName() << '\n'; - advance(); - } -} - -/// If current token's kind is expected kind, get next token. -/// If not, an error is reported. -bool Parser::consume(tokenKinds expectTok) { - if (token.is(expectTok)) { - advance(); - return true; - } - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - tokenNameMap[expectTok], token.getTokenName()); - return false; -} - -/// If current token's kind is expected kind, get next token. -/// If not, do nothing. -bool Parser::consumeNoAdvance(tokenKinds expectTok) { - if (token.is(expectTok)) - return true; - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - tokenNameMap[expectTok], token.getTokenName()); - return false; -} - -/// Parser the file, and return a Module, it store all information -/// to generate code. -Module *Parser::parser() { - Module *module = new Module(); - compilEngine(module); - return module; -} - -/// Parse keyword op, dialect and rule. -void Parser::compilEngine(Module *module) { - // rules store all rule ast. - std::vector rules; - // A file can only store one dialect. - Dialect *dialect = nullptr; - // ops store all op. - std::vector ops; - while (token.getKind() != tokenKinds::eof) { - if (token.is(tokenKinds::kw_rule)) { - advance(); - if (!consumeNoAdvance(tokenKinds::identifier)) - return; - Rule *rule = - new Rule(token.getContent(), token.getLocation(), AntlrBase::rule); - advance(); - parserRules(rule); - rules.push_back(rule); - consume(tokenKinds::semi); - } else if (token.is(tokenKinds::kw_dialect)) { - advance(); - if (!consumeNoAdvance(tokenKinds::identifier)) - return; - llvm::StringRef defName = token.getContent(); - advance(); - parserDialect(dialect, defName); - } else if (token.is(tokenKinds::kw_op)) { - advance(); - if (!parserOp(ops, token.getContent())) { - action.actOnModule(module, rules, dialect, ops); - return; - } - } else { - lexer.getDiagnostic().report( - token.getLocation(), DiagnosticEngine::err_expected, - "keyword rule, dialect or op", token.getTokenName()); - action.actOnModule(module, rules, dialect, ops); - return; - } - } - action.actOnModule(module, rules, dialect, ops); -} - -/// Parser the rule and fill nodes of rule ast. -void Parser::parserRules(Rule *rule) { - if (!consumeNoAdvance(tokenKinds::colon)) - return; - // A rule contains many generative. - std::vector generators; - while (token.getKind() != tokenKinds::semi && - token.getKind() == tokenKinds::colon) { - advance(); - GeneratorAndOthers *generatorAndOthers = new GeneratorAndOthers(); - parserGenerator(generatorAndOthers); - generators.push_back(generatorAndOthers); - if (!token.is(tokenKinds::colon) && !token.is(tokenKinds::semi)) { - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - "colon or semi", token.getTokenName()); - return; - } - } - - // Fill the rule ast. - action.actOnRule(rule, generators); -} - -/// Parser a generator and fill a node in generator. -void Parser::parserGenerator(GeneratorAndOthers *generatorAndOthers) { - while (token.is(tokenKinds::identifier) || token.is(tokenKinds::apostrophe) || - token.is(tokenKinds::plus) || token.is(tokenKinds::asterisk) || - token.is(tokenKinds::parentheseOpen) || - token.is(tokenKinds::parentheseClose) || - token.is(tokenKinds::questionMark) || - token.is(tokenKinds::curlyBlacketOpen)) { - if (token.is(tokenKinds::identifier)) - parserIdentifier(generatorAndOthers); - else if (token.is(tokenKinds::apostrophe)) - parserTerminator(generatorAndOthers); - else if (token.is(tokenKinds::curlyBlacketOpen)) - parserCurlyBracketOpen(generatorAndOthers); - else - parserPBExpression(generatorAndOthers); - } -} - -void Parser::parserCurlyBracketOpen(GeneratorAndOthers *generatorAndOthers) { - advance(); - llvm::SMLoc location = token.getLocation(); - if (token.getContent() == "builder") { - llvm::SmallVector builderNames; - llvm::SmallVector builderIdxs; - advance(); - if (!consume(tokenKinds::equal)) - return; - while (token.is(identifier)) { - int index; - if ((index = token.getContent().find('_')) == -1) - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_builder_fail); - llvm::StringRef builderOpName = token.getContent().substr(0, index); - std::string opBulderIdx = - token.getContent() - .substr(index + 1, token.getContent().size() - index) - .str(); - builderNames.push_back(builderOpName); - builderIdxs.push_back(std::stoi(opBulderIdx)); - advance(); - if (token.is(tokenKinds::comma)) - advance(); - } - generatorAndOthers->setbuilderNames(builderNames); - generatorAndOthers->setbuilderIdxs(builderIdxs); - } else { - lexer.getDiagnostic().report(location, - DiagnosticEngine::err_only_supported_builder); - return; - } - - consume(tokenKinds::curlyBlacketClose); -} - -/// Check if the identifier is a terminator. -AntlrBase::baseKind Parser::getAntlrBaseKind(llvm::StringRef name) { - if (terminators.isTerminator(name)) - return AntlrBase::baseKind::terminator; - return AntlrBase::baseKind::rule; -} - -/// processing the identifier, get the identifier's kind which stores -/// in the ast. -void Parser::parserIdentifier(GeneratorAndOthers *generatorAndOthers) { - AntlrBase::baseKind baseKind = getAntlrBaseKind(token.getContent()); - AntlrBase *r = nullptr; - if (baseKind == AntlrBase::baseKind::rule) - r = new Rule(token.getContent(), token.getLocation(), baseKind); - else if (baseKind == AntlrBase::AntlrBase::terminator) - r = new Terminator(token.getContent(), token.getLocation(), baseKind); - generatorAndOthers->getGenerator().push_back(r); - advance(); -} - -/// We support user-defined terminator.For example, we can write a 'terminator' -/// in a rule. -void Parser::parserTerminator(GeneratorAndOthers *generatorAndOthers) { - advance(); - AntlrBase *terminator = new Terminator( - token.getContent(), token.getLocation(), AntlrBase::terminator); - generatorAndOthers->getGenerator().push_back(terminator); - terminators.addCustomTerminators(token.getContent()); - advance(); - consume(tokenKinds::apostrophe); -} - -void Parser::parserPBExpression(GeneratorAndOthers *generatorAndOthers) { - AntlrBase *r = new Terminator(token.getContent(), token.getLocation(), - AntlrBase::pbexpression); - generatorAndOthers->getGenerator().push_back(r); - advance(); -} -/// Parser dialect keyword and fill all information in the dialect. -void Parser::parserDialect(Dialect *&dialect, llvm::StringRef defName) { - dialect = new Dialect(); - llvm::StringRef name; - llvm::StringRef cppNamespace; - while (token.is(tokenKinds::colon)) { - advance(); - if (token.getContent().str() == "name") { - advance(); - consumeNoAdvance(tokenKinds::equal); - name = lexer.getMarkContent("\"", "\""); - advance(); - } else if (token.getContent().str() == "cppNamespace") { - advance(); - consumeNoAdvance(tokenKinds::equal); - cppNamespace = lexer.getMarkContent("\"", "\""); - advance(); - } - } - action.actOnDialect(dialect, defName, name, cppNamespace); - advance(); -} - -/// Parser op keyword and fill all information in the ops. -bool Parser::parserOp(std::vector &ops, llvm::StringRef opName) { - DAG *arguments = nullptr; - DAG *results = nullptr; - std::vector builders; - advance(); - while (token.is(tokenKinds::colon)) { - advance(); - if (token.getContent() == "arguments") { - advance(); - if (!consumeNoAdvance(tokenKinds::equal)) - return false; - parserDAG(arguments); - advance(); - } else if (token.getContent() == "results") { - advance(); - if (!consumeNoAdvance(tokenKinds::equal)) - return false; - parserDAG(results); - advance(); - } else if (token.getContent() == "builders") { - advance(); - if (!consume(tokenKinds::equal)) - return false; - parserBuilders(builders); - advance(); - } else { - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_not_supported_element, - token.getContent()); - return false; - } - } - if (!consume(tokenKinds::semi)) { - llvm::outs() << token.getContent(); - return false; - } - // Fill all information in the ops. - action.actOnOps(ops, opName, arguments, results, builders); - return true; -} - -/// parser DAG structure and fill all information in the arguments. -void Parser::parserDAG(DAG *&arguments) { - DAG dag; - advance(); - consume(tokenKinds::parentheseOpen); - llvm::StringRef dagOperator = token.getContent(); - advance(); - while (token.is(tokenKinds::identifier) || - token.is(tokenKinds::doubleQuotationMark)) { - int number = 0; - llvm::StringRef operandName; - llvm::StringRef operand; - llvm::StringRef value; - // If the operand provides a default value. - if (token.getContent() == "CArg") { - parserCArg(operand, value); - } else if (token.getContent() == "AnyTypeOf") { - const char *start = token.getContent().data(); - advance(); - if (!consumeNoAdvance(tokenKinds::angleBracketOpen)) - return; - operand = llvm::StringRef( - start, - 9 + lexer.getEndChContent(token.getContent().data(), '>').size()); - advance(); - } else if (token.is(tokenKinds::doubleQuotationMark)) { - // If the operand's type is cpp type. - operand = lexer.getEndChContent(token.getContent().data(), '"'); - advance(); - } else { - // If the operand's type is TableGen type. - operand = token.getContent(); - advance(); - if (token.is(tokenKinds::angleBracketOpen)) { - number++; - advance(); - if (token.is(tokenKinds::squareBracketOpen)) { - advance(); - number++; - } - llvm::StringRef type = token.getContent(); - advance(); - if (token.is(tokenKinds::squareBracketClose)) { - advance(); - number++; - } - consume(tokenKinds::angleBracketClose); - number++; - operand = llvm::StringRef(operand.data(), - operand.size() + number + type.size()); - } - } - // If operand is named. - if (token.is(tokenKinds::colon)) { - advance(); - advance(); - operandName = token.getContent(); - advance(); - } - dag.addOperand(operand, operandName); - if (!value.empty()) - dag.setValue(operand, value); - if (token.is(tokenKinds::comma)) - advance(); - } - dag.setDagOperatpr(dagOperator); - consumeNoAdvance(tokenKinds::parentheseClose); - // fill all information in the arguments. - action.actOnDag(arguments, dag); -} - -/// Parser opBuilder in the op. -void Parser::parserBuilders(std::vector &builders) { - if (!consume(tokenKinds::squareBracketOpen)) - return; - while (token.getContent() == "OpBuilder") { - DAG *dag = nullptr; - llvm::StringRef code; - advance(); - if (!consumeNoAdvance(tokenKinds::angleBracketOpen)) - return; - // Parser DAG. - parserDAG(dag); - advance(); - if (token.is(tokenKinds::comma)) { - // Parser code. - parserCode(code); - advance(); - } - if (!consume(tokenKinds::angleBracketClose)) - return; - Builder *builder = new Builder(dag, code); - builders.push_back(builder); - if (token.is(tokenKinds::comma)) - advance(); - } - consumeNoAdvance(tokenKinds::squareBracketClose); -} - -void Parser::parserCode(llvm::StringRef &code) { - code = lexer.getMarkContent("[", "]"); -} - -void Parser::parserCArg(llvm::StringRef &operand, llvm::StringRef &value) { - advance(); - consumeNoAdvance(tokenKinds::angleBracketOpen); - operand = lexer.getMarkContent("\"", "\""); - advance(); - value = lexer.getMarkContent("\"", "\""); - advance(); - consume(tokenKinds::angleBracketClose); -} diff --git a/frontend/FrontendGen/lib/Scope.cpp b/frontend/FrontendGen/lib/Scope.cpp new file mode 100644 index 0000000000..ef0a1dbe72 --- /dev/null +++ b/frontend/FrontendGen/lib/Scope.cpp @@ -0,0 +1,117 @@ +#include "Scope.h" + +// SymbolTable +template void fegen::SymbolTable::add(std::string name, T *e) { + this->table.insert({name, e}); +} + +template T *fegen::SymbolTable::get(std::string name) { + return this->table[name]; +} + +template bool fegen::SymbolTable::exist(std::string name) { + return (this->table.count(name) > 0); +} + +template fegen::SymbolTable::~SymbolTable() { + for (auto pair : this->table) { + delete pair.second; + } +} + +// FegenScope +fegen::FegenScope::FegenScope(unsigned int scopeId, + fegen::FegenScope *parentScope) + : scopeId(scopeId), parentScope(parentScope) {} + +fegen::FegenTypeDefination *fegen::FegenScope::findTypeDef(std::string name) { + return this->typeTable.get(name); +} + +void fegen::FegenScope::addTypeDef(FegenTypeDefination *tyDef) { + this->typeTable.add(tyDef->getName(), tyDef); +} + +bool fegen::FegenScope::isExistTypeDef(std::string name) { + return this->typeTable.exist(name); +} + +fegen::FegenValue *fegen::FegenScope::findVar(std::string name) { + return this->varTable.get(name); +} + +void fegen::FegenScope::addVar(fegen::FegenValue *var) { + this->varTable.add(var->getName(), var); +} + +bool fegen::FegenScope::isExistVar(std::string name) { + return this->varTable.exist(name); +} + +fegen::ScopeStack::ScopeStack() : count(1) { + this->globalScope = new fegen::FegenScope(0, nullptr); + this->currentScope = this->globalScope; + this->scopeStack.push(this->globalScope); + this->scopes.push_back(this->globalScope); +} + +fegen::ScopeStack::~ScopeStack() { + for (auto scope : this->scopes) { + delete scope; + } +} + +fegen::ScopeStack &fegen::ScopeStack::getScopeStack() { + static fegen::ScopeStack sstack; + return sstack; +} + +void fegen::ScopeStack::pushScope() { + auto newScope = new fegen::FegenScope(this->count++, this->currentScope); + this->scopeStack.push(newScope); + this->scopes.push_back(newScope); + this->currentScope = newScope; +} + +void fegen::ScopeStack::popScope() { + this->scopeStack.pop(); + this->currentScope = this->scopeStack.top(); +} +bool fegen::ScopeStack::attemptAddVar(fegen::FegenValue *var) { + if (this->currentScope->isExistVar(var->getName())) { + return false; + } + this->currentScope->addVar(var); + return true; +} + +fegen::FegenValue *fegen::ScopeStack::attemptFindVar(std::string name) { + auto p = this->currentScope; + while (p != nullptr) { + if (p->isExistVar(name)) { + return p->findVar(name); + } + p = p->parentScope; + } + return nullptr; +} + +bool fegen::ScopeStack::attemptAddTypeDef(fegen::FegenTypeDefination *tyDef) { + if (this->currentScope->isExistTypeDef(tyDef->getName())) { + return false; + } + this->currentScope->addTypeDef(tyDef); + return true; +} + +fegen::FegenTypeDefination * +fegen::ScopeStack::attemptFindTypeDef(std::string name) { + auto p = this->currentScope; + while (p != nullptr) { + if (p->isExistTypeDef(name)) { + return p->findTypeDef(name); + } + p = p->parentScope; + } + return nullptr; +} \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Sema.cpp b/frontend/FrontendGen/lib/Sema.cpp deleted file mode 100644 index 00cc005e2a..0000000000 --- a/frontend/FrontendGen/lib/Sema.cpp +++ /dev/null @@ -1,54 +0,0 @@ -//====- Sema.cpp ---------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Sema.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; - -/// Set Module's nodes. -void Sema::actOnModule(Module *module, std::vector &rules, - Dialect *&dialect, std::vector &ops) { - module->setRules(rules); - module->seDialect(dialect); - module->setOps(ops); -} -/// Set Rule's node. -void Sema::actOnRule(Rule *rule, - std::vector &generators) { - rule->setGenerators(generators); -} - -/// Set Dialect's nodes. -void Sema::actOnDialect(Dialect *dialect, llvm::StringRef defName, - llvm::StringRef name, llvm::StringRef cppNamespace) { - dialect->setDefName(defName); - dialect->setName(name); - dialect->setCppNamespace(cppNamespace); -} - -/// Make a op and make it in the ops. -void Sema::actOnOps(std::vector &ops, llvm::StringRef opName, - DAG *arguments, DAG *results, - std::vector &builders) { - Op *op = new Op(); - op->setOpName(opName); - op->setArguments(arguments); - op->setResults(results); - op->setBuilders(builders); - ops.push_back(op); -} - -void Sema::actOnDag(DAG *&arguments, DAG &dag) { arguments = new DAG(dag); } From 93f69eb741f3f3408c5e302ee52296eb568ebce0 Mon Sep 17 00:00:00 2001 From: chh Date: Mon, 8 Jul 2024 22:33:31 +0800 Subject: [PATCH 02/17] [FrontendGen] Restructure class Expression. --- frontend/FrontendGen/include/FegenManager.h | 227 ++++-- frontend/FrontendGen/include/FegenVisitor.h | 454 ++++++----- frontend/FrontendGen/lib/FegenManager.cpp | 818 +++++++++++++------- frontend/FrontendGen/lib/FegenVisitor.cpp | 3 +- 4 files changed, 963 insertions(+), 539 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index c5623e76a7..505cf7a81a 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,8 @@ #define FEGEN_OPTINAL "Optional" #define FEGEN_ANY "Any" +#define FEGEN_NOT_IMPLEMENTED_ERROR false + namespace fegen { class FegenType; @@ -73,10 +76,10 @@ class FegenFunction { std::vector inputTypeList, FegenType *returnType = nullptr); ~FegenFunction() = default; - std::string getName(); - std::vector &getInputTypeList(); - FegenValue *getInputTypeList(size_t i); - FegenType *getReturnType(); + std::string getName(); + std::vector &getInputTypeList(); + FegenValue *getInputTypeList(size_t i); + FegenType *getReturnType(); }; class FegenValue; @@ -147,6 +150,8 @@ class FegenType { std::string toStringForTypedef(); // for generating op def td file. std::string toStringForOpdef(); + // for generating cpp type kind. + std::string toStringForCppKind(); static bool isSameType(FegenType *type1, FegenType *type2); ~FegenType(); // placeholder @@ -249,10 +254,16 @@ class FegenRightValue { STRING, TYPE, VECTOR, - EXPRESSION, - LEFT_VAR + LEFT_VAR, + FUNC_CALL, + OPERATION_CALL, + OPERATOR_CALL }; - + struct ExpressionNode; + struct FunctionCall; + struct OperationCall; + struct OperatorCall; + struct ExpressionTerminal; struct Expression { bool ifTerminal; LiteralKind kind; @@ -265,86 +276,175 @@ class FegenRightValue { virtual std::string toString() = 0; virtual std::string toStringForTypedef() = 0; virtual std::string toStringForOpdef() = 0; + virtual std::string toStringForCppKind() = 0; LiteralKind getKind(); + FegenType &getType(); virtual std::any getContent() = 0; virtual bool isConstexpr(); + + /// @brief operate lhs and rhs using binary operator. + static std::shared_ptr + binaryOperation(std::shared_ptr lhs, + std::shared_ptr rhs, FegenOperator op); + /// @brief operate expr using unary operator + static std::shared_ptr + unaryOperation(std::shared_ptr, FegenOperator); + + // TODO: callFunction + static std::shared_ptr + callFunction(std::vector>, FegenFunction *); + + // TODO: callOperation + static std::shared_ptr + callOperation(std::vector>, FegenOperation *); + + static std::shared_ptr getPlaceHolder(); + static std::shared_ptr getInteger(long long int, + size_t size = 32); + static std::shared_ptr getFloatPoint(long double, + size_t size = 32); + static std::shared_ptr getString(std::string); + static std::shared_ptr getType(FegenType &); + static std::shared_ptr + getList(std::vector> &); + static std::shared_ptr + getLeftValue(fegen::FegenValue *); }; struct ExpressionNode : public Expression { - using opType = - std::variant; - opType op; - std::vector params; - ExpressionNode(std::vector, opType, FegenType &, bool); - ExpressionNode(ExpressionNode &) = default; - ~ExpressionNode(); + ExpressionNode(LiteralKind, FegenType, bool); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; - virtual std::any getContent() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override = 0; + }; - /// @brief operate lhs and rhs using binary operator. - static ExpressionNode *binaryOperation(Expression *lhs, Expression *rhs, - FegenOperator op); - /// @brief operate expr using unary operator - static ExpressionNode *unaryOperation(Expression *, FegenOperator); + struct FunctionCall : public ExpressionNode { + FegenFunction *func; + std::vector> params; + FunctionCall(FegenFunction *, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; + }; - // TODO: callFunction - static ExpressionNode *callFunction(std::vector, - FegenFunction *); + struct OperationCall : public ExpressionNode { + FegenOperation *op; + std::vector> params; + OperationCall(FegenOperation *, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; + }; - // TODO: callOperation - static ExpressionNode *callOperation(std::vector, - FegenOperation *); + struct OperatorCall : public ExpressionNode { + FegenOperator op; + std::vector> params; + OperatorCall(FegenOperator, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; }; struct ExpressionTerminal : public Expression { - // monostate, int literal, float literal, string literal, type literal, list - // literal, reference of variable - using primLiteralType = - std::variant, FegenValue *>; - primLiteralType content; - ExpressionTerminal(primLiteralType, LiteralKind, FegenType, bool); - ExpressionTerminal(ExpressionTerminal &) = default; - ~ExpressionTerminal(); + ExpressionTerminal(LiteralKind, FegenType, bool); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override = 0; + }; + + struct PlaceHolder : public ExpressionTerminal { + PlaceHolder(); + virtual std::any getContent() override; + virtual std::string toString() override; + }; + + struct IntegerLiteral : public ExpressionTerminal { + size_t size; + long long int content; + // size = 32 + IntegerLiteral(int content); + IntegerLiteral(long long int content, size_t size); + virtual std::any getContent() override; + virtual std::string toString() override; + }; + + struct FloatPointLiteral : public ExpressionTerminal { + size_t size; + long double content; + FloatPointLiteral(long double content, size_t size); + virtual std::any getContent() override; + virtual std::string toString() override; + }; + + struct StringLiteral : public ExpressionTerminal { + std::string content; + StringLiteral(std::string content); + virtual std::any getContent() override; + virtual std::string toString() override; + }; + + struct TypeLiteral : public ExpressionTerminal { + FegenType content; + TypeLiteral(FegenType &content); + virtual std::any getContent() override; virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + }; + + struct ListLiteral : public ExpressionTerminal { + std::vector> content; + ListLiteral(std::vector> &content); virtual std::any getContent() override; - static ExpressionTerminal *get(std::monostate); - static ExpressionTerminal *get(int); - static ExpressionTerminal *get(float); - static ExpressionTerminal *get(std::string); - static ExpressionTerminal *get(FegenType &); - static ExpressionTerminal *get(std::vector &); - static ExpressionTerminal *get(fegen::FegenValue *); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + }; + + struct LeftValue : public ExpressionTerminal { + FegenValue *content; + LeftValue(FegenValue *content); + virtual std::any getContent() override; + virtual std::string toString() override; }; public: - FegenRightValue(Expression *content); - FegenRightValue(const FegenRightValue &); - FegenRightValue(FegenRightValue &&); - FegenRightValue::LiteralKind getKind(); + FegenRightValue(std::shared_ptr); + FegenRightValue(const FegenRightValue &) = default; + FegenRightValue(FegenRightValue &&) = default; + FegenRightValue::LiteralKind getLiteralKind(); std::string toString(); std::string toStringForTypedef(); std::string toStringForOpdef(); + std::string toStringForCppKind(); std::any getContent(); - Expression *getExpr(); - - static FegenRightValue get(); - static FegenRightValue get(int content); - static FegenRightValue get(float content); - static FegenRightValue get(std::string content); - static FegenRightValue get(FegenType &content); - // list - static FegenRightValue get(std::vector &content); - static FegenRightValue get(fegen::FegenValue *content); - static FegenRightValue get(Expression *expr); - ~FegenRightValue(); + FegenType &getType(); + std::shared_ptr getExpr(); + + static FegenRightValue getPlaceHolder(); + static FegenRightValue getInteger(long long int content, size_t size = 32); + static FegenRightValue getFloatPoint(long double content, size_t size = 32); + static FegenRightValue getString(std::string content); + static FegenRightValue getType(FegenType &content); + static FegenRightValue + getList(std::vector> &content); + static FegenRightValue getLeftValue(fegen::FegenValue *content); + static FegenRightValue getByExpr(std::shared_ptr expr); + ~FegenRightValue() = default; private: - Expression *content; + std::shared_ptr content; }; class FegenValue { @@ -374,7 +474,8 @@ class FegenValue { std::string getContentString(); std::string getContentStringForTypedef(); std::string getContentStringForOpdef(); - FegenRightValue::Expression *getExpr(); + std::string getContentStringForCppKind(); + std::shared_ptr getExpr(); ~FegenValue() = default; }; @@ -434,7 +535,6 @@ class FegenManager { friend class FegenVisitor; private: - // ScopeStack &sstack; FegenManager(); FegenManager(const FegenManager &) = delete; const FegenManager &operator=(const FegenManager &) = delete; @@ -474,8 +574,9 @@ class FegenManager { void emitBuiltinFunction(); }; -FegenType inferenceType(std::vector, - FegenOperator); +FegenType + inferenceType(std::vector>, + FegenOperator); } // namespace fegen diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 82e384dc1a..842db7203a 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -29,7 +29,9 @@ bool checkParams(std::vector &expected, std::vector &actual); /// @brief check if the type of elements in list are correct. -bool checkListLiteral(std::vector listLiteral); +bool checkListLiteral( + std::vector> + &listLiteral); class FegenVisitor : public FegenParserBaseVisitor { private: @@ -184,8 +186,8 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i <= varCount - 1; i++) { auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); auto varName = ctx->identifier(i)->getText(); - auto var = - fegen::FegenValue::get(ty, varName, fegen::FegenRightValue::get()); + auto var = fegen::FegenValue::get( + ty, varName, fegen::FegenRightValue::getPlaceHolder()); valueList.push_back(var); } @@ -271,12 +273,14 @@ class FegenVisitor : public FegenParserBaseVisitor { if (ctx->builtinTypeInstances()) { auto ty = std::any_cast( this->visit(ctx->builtinTypeInstances())); - return fegen::FegenValue::get(ty, "param", fegen::FegenRightValue::get()); + return fegen::FegenValue::get(ty, "param", + fegen::FegenRightValue::getPlaceHolder()); } else { - auto expr = std::any_cast( - this->visit(ctx->expression())); + auto expr = + std::any_cast>( + this->visit(ctx->expression())); return fegen::FegenValue::get(expr->exprType, "expression_tmp", - fegen::FegenRightValue(expr)); + fegen::FegenRightValue::getByExpr(expr)); } } @@ -352,59 +356,67 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenType std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->expression())); + auto expr = + std::any_cast>( + this->visit(ctx->expression())); if (ctx->collectProtoType()->ANY()) { std::vector tys; // TODO: reprot error assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::VECTOR); - auto exprs = - std::any_cast>( - expr->getContent()); + auto exprs = std::any_cast< + std::vector>>( + expr->getContent()); for (auto expr : exprs) { auto ty = std::any_cast(expr->getContent()); tys.push_back(ty); } return fegen::FegenType::getAnyType(tys); } else if (ctx->collectProtoType()->LIST()) { + assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::TYPE); auto ty = std::any_cast(expr->getContent()); return fegen::FegenType::getListType(ty); } else { // optional + assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::TYPE); auto ty = std::any_cast(expr->getContent()); return fegen::FegenType::getOptionalType(ty); } } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitExpression(FegenParser::ExpressionContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->andExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->andExpr(0))); for (size_t i = 1; i <= ctx->andExpr().size() - 1; i++) { - auto rhs = std::any_cast( - this->visit(ctx->andExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->andExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::OR); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitAndExpr(FegenParser::AndExprContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->equExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->equExpr(0))); for (size_t i = 1; i <= ctx->equExpr().size() - 1; i++) { - auto rhs = std::any_cast( - this->visit(ctx->equExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->equExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::AND); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitEquExpr(FegenParser::EquExprContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->compareExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->compareExpr(0))); for (size_t i = 1; i <= ctx->compareExpr().size() - 1; i++) { FegenOperator op; if (ctx->children[2 * i - 1]->getText() == "==") { @@ -412,17 +424,19 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::NOT_EQUAL; } - auto rhs = std::any_cast( - this->visit(ctx->compareExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->compareExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitCompareExpr(FegenParser::CompareExprContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->addExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->addExpr(0))); for (size_t i = 1; i <= ctx->addExpr().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -437,17 +451,19 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::GREATER_EQUAL; } - auto rhs = std::any_cast( - this->visit(ctx->addExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->addExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitAddExpr(FegenParser::AddExprContext *ctx) override { auto expr = - std::any_cast(this->visit(ctx->term(0))); + std::any_cast>( + this->visit(ctx->term(0))); for (size_t i = 1; i <= ctx->term().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -456,17 +472,19 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::SUB; } - auto rhs = std::any_cast( - this->visit(ctx->term(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->term(i))); expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitTerm(FegenParser::TermContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->powerExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->powerExpr(0))); for (size_t i = 1; i <= ctx->powerExpr().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -477,33 +495,37 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::MOD; } - auto rhs = std::any_cast( - this->visit(ctx->powerExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->powerExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitPowerExpr(FegenParser::PowerExprContext *ctx) override { - auto expr = std::any_cast( - this->visit(ctx->unaryExpr(0))); + auto expr = + std::any_cast>( + this->visit(ctx->unaryExpr(0))); for (size_t i = 1; i <= ctx->unaryExpr().size() - 1; i++) { - auto rhs = std::any_cast( - this->visit(ctx->unaryExpr(i))); + auto rhs = + std::any_cast>( + this->visit(ctx->unaryExpr(i))); expr = FegenRightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::POWER); } return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitUnaryExpr(FegenParser::UnaryExprContext *ctx) override { if (ctx->children.size() == 1 || ctx->Plus()) { return this->visit(ctx->primaryExpr()); } - auto expr = std::any_cast( - this->visit(ctx->primaryExpr())); + auto expr = + std::any_cast>( + this->visit(ctx->primaryExpr())); FegenOperator op; if (ctx->Minus()) { op = FegenOperator::NEG; @@ -514,26 +536,27 @@ class FegenVisitor : public FegenParserBaseVisitor { return expr; } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitParenSurroundedExpr( FegenParser::ParenSurroundedExprContext *ctx) override { return this->visit(ctx->expression()); } - // return FegenRightValue::Expression* + // return std::shared_ptr std::any visitPrimaryExpr(FegenParser::PrimaryExprContext *ctx) override { if (ctx->identifier()) { auto name = ctx->identifier()->getText(); auto var = this->sstack.attemptFindVar(name); if (var) { - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(var); + return (std::shared_ptr) + fegen::FegenRightValue::ExpressionTerminal::getLeftValue(var); } else { - auto tyDef = this->sstack.attemptFindTypeDef(name); + // TODO + auto tyDef = this->manager.getTypeDefination(name); if (tyDef) { auto tyVar = fegen::FegenType::getTemplateType(tyDef); - return fegen::FegenValue::get(fegen::FegenType::getMetaTemplateType(), - "", fegen::FegenRightValue::get(tyVar)); + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getType(tyVar); } else { // TODO: error report std::cerr << "can not find variable: " << ctx->identifier()->getText() @@ -544,204 +567,221 @@ class FegenVisitor : public FegenParserBaseVisitor { } } else if (ctx->typeSpec()) { auto ty = std::any_cast(this->visit(ctx->typeSpec())); - return (FegenRightValue::Expression *) - FegenRightValue::ExpressionTerminal::get(ty); + return (std::shared_ptr) + FegenRightValue::ExpressionTerminal::getType(ty); } else { // constant, functionCall, parenSurroundedExpr,contextMethodInvoke, // and variableAccess return this->visit(ctx->children[0]); } } - // return ExpressionTerminal* + // return std::shared_ptr std::any visitIntLiteral(FegenParser::IntLiteralContext *ctx) override { - int number = std::stoi(ctx->getText()); - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(number); + long long int number = std::stoi(ctx->getText()); + size_t size = 32; // TODO: Get size of number. + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getInteger(number, size); } - // return ExpressionTerminal* + // return std::shared_ptr std::any visitRealLiteral(FegenParser::RealLiteralContext *ctx) override { - double number = std::stod(ctx->getText()); - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(float(number)); + long double number = std::stod(ctx->getText()); + size_t size = 32; // TODO: Get size of number. + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getFloatPoint(number, size); } - // return ExpressionTerminal* + // return std::shared_ptr std::any visitCharLiteral(FegenParser::CharLiteralContext *ctx) override { std::string s = ctx->getText(); // remove quotation marks std::string strWithoutQuotation = s.substr(1, s.size() - 2); - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(strWithoutQuotation); + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getString(strWithoutQuotation); } - // return ExpressionTerminal* + // return std::shared_ptr std::any visitBoolLiteral(FegenParser::BoolLiteralContext *ctx) override { int content = 0; if (ctx->getText() == "true") { content = 1; } - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(content); + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getInteger(content, 1); } - // return ExpressionTerminal* + // return std::shared_ptr std::any visitListLiteral(FegenParser::ListLiteralContext *ctx) override { - std::vector elements; + std::vector> elements; for (auto exprCtx : ctx->expression()) { - auto expr = std::any_cast( - this->visit(exprCtx)); + auto expr = + std::any_cast>( + this->visit(exprCtx)); elements.push_back(expr); } - return (FegenRightValue::Expression *) - fegen::FegenRightValue::ExpressionTerminal::get(elements); + return (std::shared_ptr) + fegen::FegenRightValue::Expression::getList(elements); } std::any visitActionSpec(FegenParser::ActionSpecContext *ctx) override { return nullptr; } - std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override{ - sstack.pushScope(); - auto returnType = std::any_cast(this->visit(ctx->typeSpec())); - auto functionName = std::any_cast(this->visit(ctx->funcName())); - auto hasfunc = manager.functionMap.find(functionName); - if(hasfunc != manager.functionMap.end()){ - std::cerr << "The function name \" " << functionName - << "\" has already been used. Please use another name." << std::endl; - exit(0); - return nullptr; - } - auto functionParams = std::any_cast>(this->visit(ctx->funcParams())); - this->visit(ctx->statementBlock()); - - fegen::FegenFunction* function = fegen::FegenFunction::get(functionName, functionParams, &returnType); - manager.functionMap.insert(std::pair{functionName, function}); - sstack.popScope(); - return nullptr; + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + sstack.pushScope(); + auto returnType = + std::any_cast(this->visit(ctx->typeSpec())); + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasfunc = manager.functionMap.find(functionName); + if (hasfunc != manager.functionMap.end()) { + std::cerr << "The function name \" " << functionName + << "\" has already been used. Please use another name." + << std::endl; + exit(0); + return nullptr; } + auto functionParams = std::any_cast>( + this->visit(ctx->funcParams())); + this->visit(ctx->statementBlock()); - std::any visitFuncName(FegenParser::FuncNameContext *ctx) override{ - auto functionName = ctx->identifier()->getText(); - return functionName; + fegen::FegenFunction *function = + fegen::FegenFunction::get(functionName, functionParams, &returnType); + manager.functionMap.insert(std::pair{functionName, function}); + sstack.popScope(); + return nullptr; + } + + std::any visitFuncName(FegenParser::FuncNameContext *ctx) override { + auto functionName = ctx->identifier()->getText(); + return functionName; + } + + std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override { + std::vector paramsList = {}; + + for (size_t i = 0; i < ctx->typeSpec().size(); i++) { + auto paramType = + std::any_cast(this->visit(ctx->typeSpec(i))); + auto paramName = ctx->identifier(i)->getText(); + auto param = fegen::FegenValue::get(paramType, paramName, nullptr); + paramsList.push_back(param); + sstack.attemptAddVar(param); } + return paramsList; + } - std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override{ - std::vector paramsList = {}; - - for(size_t i = 0; i < ctx->typeSpec().size(); i++){ - auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); - auto paramName = ctx->identifier(i)->getText(); - auto param = fegen::FegenValue::get(paramType, paramName, nullptr); - paramsList.push_back(param); - sstack.attemptAddVar(param); - } - return paramsList; - } - - std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override{ - auto varType = std::any_cast(this->visit(ctx->typeSpec())); - auto varName = ctx->identifier()->getText(); - fegen::FegenValue* var; - if(ctx->expression()){ - auto varcontent = std::any_cast(this->visit(ctx->expression())); - // TODO: check error - // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ - // std::cerr << "The variabel \" " << varName - // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." << std::endl; - // exit(0); - // return nullptr; - // } - var = fegen::FegenValue::get(varType, varName, varcontent); - } else { - var = fegen::FegenValue::get(varType, varName, nullptr); - } - sstack.attemptAddVar(var); - manager.stmtContentMap.insert(std::pair{ctx, var}); - return var; - } - - std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override{ - auto varName = ctx->identifier()->getText(); - auto varcontent = std::any_cast(this->visit(ctx->expression())); - auto var = sstack.attemptFindVar(varName); - if(!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)){ - std::cerr << "The variabel \" " << varName - << "\" need \"" << var->getType().getTypeName() << " \" type rightvalue." << std::endl; - exit(0); - return nullptr; - } - fegen::FegenValue * stmt = fegen::FegenValue::get(var->getType(), varName, varcontent); - manager.stmtContentMap.insert(std::pair{ctx, stmt}); - - return stmt; - } - - std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override{ - std::vector parasList = {}; - auto functionName =std::any_cast(this->visit(ctx->funcName())); - auto hasFunc = manager.functionMap.at(functionName); - auto paramsNum = ctx->expression().size(); - auto paraList = hasFunc->getInputTypeList(); - if( paramsNum> 0){ - for(size_t i = 0; i < paramsNum; i++){ - auto oprand = std::any_cast(this->visit(ctx->expression(i))); - parasList.push_back(oprand); - } - size_t len1 = paraList.size(); - size_t len2 = parasList.size(); - if(len1 != len2){ - std::cerr << "The function \" " << functionName - << "\" parameter count mismatch." << std::endl; - exit(0); - return nullptr; - } - for(size_t i = 0; i < len1; i++){ - if(!fegen::FegenType::isSameType(¶List[i]->getType(), ¶sList[i]->exprType)){ - std::cerr << "The function \" " << functionName - << "\" parameter" << i << " type mismatch." << std::endl; - exit(0); - return nullptr; - } - } - } - auto returnType = hasFunc->getReturnType(); - fegen::FegenFunction *funcCall = fegen::FegenFunction::get(functionName, paraList, returnType); - manager.stmtContentMap.insert(std::pair{ctx, funcCall}); - return returnType; + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + auto varType = + std::any_cast(this->visit(ctx->typeSpec())); + auto varName = ctx->identifier()->getText(); + fegen::FegenValue *var; + if (ctx->expression()) { + auto varcontent = std::any_cast( + this->visit(ctx->expression())); + // TODO: check error + // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ + // std::cerr << "The variabel \" " << varName + // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." + // << std::endl; exit(0); return nullptr; + // } + var = fegen::FegenValue::get(varType, varName, varcontent); + } else { + var = fegen::FegenValue::get(varType, varName, nullptr); } + sstack.attemptAddVar(var); + manager.stmtContentMap.insert(std::pair{ctx, var}); + return var; + } - std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override{ + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { + auto varName = ctx->identifier()->getText(); + auto varcontent = std::any_cast( + this->visit(ctx->expression())); + auto var = sstack.attemptFindVar(varName); + if (!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)) { + std::cerr << "The variabel \" " << varName << "\" need \"" + << var->getType().getTypeName() << " \" type rightvalue." + << std::endl; + exit(0); + return nullptr; + } + fegen::FegenValue *stmt = + fegen::FegenValue::get(var->getType(), varName, varcontent); + manager.stmtContentMap.insert(std::pair{ctx, stmt}); + + return stmt; + } + + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { + std::vector parasList = {}; + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasFunc = manager.functionMap.at(functionName); + auto paramsNum = ctx->expression().size(); + auto paraList = hasFunc->getInputTypeList(); + if (paramsNum > 0) { + for (size_t i = 0; i < paramsNum; i++) { + auto oprand = std::any_cast( + this->visit(ctx->expression(i))); + parasList.push_back(oprand); + } + size_t len1 = paraList.size(); + size_t len2 = parasList.size(); + if (len1 != len2) { + std::cerr << "The function \" " << functionName + << "\" parameter count mismatch." << std::endl; + exit(0); return nullptr; + } + for (size_t i = 0; i < len1; i++) { + if (!fegen::FegenType::isSameType(¶List[i]->getType(), + ¶sList[i]->exprType)) { + std::cerr << "The function \" " << functionName << "\" parameter" << i + << " type mismatch." << std::endl; + exit(0); + return nullptr; + } + } } + auto returnType = hasFunc->getReturnType(); + fegen::FegenFunction *funcCall = + fegen::FegenFunction::get(functionName, paraList, returnType); + manager.stmtContentMap.insert(std::pair{ctx, funcCall}); + return returnType; + } - std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override{ - sstack.pushScope(); - this->visit(ctx->expression(0)); - this->visit(ctx->statementBlock(0)); - if(ctx->expression().size() > 1){ - for(size_t i = 1; i < ctx->expression().size(); i++){ - this->visit(ctx->expression(i)); - this->visit(ctx->statementBlock(i)); - } - } - if(ctx->statementBlock(ctx->expression().size()+1)) - this->visit(ctx->statementBlock(ctx->expression().size()+1)); - sstack.popScope(); + std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override { + return nullptr; + } - return nullptr; - } + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + sstack.pushScope(); + this->visit(ctx->expression(0)); + this->visit(ctx->statementBlock(0)); + if (ctx->expression().size() > 1) { + for (size_t i = 1; i < ctx->expression().size(); i++) { + this->visit(ctx->expression(i)); + this->visit(ctx->statementBlock(i)); + } + } + if (ctx->statementBlock(ctx->expression().size() + 1)) + this->visit(ctx->statementBlock(ctx->expression().size() + 1)); + sstack.popScope(); - std::any visitForStmt(FegenParser::ForStmtContext *ctx) override{ - sstack.pushScope(); - this->visit(ctx->assignStmt(0)); - this->visit(ctx->expression()); - this->visit(ctx->assignStmt(1)); - this->visit(ctx->statementBlock()); - sstack.popScope(); + return nullptr; + } - return nullptr; - } + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { + sstack.pushScope(); + this->visit(ctx->assignStmt(0)); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(1)); + this->visit(ctx->statementBlock()); + sstack.popScope(); + + return nullptr; + } std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { auto opName = ctx->opName()->getText(); diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 8a0fe23a6c..b2c03059c5 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1,13 +1,17 @@ -#include "FegenParserBaseVisitor.h" #include "FegenManager.h" +#include "FegenParser.h" +#include "FegenParserBaseVisitor.h" #include "Scope.h" #include +#include #include #include #include +#include #include #include #include +#include fegen::FegenFunction::FegenFunction(std::string name, std::vector &&inputTypeList, @@ -20,7 +24,7 @@ fegen::FegenFunction::get(std::string name, FegenType *returnType) { return new fegen::FegenFunction(name, std::move(inputTypeList), returnType); } -std::string fegen::FegenFunction::getName() { this->name; } +std::string fegen::FegenFunction::getName() { return this->name; } std::vector &fegen::FegenFunction::getInputTypeList() { return this->inputTypeList; @@ -228,12 +232,8 @@ std::string fegen::FegenType::toStringForOpdef() { return this->typeDefine->getName(); } else if (typedefName == FEGEN_LIST) { std::string res = "Variadic<"; - for (size_t i = 0; i <= this->parameters.size() - 1; i++) { - res.append(this->parameters[i]->getContentStringForTypedef()); - if (i != this->parameters.size() - 1) { - res.append(", "); - } - } + assert(this->parameters.size() == 1); + res.append(this->parameters[0]->getContentStringForTypedef()); res.append(">"); return res; } else if (typedefName == FEGEN_INTEGER) { @@ -256,6 +256,40 @@ std::string fegen::FegenType::toStringForOpdef() { exit(0); } +std::string fegen::FegenType::toStringForCppKind() { + // handle builtin type instance + auto typeName = this->typeName; + auto typedefName = this->typeDefine->getName(); + if (typedefName == FEGEN_LIST) { + assert(this->parameters.size() == 1); + std::string res = "std::vector<"; + res.append(this->parameters[0]->getContentStringForTypedef()); + res.append(">"); + return res; + } else if (typedefName == FEGEN_INTEGER) { + assert(this->parameters.size() == 1); + if (typeName == "int") { + return "int"; + } + int size = this->getParameters(0)->getContent(); + if (size == 64) { + return "long"; + } else if (size == 16) { + return "short"; + } + } else if (typedefName == FEGEN_FLOATPOINT) { + assert(this->parameters.size() == 1); + if (typeName == "float") { + return "float"; + } else if (typeName == "double") { + return "double"; + } + } + std::cerr << "Unsupported type: " << typeName << "in generating cpp type." + << std::endl; + exit(0); +} + fegen::FegenType::~FegenType() { for (auto p : this->parameters) { delete p; @@ -286,7 +320,7 @@ fegen::FegenType fegen::FegenType::getInt32Type() { return fegen::FegenType( fegen::FegenType::TypeKind::CPP, "int", {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); } @@ -294,7 +328,7 @@ fegen::FegenType fegen::FegenType::getFloatType() { return fegen::FegenType( fegen::FegenType::TypeKind::CPP, "float", {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::get(32))}, + fegen::FegenRightValue::getInteger(32))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); } @@ -302,7 +336,7 @@ fegen::FegenType fegen::FegenType::getDoubleType() { return fegen::FegenType( fegen::FegenType::TypeKind::CPP, "double", {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::get(64))}, + fegen::FegenRightValue::getInteger(64))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); } @@ -310,7 +344,7 @@ fegen::FegenType fegen::FegenType::getBoolType() { return fegen::FegenType( fegen::FegenType::TypeKind::CPP, "bool", {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::get(1))}, + fegen::FegenRightValue::getInteger(1))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); } @@ -352,7 +386,7 @@ fegen::FegenType fegen::FegenType::getVectorType(fegen::FegenValue *size, fegen::FegenType::TypeKind::CPP, {size, fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::get(elementType))}, + fegen::FegenRightValue::getType(elementType))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_VECTOR), elementType.typeLevel); } @@ -364,7 +398,7 @@ fegen::FegenType fegen::FegenType::getTensorType(fegen::FegenValue *shape, fegen::FegenType::TypeKind::CPP, {shape, fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::get(elementType))}, + fegen::FegenRightValue::getType(elementType))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_TENSOR), elementType.typeLevel); } @@ -377,7 +411,7 @@ fegen::FegenType fegen::FegenType::getListType(fegen::FegenType elementType) { {fegen::FegenValue::get( elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() : fegen::FegenType::getMetaType(), - "elementType", fegen::FegenRightValue::get(elementType))}, + "elementType", fegen::FegenRightValue::getType(elementType))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_LIST), elementType.typeLevel); } @@ -391,7 +425,7 @@ fegen::FegenType::getOptionalType(fegen::FegenType elementType) { {fegen::FegenValue::get( elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() : fegen::FegenType::getMetaType(), - "elementType", fegen::FegenRightValue::get(elementType))}, + "elementType", fegen::FegenRightValue::getType(elementType))}, fegen::FegenManager::getManager().getTypeDefination(FEGEN_OPTINAL), elementType.typeLevel); } @@ -408,8 +442,8 @@ fegen::FegenType::getAnyType(std::vector elementTypes) { : fegen::FegenType::getMetaType(); for (auto &ty : elementTypes) { assert(ty.typeLevel == tyLevel); - p_elemTy.push_back(fegen::FegenValue::get(tyty, name + std::to_string(i), - fegen::FegenRightValue::get(ty))); + p_elemTy.push_back(fegen::FegenValue::get( + tyty, name + std::to_string(i), fegen::FegenRightValue::getType(ty))); i++; } return fegen::FegenType( @@ -516,42 +550,93 @@ fegen::FegenRightValue::Expression::getKind() { return this->kind; } -std::any fegen::FegenRightValue::Expression::getContent() { - if (this->ifTerminal) { - auto tPtr = - dynamic_cast(this); - return tPtr->content; - } else { - return dynamic_cast(this); - ; - } +fegen::FegenType &fegen::FegenRightValue::Expression::getType() { + return this->exprType; } bool fegen::FegenRightValue::Expression::isConstexpr() { return this->ifConstexpr; } -// class ExpressionNode +std::shared_ptr +fegen::FegenRightValue::Expression::getPlaceHolder() { + return std::make_shared(); +} -fegen::FegenRightValue::ExpressionNode::ExpressionNode( - std::vector params, - std::variant - op, - FegenType &exprTy, bool ifConstexpr) - : Expression(false, fegen::FegenRightValue::LiteralKind::EXPRESSION, exprTy, - ifConstexpr), - op(op), params(params) {} - -fegen::FegenRightValue::ExpressionNode::~ExpressionNode() { - for (auto p : this->params) { - delete p; - } +std::shared_ptr +fegen::FegenRightValue::Expression::getInteger(long long int content, + size_t size) { + return std::make_shared(content, + size); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::getFloatPoint(long double content, + size_t size) { + return std::make_shared(content, + size); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::getString(std::string content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::getType(fegen::FegenType &content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::getList( + std::vector> &content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::getLeftValue(fegen::FegenValue *content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::FegenRightValue::Expression::binaryOperation( + std::shared_ptr lhs, + std::shared_ptr rhs, FegenOperator op) { + FegenType resTy = fegen::inferenceType({lhs, rhs}, op); + return std::make_shared( + op, std::vector>{ + lhs, rhs}); } +std::shared_ptr +fegen::FegenRightValue::Expression::unaryOperation( + std::shared_ptr v, FegenOperator op) { + FegenType resTy = fegen::inferenceType({v}, op); + return std::make_shared( + op, std::vector>{v}); +} + +// class ExpressionNode + +fegen::FegenRightValue::ExpressionNode::ExpressionNode(LiteralKind kind, + FegenType exprTy, + bool ifConstexpr) + : Expression(false, kind, exprTy, ifConstexpr) {} + std::string fegen::FegenRightValue::ExpressionNode::toString() { - // TODO: toString - return "todo: fegen::FegenRightValue::ExpressionNode::toString"; + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::FegenRightValue::ExpressionNode::toStringForTypedef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::FegenRightValue::ExpressionNode::toStringForOpdef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::FegenRightValue::ExpressionNode::toStringForCppKind() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } inline bool isBinaryOperator(fegen::FegenOperator &op) { @@ -564,246 +649,377 @@ inline bool isBinaryOperator(fegen::FegenOperator &op) { } } -inline std::string OperatorToString(fegen::FegenOperator &op) { - switch (op) { - case fegen::FegenOperator::ADD: - return "+"; - case fegen::FegenOperator::SUB: - return "-"; - case fegen::FegenOperator::MUL: - return "*"; - case fegen::FegenOperator::DIV: - return "/"; - default: - std::cerr << "unsupproted operator." << std::endl; - exit(0); +std::string getCppOperator(fegen::FegenOperator op) { + // switch(op){ + // OR, + // AND, + // EQUAL, + // NOT_EQUAL, + // LESS, + // LESS_EQUAL, + // GREATER, + // GREATER_EQUAL, + // ADD, + // SUB, + // MUL, + // DIV, + // MOD, + // POWER, + // NEG, + // NOT + // } +} + +// std::string res; +// auto opKind = this->op.index(); +// if(opKind == 0){ // function +// auto func = std::get<0>(this->op); +// // res.append(func.) +// // TODO: add FegenFunction methods. +// }else if(opKind == 1) { // operation +// assert(false); +// return res; +// }else{ // operator +// auto op = std::get<2>(this->op); +// if(isBinaryOperator(op)){ +// assert(this->params.size() == 2); +// res.append(this->params[0]->toStringForCppKind()); +// switch(op){ +// case fegen::FegenOperator::ADD:{ +// res.append() +// } +// } +// res.append(this->params[1]->toStringForCppKind()); +// }else{ + +// } +// switch(op) { +// case fegen::FegenOperator::ADD: { + +// } +// } +// } + +// class FunctionCall +inline bool isFuncParamsAllConstant( + std::vector> ¶ms) { + for (auto param : params) { + if (!param->isConstexpr()) { + return false; + } } + return true; } -std::string fegen::FegenRightValue::ExpressionNode::toStringForTypedef() { - assert(false); - std::cerr << "error type." << std::endl; - exit(0); +// TODO: invoke methods of FegenFunction +fegen::FegenRightValue::FunctionCall::FunctionCall( + fegen::FegenFunction *func, + std::vector> params) + : ExpressionNode(fegen::FegenRightValue::LiteralKind::FUNC_CALL, + fegen::FegenType::getInt32Type(), + isFuncParamsAllConstant(params)), + func(func), params(std::move(params)) {} + +std::string fegen::FegenRightValue::FunctionCall::toString() { + return "FunctionCall::toString"; } -std::string fegen::FegenRightValue::ExpressionNode::toStringForOpdef() { - assert(false); - std::cerr << "error type." << std::endl; - exit(0); +std::string fegen::FegenRightValue::FunctionCall::toStringForTypedef() { + return "FunctionCall::toStringForTypedef"; } -std::any fegen::FegenRightValue::ExpressionNode::getContent() { return this; } +std::string fegen::FegenRightValue::FunctionCall::toStringForOpdef() { + return "FunctionCall::toStringForOpdef"; +} -fegen::FegenRightValue::ExpressionNode * -fegen::FegenRightValue::ExpressionNode::binaryOperation( - fegen::FegenRightValue::Expression *lhs, - fegen::FegenRightValue::Expression *rhs, FegenOperator op) { - // TODO: infer type kind: cpp, attribute, or operand - FegenType resTy = fegen::inferenceType({lhs, rhs}, op); - return new fegen::FegenRightValue::ExpressionNode( - {lhs, rhs}, op, resTy, (lhs->isConstexpr() && rhs->isConstexpr())); +std::string fegen::FegenRightValue::FunctionCall::toStringForCppKind() { + return "FunctionCall::toStringForCppKind"; } -fegen::FegenRightValue::ExpressionNode * -fegen::FegenRightValue::ExpressionNode::unaryOperation( - fegen::FegenRightValue::Expression *v, FegenOperator op) { - // TODO: infer type kind: cpp, attribute, or operand - FegenType resTy = fegen::inferenceType({v}, op); - return new fegen::FegenRightValue::ExpressionNode({v}, op, resTy, - v->isConstexpr()); +std::any fegen::FegenRightValue::FunctionCall::getContent() { return this; } + +// class OperationCall +fegen::FegenRightValue::OperationCall::OperationCall( + fegen::FegenOperation *op, + std::vector> params) + : ExpressionNode(fegen::FegenRightValue::LiteralKind::OPERATION_CALL, + fegen::FegenType::getInt32Type(), + isFuncParamsAllConstant(params)), + op(op), params(std::move(params)) {} + +std::string fegen::FegenRightValue::OperationCall::toString() { + return "OperationCall::toString"; +} + +std::string fegen::FegenRightValue::OperationCall::toStringForTypedef() { + return "OperationCall::toStringForTypedef"; +} + +std::string fegen::FegenRightValue::OperationCall::toStringForOpdef() { + return "OperationCall::toStringForOpdef"; +} + +std::string fegen::FegenRightValue::OperationCall::toStringForCppKind() { + return "OperationCall::toStringForCppKind"; +} + +std::any fegen::FegenRightValue::OperationCall::getContent() { return this; } + +// class OperatorCall +fegen::FegenRightValue::OperatorCall::OperatorCall( + fegen::FegenOperator op, + std::vector> params) + : ExpressionNode(fegen::FegenRightValue::LiteralKind::OPERATION_CALL, + fegen::inferenceType(params, op), + isFuncParamsAllConstant(params)), + op(op), params(std::move(params)) {} + +std::string fegen::FegenRightValue::OperatorCall::toString() { + return "OperatorCall::toString"; +} + +std::string fegen::FegenRightValue::OperatorCall::toStringForTypedef() { + return "OperatorCall::toStringForTypedef"; } +std::string fegen::FegenRightValue::OperatorCall::toStringForOpdef() { + return "OperatorCall::toStringForOpdef"; +} + +std::string fegen::FegenRightValue::OperatorCall::toStringForCppKind() { + return "OperatorCall::toStringForCppKind"; +} + +std::any fegen::FegenRightValue::OperatorCall::getContent() { return this; } + // class ExpressionTerminal fegen::FegenRightValue::ExpressionTerminal::ExpressionTerminal( - primLiteralType c, fegen::FegenRightValue::LiteralKind kind, - FegenType exprTy, bool ifConstexpr) - : Expression(true, kind, exprTy, ifConstexpr), content(c) {} - -fegen::FegenRightValue::ExpressionTerminal::~ExpressionTerminal() { - if (this->kind == fegen::FegenRightValue::LiteralKind::VECTOR) { - auto &v = std::get>(this->content); - for (auto p : v) { - delete p; - } - } -} + fegen::FegenRightValue::LiteralKind kind, FegenType exprTy, + bool ifConstexpr) + : Expression(true, kind, exprTy, ifConstexpr) {} std::string fegen::FegenRightValue::ExpressionTerminal::toString() { - // TODO: toString - return "todo: fegen::FegenRightValue::ExpressionTerminal::toString"; + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } std::string fegen::FegenRightValue::ExpressionTerminal::toStringForTypedef() { - assert(this->isConstexpr()); - switch (this->kind) { - case fegen::FegenRightValue::LiteralKind::TYPE: { - auto ty = std::get(this->content); - return ty.toStringForTypedef(); - } - case fegen::FegenRightValue::LiteralKind::VECTOR: { - std::string res; - res.append("["); - auto exprs = std::get>(this->content); - for (size_t i = 0; i <= exprs.size() - 1; i++) { - res.append(exprs[i]->toStringForTypedef()); - if (i != exprs.size() - 1) { - res.append(", "); - } - } - res.append("]"); - return res; - } - default: { - std::cerr << "unsupport expression" << std::endl; - exit(0); - } - } + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } std::string fegen::FegenRightValue::ExpressionTerminal::toStringForOpdef() { - assert(this->isConstexpr()); - switch (this->kind) { - case fegen::FegenRightValue::LiteralKind::TYPE: { - auto ty = std::get(this->content); - return ty.toStringForOpdef(); - } - case fegen::FegenRightValue::LiteralKind::VECTOR: { - std::string res; - res.append("["); - auto exprs = std::get>(this->content); - for (size_t i = 0; i <= exprs.size() - 1; i++) { - res.append(exprs[i]->toStringForOpdef()); - if (i != exprs.size() - 1) { - res.append(", "); - } - } - res.append("]"); - return res; - } - default: { - assert(false); - std::cerr << "unsupport expression" << std::endl; - exit(0); - } - } + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::any fegen::FegenRightValue::ExpressionTerminal::getContent() { - switch (this->kind) { - case fegen::FegenRightValue::LiteralKind::INT: - return std::get(this->content); - case fegen::FegenRightValue::LiteralKind::FLOAT: - return std::get(this->content); - case fegen::FegenRightValue::LiteralKind::STRING: - return std::get(this->content); - case fegen::FegenRightValue::LiteralKind::TYPE: - return std::get(this->content); - case fegen::FegenRightValue::LiteralKind::VECTOR: - return std::get>(this->content); - case fegen::FegenRightValue::LiteralKind::LEFT_VAR: - return std::get(this->content); - default: - return std::monostate(); - } +std::string fegen::FegenRightValue::ExpressionTerminal::toStringForCppKind() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +// class PlaceHolder +fegen::FegenRightValue::PlaceHolder::PlaceHolder() + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::MONOSTATE, + fegen::FegenType::getPlaceHolder(), true) {} + +std::any fegen::FegenRightValue::PlaceHolder::getContent() { + return std::monostate(); +} + +std::string fegen::FegenRightValue::PlaceHolder::toString() { return ""; } + +// class IntegerLiteral +fegen::FegenRightValue::IntegerLiteral::IntegerLiteral(int content) + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::INT, + fegen::FegenType::getInt32Type(), true), + content(content) {} + +fegen::FegenRightValue::IntegerLiteral::IntegerLiteral(long long int content, + size_t size) + : ExpressionTerminal( + fegen::FegenRightValue::LiteralKind::INT, + fegen::FegenType::getIntegerType(fegen::FegenValue::get( + fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::getByExpr( + std::make_shared( + size)))), + true), + content(content) {} + +std::any fegen::FegenRightValue::IntegerLiteral::getContent() { + return this->content; } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(std::monostate content) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::MONOSTATE, - fegen::FegenType::getPlaceHolder(), true); +std::string fegen::FegenRightValue::IntegerLiteral::toString() { + return std::to_string(this->content); } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(int content) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::INT, - fegen::FegenType::getInt32Type(), true); +// class FloatPointLiteral +fegen::FegenRightValue::FloatPointLiteral::FloatPointLiteral( + long double content, size_t size) + : ExpressionTerminal( + fegen::FegenRightValue::LiteralKind::FLOAT, + fegen::FegenType::getFloatPointType( + fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", + fegen::FegenRightValue::getInteger(size))), + true), + content(content) {} + +std::any fegen::FegenRightValue::FloatPointLiteral::getContent() { + return this->content; } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(float content) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::FLOAT, - fegen::FegenType::getFloatType(), true); +std::string fegen::FegenRightValue::FloatPointLiteral::toString() { + return std::to_string(this->content); } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(std::string content) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::STRING, - fegen::FegenType::getStringType(), true); +// class StringLiteral +fegen::FegenRightValue::StringLiteral::StringLiteral(std::string content) + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::STRING, + fegen::FegenType::getStringType(), true), + content(content) {} + +std::any fegen::FegenRightValue::StringLiteral::getContent() { + return this->content; } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenType &content) { - bool ifConstexpr = true; +std::string fegen::FegenRightValue::StringLiteral::toString() { + std::string res; + res.append("\""); + res.append(this->content); + res.append("\""); + return res; +} + +// class TypeLiteral + +// Check params of content and return ture if params are all const expr. +inline bool isParamsConstant(fegen::FegenType &content) { for (auto param : content.getParameters()) { if (!param->getExpr()->isConstexpr()) { - ifConstexpr = false; - break; + return false; } } + return true; +} + +// Get type of type literal. +fegen::FegenType getTypeLiteralType(fegen::FegenType &content) { if (content.getTypeLevel() == 2) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::TYPE, - fegen::FegenType::getMetaTemplateType(), ifConstexpr); + return fegen::FegenType::getMetaTemplateType(); } else if (content.getTypeLevel() == 3) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::TYPE, - fegen::FegenType::getMetaType(), ifConstexpr); + return fegen::FegenType::getMetaType(); } else { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::TYPE, - fegen::FegenType::getPlaceHolder(), ifConstexpr); + return fegen::FegenType::getPlaceHolder(); } } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get( - std::vector &content) { - bool ifConstexpr = true; +fegen::FegenRightValue::TypeLiteral::TypeLiteral(fegen::FegenType &content) + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::TYPE, + getTypeLiteralType(content), + isParamsConstant(content)), + content(content) {} + +std::any fegen::FegenRightValue::TypeLiteral::getContent() { + return this->content; +} + +std::string fegen::FegenRightValue::TypeLiteral::toString() { + return this->content.getTypeName(); +} + +std::string fegen::FegenRightValue::TypeLiteral::toStringForTypedef() { + return this->content.toStringForTypedef(); +} + +std::string fegen::FegenRightValue::TypeLiteral::toStringForOpdef() { + return this->content.toStringForOpdef(); +} + +std::string fegen::FegenRightValue::TypeLiteral::toStringForCppKind() { + return this->content.toStringForCppKind(); +} + +// class ExpressionTerminal + +// Return ture if all Expressions in content are all true. +bool isExpressionListConst( + std::vector> &content) { for (auto p : content) { if (!p->isConstexpr()) { - ifConstexpr = false; + return false; break; } } - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::VECTOR, - fegen::FegenType::getListType(content[0]->exprType), ifConstexpr); + return true; } -fegen::FegenRightValue::ExpressionTerminal * -fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenValue *content) { - return new fegen::FegenRightValue::ExpressionTerminal( - content, fegen::FegenRightValue::LiteralKind::LEFT_VAR, - content->getType(), content->getExpr()->isConstexpr()); +fegen::FegenRightValue::ListLiteral::ListLiteral( + std::vector> &content) + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::VECTOR, + content[0]->exprType, isExpressionListConst(content)), + content(content) {} + +std::any fegen::FegenRightValue::ListLiteral::getContent() { + return this->content; } -// class FegenRightValue -fegen::FegenRightValue::FegenRightValue( - fegen::FegenRightValue::Expression *content) - : content(content) {} +std::string fegen::FegenRightValue::ListLiteral::toString() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toString()); + if (i != this->content.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; +} -fegen::FegenRightValue::FegenRightValue(const fegen::FegenRightValue &rhs) { - if (rhs.content->isTerminal()) { - auto expr = - dynamic_cast(rhs.content); - this->content = new fegen::FegenRightValue::ExpressionTerminal(*expr); - } else { - auto expr = - dynamic_cast(rhs.content); - this->content = new fegen::FegenRightValue::ExpressionNode(*expr); +std::string fegen::FegenRightValue::ListLiteral::toStringForTypedef() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toStringForTypedef()); + if (i != this->content.size() - 1) { + res.append(", "); + } } + res.append("]"); + return res; +} + +std::string fegen::FegenRightValue::ListLiteral::toStringForOpdef() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toStringForOpdef()); + if (i != this->content.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; +} + +// class LeftValue +fegen::FegenRightValue::LeftValue::LeftValue(fegen::FegenValue *content) + : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::LEFT_VAR, + content->getType(), content->getExpr()->isConstexpr()), + content(content) {} + +std::any fegen::FegenRightValue::LeftValue::getContent() { + return this->content; } -fegen::FegenRightValue::FegenRightValue(fegen::FegenRightValue &&rhs) { - this->content = rhs.content; - rhs.content = nullptr; +std::string fegen::FegenRightValue::LeftValue::toString() { + return this->content->getName(); } -fegen::FegenRightValue::LiteralKind fegen::FegenRightValue::getKind() { +// class FegenRightValue +fegen::FegenRightValue::FegenRightValue( + std::shared_ptr content) + : content(content) {} + +fegen::FegenRightValue::LiteralKind fegen::FegenRightValue::getLiteralKind() { return this->content->getKind(); } @@ -819,54 +1035,66 @@ std::string fegen::FegenRightValue::toStringForOpdef() { return this->content->toStringForOpdef(); } +std::string fegen::FegenRightValue::toStringForCppKind() { + return this->content->toStringForCppKind(); +} + std::any fegen::FegenRightValue::getContent() { return this->content->getContent(); } -fegen::FegenRightValue::Expression *fegen::FegenRightValue::getExpr() { +fegen::FegenType &fegen::FegenRightValue::getType() { + return this->content->getType(); +} + +std::shared_ptr +fegen::FegenRightValue::getExpr() { return this->content; } -fegen::FegenRightValue fegen::FegenRightValue::get() { +fegen::FegenRightValue fegen::FegenRightValue::getPlaceHolder() { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(std::monostate())); + fegen::FegenRightValue::Expression::getPlaceHolder()); } -fegen::FegenRightValue fegen::FegenRightValue::get(int content) { +fegen::FegenRightValue fegen::FegenRightValue::getInteger(long long int content, + size_t size) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getInteger(content, size)); } -fegen::FegenRightValue fegen::FegenRightValue::get(float content) { + +fegen::FegenRightValue +fegen::FegenRightValue::getFloatPoint(long double content, size_t size) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getFloatPoint(content, size)); } -fegen::FegenRightValue fegen::FegenRightValue::get(std::string content) { +fegen::FegenRightValue fegen::FegenRightValue::getString(std::string content) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getString(content)); } -fegen::FegenRightValue fegen::FegenRightValue::get(fegen::FegenType &content) { +fegen::FegenRightValue +fegen::FegenRightValue::getType(fegen::FegenType &content) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getType(content)); } -fegen::FegenRightValue fegen::FegenRightValue::get( - std::vector &content) { +fegen::FegenRightValue fegen::FegenRightValue::getList( + std::vector> &content) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getList(content)); } -fegen::FegenRightValue fegen::FegenRightValue::get(fegen::FegenValue *content) { +fegen::FegenRightValue +fegen::FegenRightValue::getLeftValue(fegen::FegenValue *content) { return fegen::FegenRightValue( - fegen::FegenRightValue::ExpressionTerminal::get(content)); + fegen::FegenRightValue::Expression::getLeftValue(content)); } -fegen::FegenRightValue -fegen::FegenRightValue::get(fegen::FegenRightValue::Expression *expr) { +fegen::FegenRightValue fegen::FegenRightValue::getByExpr( + std::shared_ptr expr) { assert(expr != nullptr); return fegen::FegenRightValue(expr); } -fegen::FegenRightValue::~FegenRightValue() { delete this->content; } - // class FegenValue fegen::FegenValue::FegenValue(fegen::FegenType type, std::string name, fegen::FegenRightValue content) @@ -891,7 +1119,7 @@ fegen::FegenType &fegen::FegenValue::getType() { return this->type; } std::string fegen::FegenValue::getName() { return this->name; } fegen::FegenRightValue::LiteralKind fegen::FegenValue::getContentKind() { - return this->content.getKind(); + return this->content.getLiteralKind(); } std::string fegen::FegenValue::getContentString() { @@ -906,7 +1134,12 @@ std::string fegen::FegenValue::getContentStringForOpdef() { return this->content.toStringForOpdef(); } -fegen::FegenRightValue::Expression *fegen::FegenValue::getExpr() { +std::string fegen::FegenValue::getContentStringForCppKind() { + return this->content.toStringForCppKind(); +} + +std::shared_ptr +fegen::FegenValue::getExpr() { return this->content.getExpr(); } @@ -986,6 +1219,8 @@ std::string getChildrenText(antlr4::tree::ParseTree *ctx) { fegen::FegenManager::FegenManager() {} +namespace fegen { + class Emitter { private: std::ostream &stream; @@ -1024,10 +1259,56 @@ class Emitter { } }; +class StmtGenerator : FegenParserBaseVisitor { +private: + FegenManager &manager; + Emitter &emitter; + +public: + StmtGenerator(Emitter &emitter) + : manager(FegenManager::getManager()), emitter(emitter) {} + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + auto var = manager.getStmtContent(ctx->identifier()); + switch (var->getType().getTypeKind()) { + case fegen::FegenType::TypeKind::CPP: { + this->emitter << var->getType().toStringForCppKind() << " " + << var->getName(); + if (ctx->expression()) { + auto expr = this->manager.getStmtContent( + ctx->expression()); + this->emitter << " = " << expr->toStringForCppKind(); + } + this->emitter << ";"; + this->emitter.newLine(); + break; + } + case fegen::FegenType::TypeKind::ATTRIBUTE: { + break; + } + case fegen::FegenType::TypeKind::OPERAND: { + break; + } + } + return nullptr; + } + + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override {} + + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override {} + + std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override {} + + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override {} + + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override {} +}; + +} // namespace fegen + void fegen::FegenManager::emitG4() { std::ofstream fileStream; fileStream.open(this->moduleName + ".g4"); - Emitter emitter(fileStream); + fegen::Emitter emitter(fileStream); emitter << "grammar " << this->moduleName << ";"; emitter.newLine(); for (auto node_pair : this->nodeMap) { @@ -1055,7 +1336,7 @@ void fegen::FegenManager::emitG4() { void fegen::FegenManager::emitTypeDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Types.td"); - Emitter emitter(fileStream); + fegen::Emitter emitter(fileStream); // file head std::string mn(this->moduleName); std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); @@ -1151,7 +1432,7 @@ void fegen::FegenManager::emitTypeDefination() { void fegen::FegenManager::emitOpDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Ops.td"); - Emitter emitter(fileStream); + fegen::Emitter emitter(fileStream); // file head std::string mn(this->moduleName); @@ -1242,7 +1523,7 @@ void fegen::FegenManager::emitOpDefination() { void fegen::FegenManager::emitDialectDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Dialect.td"); - Emitter emitter(fileStream); + fegen::Emitter emitter(fileStream); // file head std::string mn(this->moduleName); @@ -1319,11 +1600,11 @@ void fegen::FegenManager::initbuiltinTypes() { auto intType = fegen::FegenType( fegen::FegenType::TypeKind::CPP, {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, intTypeDefination, false); // parameters of Integer is int32(Integer<32>) - intTypeDefination->parameters.push_back( - fegen::FegenValue::get(intType, "size", fegen::FegenRightValue::get())); + intTypeDefination->parameters.push_back(fegen::FegenValue::get( + intType, "size", fegen::FegenRightValue::getPlaceHolder())); this->typeDefMap.insert({FEGEN_INTEGER, intTypeDefination}); // FloatPoint @@ -1332,7 +1613,7 @@ void fegen::FegenManager::initbuiltinTypes() { fegen::FegenTypeDefination::get( "fegen_builtin", FEGEN_FLOATPOINT, {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); // Char @@ -1351,19 +1632,19 @@ void fegen::FegenManager::initbuiltinTypes() { fegen::FegenTypeDefination::get( "fegen_builtin", FEGEN_VECTOR, {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::get()), + fegen::FegenRightValue::getPlaceHolder()), fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); // List (this should be ahead of Tensor and Any Type defination) this->typeDefMap.insert( {FEGEN_LIST, fegen::FegenTypeDefination::get( "fegen_builtin", FEGEN_LIST, - {fegen::FegenValue::get(fegen::FegenType::getMetaType(), - "elementType", - fegen::FegenRightValue::get())}, + {fegen::FegenValue::get( + fegen::FegenType::getMetaType(), "elementType", + fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); // Tensor @@ -1373,10 +1654,10 @@ void fegen::FegenManager::initbuiltinTypes() { "fegen_builtin", FEGEN_TENSOR, {fegen::FegenValue::get( fegen::FegenType::getListType(fegen::FegenType::getInt32Type()), - "shape", fegen::FegenRightValue::get()), + "shape", fegen::FegenRightValue::getPlaceHolder()), fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); // Optional @@ -1385,7 +1666,7 @@ void fegen::FegenManager::initbuiltinTypes() { "fegen_builtin", FEGEN_OPTINAL, {fegen::FegenValue::get( fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::get())}, + fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); // Any @@ -1395,7 +1676,7 @@ void fegen::FegenManager::initbuiltinTypes() { "fegen_builtin", FEGEN_ANY, {fegen::FegenValue::get( fegen::FegenType::getListType(fegen::FegenType::getMetaType()), - "elementType", fegen::FegenRightValue::get())}, + "elementType", fegen::FegenRightValue::getPlaceHolder())}, nullptr, false)}); } @@ -1442,13 +1723,14 @@ fegen::FegenManager::~FegenManager() { } } -fegen::FegenType -fegen::inferenceType(std::vector operands, - fegen::FegenOperator op) { +fegen::FegenType fegen::inferenceType( + std::vector> operands, + fegen::FegenOperator op) { // TODO: infer type return fegen::FegenType::getInt32Type(); } -namespace fegen{ + +namespace fegen { // class StmtVisitor : public FegenParserBaseVisitor{ // public: diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp index 4cb330785a..316b0ae1ad 100644 --- a/frontend/FrontendGen/lib/FegenVisitor.cpp +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -6,6 +6,7 @@ bool fegen::checkParams(std::vector &expected, } bool fegen::checkListLiteral( - std::vector listLiteral) { + std::vector> + &listLiteral) { return true; } \ No newline at end of file From ed4fd9d79e9a903f0184e6efde492cc3b7525877 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Thu, 11 Jul 2024 13:04:50 +0000 Subject: [PATCH 03/17] update fegen --- frontend/FrontendGen/include/FegenVisitor.h | 281 +++++++++++--------- frontend/FrontendGen/lib/FegenManager.cpp | 9 +- frontend/FrontendGen/lib/FegenParser.g4 | 2 +- 3 files changed, 160 insertions(+), 132 deletions(-) diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 82e384dc1a..454fc5ba0e 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -601,147 +601,174 @@ class FegenVisitor : public FegenParserBaseVisitor { return nullptr; } - std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override{ - sstack.pushScope(); - auto returnType = std::any_cast(this->visit(ctx->typeSpec())); - auto functionName = std::any_cast(this->visit(ctx->funcName())); - auto hasfunc = manager.functionMap.find(functionName); - if(hasfunc != manager.functionMap.end()){ - std::cerr << "The function name \" " << functionName - << "\" has already been used. Please use another name." << std::endl; - exit(0); - return nullptr; - } - auto functionParams = std::any_cast>(this->visit(ctx->funcParams())); - this->visit(ctx->statementBlock()); + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + sstack.pushScope(); + auto returnType = + std::any_cast(this->visit(ctx->typeSpec())); + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasfunc = manager.functionMap.find(functionName); + if (hasfunc != manager.functionMap.end()) { + std::cerr << "The function name \" " << functionName + << "\" has already been used. Please use another name." + << std::endl; + exit(0); + return nullptr; + } + auto functionParams = std::any_cast>( + this->visit(ctx->funcParams())); + this->visit(ctx->statementBlock()); - fegen::FegenFunction* function = fegen::FegenFunction::get(functionName, functionParams, &returnType); - manager.functionMap.insert(std::pair{functionName, function}); - sstack.popScope(); - return nullptr; + fegen::FegenFunction *function = + fegen::FegenFunction::get(functionName, functionParams, &returnType); + manager.functionMap.insert(std::pair{functionName, function}); + sstack.popScope(); + return nullptr; + } + + std::any visitFuncName(FegenParser::FuncNameContext *ctx) override { + auto functionName = ctx->identifier()->getText(); + return functionName; + } + + std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override { + std::vector paramsList = {}; + + for (size_t i = 0; i < ctx->typeSpec().size(); i++) { + auto paramType = + std::any_cast(this->visit(ctx->typeSpec(i))); + auto paramName = ctx->identifier(i)->getText(); + auto param = fegen::FegenValue::get(paramType, paramName, nullptr); + paramsList.push_back(param); + sstack.attemptAddVar(param); } + return paramsList; + } - std::any visitFuncName(FegenParser::FuncNameContext *ctx) override{ - auto functionName = ctx->identifier()->getText(); - return functionName; + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + auto varType = + std::any_cast(this->visit(ctx->typeSpec())); + auto varName = ctx->identifier()->getText(); + fegen::FegenValue *var; + if (ctx->expression()) { + auto varContent = std::any_cast( + this->visit(ctx->expression())); + if (!fegen::FegenType::isSameType(&varType, &varContent->exprType)) { + std::cerr << "The variabel \"" << varName << "\" need \"" + << varType.getTypeName() << " \" type rightvalue. But now is " << varContent->exprType.getTypeName() + << std::endl; + exit(0); + return nullptr; + } + var = fegen::FegenValue::get(varType, varName, varContent); + } else { + var = fegen::FegenValue::get(varType, varName, nullptr); } + sstack.attemptAddVar(var); + manager.stmtContentMap.insert(std::pair{ctx, var}); + return var; + } - std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override{ - std::vector paramsList = {}; - - for(size_t i = 0; i < ctx->typeSpec().size(); i++){ - auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); - auto paramName = ctx->identifier(i)->getText(); - auto param = fegen::FegenValue::get(paramType, paramName, nullptr); - paramsList.push_back(param); - sstack.attemptAddVar(param); - } - return paramsList; - } - - std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override{ - auto varType = std::any_cast(this->visit(ctx->typeSpec())); - auto varName = ctx->identifier()->getText(); - fegen::FegenValue* var; - if(ctx->expression()){ - auto varcontent = std::any_cast(this->visit(ctx->expression())); - // TODO: check error - // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ - // std::cerr << "The variabel \" " << varName - // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." << std::endl; - // exit(0); - // return nullptr; - // } - var = fegen::FegenValue::get(varType, varName, varcontent); - } else { - var = fegen::FegenValue::get(varType, varName, nullptr); - } - sstack.attemptAddVar(var); - manager.stmtContentMap.insert(std::pair{ctx, var}); - return var; - } - - std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override{ - auto varName = ctx->identifier()->getText(); - auto varcontent = std::any_cast(this->visit(ctx->expression())); - auto var = sstack.attemptFindVar(varName); - if(!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)){ - std::cerr << "The variabel \" " << varName - << "\" need \"" << var->getType().getTypeName() << " \" type rightvalue." << std::endl; - exit(0); - return nullptr; - } - fegen::FegenValue * stmt = fegen::FegenValue::get(var->getType(), varName, varcontent); - manager.stmtContentMap.insert(std::pair{ctx, stmt}); - - return stmt; - } - - std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override{ - std::vector parasList = {}; - auto functionName =std::any_cast(this->visit(ctx->funcName())); - auto hasFunc = manager.functionMap.at(functionName); - auto paramsNum = ctx->expression().size(); - auto paraList = hasFunc->getInputTypeList(); - if( paramsNum> 0){ - for(size_t i = 0; i < paramsNum; i++){ - auto oprand = std::any_cast(this->visit(ctx->expression(i))); - parasList.push_back(oprand); - } - size_t len1 = paraList.size(); - size_t len2 = parasList.size(); - if(len1 != len2){ - std::cerr << "The function \" " << functionName - << "\" parameter count mismatch." << std::endl; - exit(0); - return nullptr; - } - for(size_t i = 0; i < len1; i++){ - if(!fegen::FegenType::isSameType(¶List[i]->getType(), ¶sList[i]->exprType)){ - std::cerr << "The function \" " << functionName - << "\" parameter" << i << " type mismatch." << std::endl; - exit(0); - return nullptr; - } - } - } - auto returnType = hasFunc->getReturnType(); - fegen::FegenFunction *funcCall = fegen::FegenFunction::get(functionName, paraList, returnType); - manager.stmtContentMap.insert(std::pair{ctx, funcCall}); - return returnType; + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { + auto varName = ctx->identifier()->getText(); + auto varcontent = std::any_cast( + this->visit(ctx->expression())); + auto var = sstack.attemptFindVar(varName); + if (!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)) { + std::cerr << "The variabel \" " << varName << "\" need \"" + << var->getType().getTypeName() << " \" type rightvalue." + << std::endl; + exit(0); + return nullptr; } + fegen::FegenValue *stmt = + fegen::FegenValue::get(var->getType(), varName, varcontent); + manager.stmtContentMap.insert(std::pair{ctx, stmt}); - std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override{ + return stmt; + } + + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { + std::vector parasList = {}; + fegen::FegenFunction *function; + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasFunc = manager.functionMap.find(functionName); + if(hasFunc == manager.functionMap.end()){ + std::cerr << "The called function \"" << functionName + << "\" is not exist." << std::endl; + exit(0); return nullptr; } - - std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override{ - sstack.pushScope(); - this->visit(ctx->expression(0)); - this->visit(ctx->statementBlock(0)); - if(ctx->expression().size() > 1){ - for(size_t i = 1; i < ctx->expression().size(); i++){ - this->visit(ctx->expression(i)); - this->visit(ctx->statementBlock(i)); - } + function = hasFunc->second; + auto paramsNum = ctx->expression().size(); + auto paraList = function->getInputTypeList(); + if (paramsNum > 0) { + for (size_t i = 0; i < paramsNum; i++) { + auto oprand = std::any_cast( + this->visit(ctx->expression(i))); + parasList.push_back(oprand); + } + size_t len1 = paraList.size(); + size_t len2 = parasList.size(); + if (len1 != len2) { + std::cerr << "The function \"" << functionName + << "\" parameter count mismatch." << std::endl; + exit(0); + return nullptr; + } + for (size_t i = 0; i < len1; i++) { + if (!fegen::FegenType::isSameType(¶List[i]->getType(), + ¶sList[i]->exprType)) { + std::cerr << "The function \"" << functionName << "\" parameter" << i + << " type mismatch." << std::endl; + exit(0); + return nullptr; } - if(ctx->statementBlock(ctx->expression().size()+1)) - this->visit(ctx->statementBlock(ctx->expression().size()+1)); - sstack.popScope(); + } + } + auto returnType = function->getReturnType(); + fegen::FegenFunction *funcCall = + fegen::FegenFunction::get(functionName, paraList, returnType); + manager.stmtContentMap.insert(std::pair{ctx, funcCall}); + return returnType; + } - return nullptr; - } + std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override { + return nullptr; + } - std::any visitForStmt(FegenParser::ForStmtContext *ctx) override{ - sstack.pushScope(); - this->visit(ctx->assignStmt(0)); - this->visit(ctx->expression()); - this->visit(ctx->assignStmt(1)); - this->visit(ctx->statementBlock()); - sstack.popScope(); + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + sstack.pushScope(); + this->visit(ctx->expression(0)); + this->visit(ctx->statementBlock(0)); + for (size_t i = 1; i <= ctx->expression().size() - 1; i++) { + this->visit(ctx->expression(i)); + this->visit(ctx->statementBlock(i)); + } + if (ctx->statementBlock(ctx->expression().size() + 1)) + this->visit(ctx->statementBlock(ctx->expression().size() + 1)); + sstack.popScope(); - return nullptr; + return nullptr; + } + + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { + sstack.pushScope(); + if (ctx->varDeclStmt()) { + this->visit(ctx->varDeclStmt()); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(0)); + } else { + this->visit(ctx->assignStmt(0)); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(1)); } + this->visit(ctx->statementBlock()); + sstack.popScope(); + + return nullptr; + } std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { auto opName = ctx->opName()->getText(); diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 8a0fe23a6c..29bc9bf993 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1,5 +1,5 @@ -#include "FegenParserBaseVisitor.h" #include "FegenManager.h" +#include "FegenParserBaseVisitor.h" #include "Scope.h" #include #include @@ -20,7 +20,7 @@ fegen::FegenFunction::get(std::string name, FegenType *returnType) { return new fegen::FegenFunction(name, std::move(inputTypeList), returnType); } -std::string fegen::FegenFunction::getName() { this->name; } +std::string fegen::FegenFunction::getName() { return this->name; } std::vector &fegen::FegenFunction::getInputTypeList() { return this->inputTypeList; @@ -778,7 +778,8 @@ fegen::FegenRightValue::ExpressionTerminal * fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenValue *content) { return new fegen::FegenRightValue::ExpressionTerminal( content, fegen::FegenRightValue::LiteralKind::LEFT_VAR, - content->getType(), content->getExpr()->isConstexpr()); + content->getType(), + content->getExpr()->isConstexpr()); } // class FegenRightValue @@ -1448,7 +1449,7 @@ fegen::inferenceType(std::vector operands, // TODO: infer type return fegen::FegenType::getInt32Type(); } -namespace fegen{ +namespace fegen { // class StmtVisitor : public FegenParserBaseVisitor{ // public: diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 index 23feea1b2a..048d104a0e 100644 --- a/frontend/FrontendGen/lib/FegenParser.g4 +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -293,7 +293,7 @@ ifStmt ; forStmt - : FOR LeftParen assignStmt Semi expression Semi assignStmt RightParen statementBlock + : FOR LeftParen (assignStmt | varDeclStmt) Semi expression Semi assignStmt RightParen statementBlock ; // expression From 4a6cf8ce5bc46620745576cff22f14b5e5274c4b Mon Sep 17 00:00:00 2001 From: chh Date: Mon, 15 Jul 2024 20:18:07 +0800 Subject: [PATCH 04/17] [FrontendGen] Fix type error in visitor. --- frontend/FrontendGen/include/FegenVisitor.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 842db7203a..61d467d65e 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -664,7 +664,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); auto paramName = ctx->identifier(i)->getText(); - auto param = fegen::FegenValue::get(paramType, paramName, nullptr); + auto param = fegen::FegenValue::get(paramType, paramName, fegen::FegenRightValue::getPlaceHolder()); paramsList.push_back(param); sstack.attemptAddVar(param); } @@ -677,7 +677,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto varName = ctx->identifier()->getText(); fegen::FegenValue *var; if (ctx->expression()) { - auto varcontent = std::any_cast( + auto varcontent = std::any_cast>( this->visit(ctx->expression())); // TODO: check error // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ @@ -685,9 +685,9 @@ class FegenVisitor : public FegenParserBaseVisitor { // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." // << std::endl; exit(0); return nullptr; // } - var = fegen::FegenValue::get(varType, varName, varcontent); + var = fegen::FegenValue::get(varType, varName, fegen::FegenRightValue::getByExpr(varcontent)); } else { - var = fegen::FegenValue::get(varType, varName, nullptr); + var = fegen::FegenValue::get(varType, varName, fegen::FegenRightValue::getPlaceHolder()); } sstack.attemptAddVar(var); manager.stmtContentMap.insert(std::pair{ctx, var}); @@ -696,7 +696,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { auto varName = ctx->identifier()->getText(); - auto varcontent = std::any_cast( + auto varcontent = std::any_cast>( this->visit(ctx->expression())); auto var = sstack.attemptFindVar(varName); if (!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)) { @@ -707,7 +707,7 @@ class FegenVisitor : public FegenParserBaseVisitor { return nullptr; } fegen::FegenValue *stmt = - fegen::FegenValue::get(var->getType(), varName, varcontent); + fegen::FegenValue::get(var->getType(), varName, fegen::FegenRightValue::getByExpr(varcontent)); manager.stmtContentMap.insert(std::pair{ctx, stmt}); return stmt; From 2c78e0aa2ca835eb0d3634e5b346fd893774499f Mon Sep 17 00:00:00 2001 From: chh Date: Mon, 15 Jul 2024 21:10:00 +0800 Subject: [PATCH 05/17] [FrontendGen] Rename: remove 'Fegen' prefix. --- frontend/FrontendGen/include/FegenManager.h | 324 +++---- frontend/FrontendGen/include/FegenVisitor.h | 293 +++---- frontend/FrontendGen/include/Scope.h | 20 +- frontend/FrontendGen/lib/FegenManager.cpp | 912 ++++++++++---------- frontend/FrontendGen/lib/FegenVisitor.cpp | 6 +- frontend/FrontendGen/lib/Scope.cpp | 16 +- 6 files changed, 791 insertions(+), 780 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 505cf7a81a..7400e96cc8 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -33,9 +33,9 @@ namespace fegen { -class FegenType; -class FegenManager; -class FegenValue; +class Type; +class Manager; +class Value; // binary operation @@ -59,65 +59,65 @@ enum class FegenOperator { }; // user defined function -class FegenFunction { +class Function { private: // cpp function name std::string name; // input object - std::vector inputTypeList; + std::vector inputTypeList; // return type - FegenType *returnType; - explicit FegenFunction(std::string name, - std::vector &&inputTypeList, - FegenType *returnType); + Type *returnType; + explicit Function(std::string name, + std::vector &&inputTypeList, + Type *returnType); public: - static FegenFunction *get(std::string name, - std::vector inputTypeList, - FegenType *returnType = nullptr); - ~FegenFunction() = default; + static Function *get(std::string name, + std::vector inputTypeList, + Type *returnType = nullptr); + ~Function() = default; std::string getName(); - std::vector &getInputTypeList(); - FegenValue *getInputTypeList(size_t i); - FegenType *getReturnType(); + std::vector &getInputTypeList(); + Value *getInputTypeList(size_t i); + Type *getReturnType(); }; -class FegenValue; +class Value; // user defined operation -class FegenOperation { +class Operation { private: std::string dialectName; std::string operationName; // arguments of operation - std::vector arguments; + std::vector arguments; // results of operation - std::vector results; + std::vector results; // operation body context FegenParser::BodySpecContext *ctx; - explicit FegenOperation(std::string dialectName, std::string operationName, - std::vector &&arguments, - std::vector &&results, + explicit Operation(std::string dialectName, std::string operationName, + std::vector &&arguments, + std::vector &&results, FegenParser::BodySpecContext *ctx); public: void setOpName(std::string); std::string getOpName(); - std::vector &getArguments(); - FegenValue *getArguments(size_t i); - std::vector &getResults(); - FegenValue *getResults(size_t i); - static FegenOperation *get(std::string operationName, - std::vector arguments, - std::vector results, + std::vector &getArguments(); + Value *getArguments(size_t i); + std::vector &getResults(); + Value *getResults(size_t i); + static Operation *get(std::string operationName, + std::vector arguments, + std::vector results, FegenParser::BodySpecContext *ctx); - ~FegenOperation() = default; + ~Operation() = default; }; -class FegenTypeDefination; +class TypeDefination; -class FegenType { - friend class FegenValue; +class Type { + friend class Value; public: enum class TypeKind { ATTRIBUTE, OPERAND, CPP }; @@ -125,25 +125,25 @@ class FegenType { private: TypeKind kind; std::string typeName; - std::vector parameters; - FegenTypeDefination *typeDefine; + std::vector parameters; + TypeDefination *typeDefine; int typeLevel; public: - FegenType(TypeKind kind, std::string name, - std::vector parameters, FegenTypeDefination *tyDef, + Type(TypeKind kind, std::string name, + std::vector parameters, TypeDefination *tyDef, int typeLevel); - FegenType(TypeKind kind, std::vector parameters, - FegenTypeDefination *tyDef, int typeLevel); - FegenType(const FegenType &); - FegenType(FegenType &&); + Type(TypeKind kind, std::vector parameters, + TypeDefination *tyDef, int typeLevel); + Type(const Type &); + Type(Type &&); TypeKind getTypeKind(); void setTypeKind(TypeKind kind); - std::vector &getParameters(); - FegenValue *getParameters(size_t i); - void setParameters(std::vector ¶ms); - FegenTypeDefination *getTypeDefination(); - void setTypeDefination(FegenTypeDefination *tyDef); + std::vector &getParameters(); + Value *getParameters(size_t i); + void setParameters(std::vector ¶ms); + TypeDefination *getTypeDefination(); + void setTypeDefination(TypeDefination *tyDef); std::string getTypeName(); int getTypeLevel(); // for generating typedef td file. @@ -152,82 +152,82 @@ class FegenType { std::string toStringForOpdef(); // for generating cpp type kind. std::string toStringForCppKind(); - static bool isSameType(FegenType *type1, FegenType *type2); - ~FegenType(); + static bool isSameType(Type *type1, Type *type2); + ~Type(); // placeholder - static FegenType getPlaceHolder(); + static Type getPlaceHolder(); // Type - static FegenType getMetaType(); + static Type getMetaType(); // TypeTemplate - static FegenType getMetaTemplateType(); + static Type getMetaTemplateType(); // int - static FegenType getInt32Type(); + static Type getInt32Type(); // float - static FegenType getFloatType(); + static Type getFloatType(); // float - static FegenType getDoubleType(); + static Type getDoubleType(); // bool - static FegenType getBoolType(); + static Type getBoolType(); // Integer - static FegenType getIntegerType(FegenValue *size); + static Type getIntegerType(Value *size); // FloatPoint - static FegenType getFloatPointType(FegenValue *size); + static Type getFloatPointType(Value *size); // char - static FegenType getCharType(); + static Type getCharType(); // string - static FegenType getStringType(); + static Type getStringType(); // Vector - static FegenType getVectorType(FegenValue *size, FegenType elementType); + static Type getVectorType(Value *size, Type elementType); // Tensor - static FegenType getTensorType(FegenValue *shape, FegenType elementType); + static Type getTensorType(Value *shape, Type elementType); // List - static FegenType getListType(FegenType elementType); + static Type getListType(Type elementType); // Optional - static FegenType getOptionalType(FegenType elementType); + static Type getOptionalType(Type elementType); // Any - static FegenType getAnyType(std::vector elementTypes); + static Type getAnyType(std::vector elementTypes); - static FegenType getIntegerTemplate(); - static FegenType getFloatPointTemplate(); + static Type getIntegerTemplate(); + static Type getFloatPointTemplate(); - static FegenType getInstanceType(FegenTypeDefination *typeDefination, - std::vector parameters); + static Type getInstanceType(TypeDefination *typeDefination, + std::vector parameters); - static FegenType getTemplateType(FegenTypeDefination *typeDefination); + static Type getTemplateType(TypeDefination *typeDefination); }; -class FegenTypeDefination { - friend class FegenManager; +class TypeDefination { + friend class Manager; private: std::string dialectName; std::string name; - std::vector parameters; + std::vector parameters; FegenParser::TypeDefinationDeclContext *ctx; bool ifCustome; std::string mnemonic; public: - FegenTypeDefination(std::string dialectName, std::string name, - std::vector parameters, + TypeDefination(std::string dialectName, std::string name, + std::vector parameters, FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome); - static FegenTypeDefination *get(std::string dialectName, std::string name, - std::vector parameters, + static TypeDefination *get(std::string dialectName, std::string name, + std::vector parameters, FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome = true); std::string getDialectName(); @@ -235,16 +235,16 @@ class FegenTypeDefination { std::string getName(); std::string getMnemonic(); void setName(std::string); - const std::vector &getParameters(); + const std::vector &getParameters(); FegenParser::TypeDefinationDeclContext *getCtx(); void setCtx(FegenParser::TypeDefinationDeclContext *); bool isCustome(); }; /// @brief Represent right value, and pass by value. -class FegenRightValue { - friend class FegenType; - friend class FegenValue; +class RightValue { + friend class Type; + friend class Value; public: enum class LiteralKind { @@ -267,10 +267,10 @@ class FegenRightValue { struct Expression { bool ifTerminal; LiteralKind kind; - FegenType exprType; + Type exprType; bool isLiteral; bool ifConstexpr; - Expression(bool, LiteralKind, FegenType &, bool); + Expression(bool, LiteralKind, Type &, bool); virtual ~Expression() = default; virtual bool isTerminal(); virtual std::string toString() = 0; @@ -278,7 +278,7 @@ class FegenRightValue { virtual std::string toStringForOpdef() = 0; virtual std::string toStringForCppKind() = 0; LiteralKind getKind(); - FegenType &getType(); + Type &getType(); virtual std::any getContent() = 0; virtual bool isConstexpr(); @@ -292,11 +292,11 @@ class FegenRightValue { // TODO: callFunction static std::shared_ptr - callFunction(std::vector>, FegenFunction *); + callFunction(std::vector>, Function *); // TODO: callOperation static std::shared_ptr - callOperation(std::vector>, FegenOperation *); + callOperation(std::vector>, Operation *); static std::shared_ptr getPlaceHolder(); static std::shared_ptr getInteger(long long int, @@ -304,15 +304,15 @@ class FegenRightValue { static std::shared_ptr getFloatPoint(long double, size_t size = 32); static std::shared_ptr getString(std::string); - static std::shared_ptr getType(FegenType &); + static std::shared_ptr getType(Type &); static std::shared_ptr getList(std::vector> &); static std::shared_ptr - getLeftValue(fegen::FegenValue *); + getLeftValue(fegen::Value *); }; struct ExpressionNode : public Expression { - ExpressionNode(LiteralKind, FegenType, bool); + ExpressionNode(LiteralKind, Type, bool); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; @@ -321,9 +321,9 @@ class FegenRightValue { }; struct FunctionCall : public ExpressionNode { - FegenFunction *func; + Function *func; std::vector> params; - FunctionCall(FegenFunction *, std::vector>); + FunctionCall(Function *, std::vector>); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; @@ -332,9 +332,9 @@ class FegenRightValue { }; struct OperationCall : public ExpressionNode { - FegenOperation *op; + Operation *op; std::vector> params; - OperationCall(FegenOperation *, std::vector>); + OperationCall(Operation *, std::vector>); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; @@ -354,7 +354,7 @@ class FegenRightValue { }; struct ExpressionTerminal : public Expression { - ExpressionTerminal(LiteralKind, FegenType, bool); + ExpressionTerminal(LiteralKind, Type, bool); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; @@ -394,8 +394,8 @@ class FegenRightValue { }; struct TypeLiteral : public ExpressionTerminal { - FegenType content; - TypeLiteral(FegenType &content); + Type content; + TypeLiteral(Type &content); virtual std::any getContent() override; virtual std::string toString() override; virtual std::string toStringForTypedef() override; @@ -413,143 +413,145 @@ class FegenRightValue { }; struct LeftValue : public ExpressionTerminal { - FegenValue *content; - LeftValue(FegenValue *content); + Value *content; + LeftValue(Value *content); virtual std::any getContent() override; virtual std::string toString() override; }; public: - FegenRightValue(std::shared_ptr); - FegenRightValue(const FegenRightValue &) = default; - FegenRightValue(FegenRightValue &&) = default; - FegenRightValue::LiteralKind getLiteralKind(); + RightValue(std::shared_ptr); + RightValue(const RightValue &) = default; + RightValue(RightValue &&) = default; + RightValue &operator=(const RightValue &another) = default; + RightValue::LiteralKind getLiteralKind(); std::string toString(); std::string toStringForTypedef(); std::string toStringForOpdef(); std::string toStringForCppKind(); std::any getContent(); - FegenType &getType(); + Type &getType(); std::shared_ptr getExpr(); - static FegenRightValue getPlaceHolder(); - static FegenRightValue getInteger(long long int content, size_t size = 32); - static FegenRightValue getFloatPoint(long double content, size_t size = 32); - static FegenRightValue getString(std::string content); - static FegenRightValue getType(FegenType &content); - static FegenRightValue + static RightValue getPlaceHolder(); + static RightValue getInteger(long long int content, size_t size = 32); + static RightValue getFloatPoint(long double content, size_t size = 32); + static RightValue getString(std::string content); + static RightValue getType(Type &content); + static RightValue getList(std::vector> &content); - static FegenRightValue getLeftValue(fegen::FegenValue *content); - static FegenRightValue getByExpr(std::shared_ptr expr); - ~FegenRightValue() = default; + static RightValue getLeftValue(fegen::Value *content); + static RightValue getByExpr(std::shared_ptr expr); + ~RightValue() = default; private: std::shared_ptr content; }; -class FegenValue { - friend class FegenType; +class Value { + friend class Type; private: - FegenType type; + Type type; std::string name; - FegenRightValue content; + RightValue content; public: - FegenValue(FegenType type, std::string name, FegenRightValue content); - FegenValue(const FegenValue &rhs); - FegenValue(FegenValue &&rhs); + Value(Type type, std::string name, RightValue content); + Value(const Value &rhs); + Value(Value &&rhs); - static FegenValue *get(FegenType type, std::string name, - FegenRightValue constant); + static Value *get(Type type, std::string name, + RightValue constant); std::string getName(); - FegenType &getType(); + Type &getType(); /// @brief return content of right value, get ExprssionNode* if kind is /// EXPRESSION. template T getContent() { return std::any_cast(this->content.getContent()); } - FegenRightValue::LiteralKind getContentKind(); + void setContent(fegen::RightValue content); + RightValue::LiteralKind getContentKind(); std::string getContentString(); std::string getContentStringForTypedef(); std::string getContentStringForOpdef(); std::string getContentStringForCppKind(); - std::shared_ptr getExpr(); - ~FegenValue() = default; + std::shared_ptr getExpr(); + ~Value() = default; }; -class FegenNode; +class ParserNode; -class FegenRule { - friend class FegenManager; +class ParserRule { + friend class Manager; private: std::string content; // from which node - FegenNode *src; - std::map inputs; - std::map returns; + ParserNode *src; + std::map inputs; + std::map returns; // context in parser tree antlr4::ParserRuleContext *ctx; - explicit FegenRule(std::string content, FegenNode *src, + explicit ParserRule(std::string content, ParserNode *src, antlr4::ParserRuleContext *ctx); public: - static FegenRule *get(std::string content, FegenNode *src, + static ParserRule *get(std::string content, ParserNode *src, antlr4::ParserRuleContext *ctx); llvm::StringRef getContent(); // check and add input value - bool addInput(FegenValue input); + bool addInput(Value input); // check and add return value - bool addReturn(FegenValue output); + bool addReturn(Value output); // set source node - void setSrc(FegenNode *src); + void setSrc(ParserNode *src); }; -class FegenNode { - friend class FegenManager; +class ParserNode { + friend class Manager; public: enum class NodeType { PARSER_RULE, LEXER_RULE }; private: - std::vector rules; + std::vector rules; antlr4::ParserRuleContext *ctx; NodeType ntype; - explicit FegenNode(std::vector &&rules, + explicit ParserNode(std::vector &&rules, antlr4::ParserRuleContext *ctx, NodeType ntype); public: - static FegenNode *get(std::vector rules, + static ParserNode *get(std::vector rules, antlr4::ParserRuleContext *ctx, NodeType ntype); - static FegenNode *get(antlr4::ParserRuleContext *ctx, NodeType ntype); - void addFegenRule(FegenRule *rule); + static ParserNode *get(antlr4::ParserRuleContext *ctx, NodeType ntype); + void addFegenRule(ParserRule *rule); // release rules first - ~FegenNode(); + ~ParserNode(); }; class FegenVisitor; -class FegenManager { +class Manager { friend class FegenVisitor; private: - FegenManager(); - FegenManager(const FegenManager &) = delete; - const FegenManager &operator=(const FegenManager &) = delete; + Manager(); + Manager(const Manager &) = delete; + const Manager &operator=(const Manager &) = delete; // release nodes, type, operation, function - ~FegenManager(); + ~Manager(); void initbuiltinTypes(); public: std::string moduleName; std::vector headFiles; - std::map nodeMap; - llvm::StringMap typeMap; - std::map typeDefMap; - std::map operationMap; - std::map functionMap; + std::map nodeMap; + llvm::StringMap typeMap; + std::map typeDefMap; + std::map operationMap; + std::map functionMap; // stmt contents std::unordered_map stmtContentMap; void addStmtContent(antlr4::ParserRuleContext *ctx, std::any content); @@ -558,14 +560,14 @@ class FegenManager { return std::any_cast(this->stmtContentMap[ctx]); } - static FegenManager &getManager(); + static Manager &getManager(); void setModuleName(std::string name); - FegenTypeDefination *getTypeDefination(std::string name); - bool addTypeDefination(FegenTypeDefination *tyDef); + TypeDefination *getTypeDefination(std::string name); + bool addTypeDefination(TypeDefination *tyDef); - FegenOperation *getOperationDefination(std::string name); - bool addOperationDefination(FegenOperation *opDef); + Operation *getOperationDefination(std::string name); + bool addOperationDefination(Operation *opDef); void emitG4(); void emitTypeDefination(); void emitOpDefination(); @@ -574,8 +576,8 @@ class FegenManager { void emitBuiltinFunction(); }; -FegenType - inferenceType(std::vector>, +Type + inferenceType(std::vector>, FegenOperator); } // namespace fegen diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 61d467d65e..2b2724cce2 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -25,17 +25,17 @@ namespace fegen { /// @param expected expected params. /// @param actual actual params. /// @return true if correct. -bool checkParams(std::vector &expected, - std::vector &actual); +bool checkParams(std::vector &expected, + std::vector &actual); /// @brief check if the type of elements in list are correct. bool checkListLiteral( - std::vector> + std::vector> &listLiteral); class FegenVisitor : public FegenParserBaseVisitor { private: - FegenManager &manager; + Manager &manager; ScopeStack &sstack; public: @@ -45,7 +45,7 @@ class FegenVisitor : public FegenParserBaseVisitor { void emitOpDefination() { this->manager.emitOpDefination(); } FegenVisitor() - : manager(FegenManager::getManager()), + : manager(Manager::getManager()), sstack(ScopeStack::getScopeStack()) { this->manager.initbuiltinTypes(); } @@ -53,7 +53,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitTypeDefinationDecl( FegenParser::TypeDefinationDeclContext *ctx) override { auto typeName = ctx->typeDefinationName()->getText(); - auto tyDef = std::any_cast( + auto tyDef = std::any_cast( this->visit(ctx->typeDefinationBlock())); // set name and ctx for type defination tyDef->setName(typeName); @@ -66,10 +66,10 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenTypeDefination* std::any visitTypeDefinationBlock( FegenParser::TypeDefinationBlockContext *ctx) override { - auto params = std::any_cast>( + auto params = std::any_cast>( this->visit(ctx->parametersSpec())); auto tyDef = - FegenTypeDefination::get(this->manager.moduleName, "", params, nullptr); + TypeDefination::get(this->manager.moduleName, "", params, nullptr); return tyDef; } @@ -81,9 +81,9 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitParserRuleSpec(FegenParser::ParserRuleSpecContext *ctx) override { auto ruleList = - std::any_cast>(this->visit(ctx->ruleBlock())); + std::any_cast>(this->visit(ctx->ruleBlock())); auto ruleNode = - FegenNode::get(ruleList, ctx, FegenNode::NodeType::PARSER_RULE); + ParserNode::get(ruleList, ctx, ParserNode::NodeType::PARSER_RULE); // set source node for rules for (auto rule : ruleList) { rule->setSrc(ruleNode); @@ -93,9 +93,9 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitRuleAltList(FegenParser::RuleAltListContext *ctx) override { - std::vector ruleList; + std::vector ruleList; for (auto alt : ctx->actionAlt()) { - auto fegenRule = std::any_cast(this->visit(alt)); + auto fegenRule = std::any_cast(this->visit(alt)); ruleList.push_back(fegenRule); } return ruleList; @@ -105,11 +105,11 @@ class FegenVisitor : public FegenParserBaseVisitor { auto rawRule = this->visit(ctx->alternative()); if (ctx->actionBlock()) { auto blockValues = std::any_cast< - std::tuple, std::vector>>( + std::tuple, std::vector>>( this->visit(ctx->actionBlock())); auto inputs = std::get<0>(blockValues); auto returns = std::get<1>(blockValues); - auto rule = std::any_cast(rawRule); + auto rule = std::any_cast(rawRule); for (auto in : inputs) { auto flag = rule->addInput(*in); if (!flag) { // TODO: error report @@ -130,15 +130,15 @@ class FegenVisitor : public FegenParserBaseVisitor { // return tuple, vector> std::any visitActionBlock(FegenParser::ActionBlockContext *ctx) override { - std::vector inputs; - std::vector returns; + std::vector inputs; + std::vector returns; if (ctx->inputsSpec()) { - inputs = std::any_cast>( + inputs = std::any_cast>( this->visit(ctx->inputsSpec())); } if (ctx->returnsSpec()) { - returns = std::any_cast>( + returns = std::any_cast>( this->visit(ctx->returnsSpec())); } @@ -152,16 +152,16 @@ class FegenVisitor : public FegenParserBaseVisitor { // TODO: do more check std::any visitAlternative(FegenParser::AlternativeContext *ctx) override { auto content = ctx->getText(); - auto rule = FegenRule::get(content, nullptr, ctx); + auto rule = ParserRule::get(content, nullptr, ctx); return rule; } std::any visitLexerRuleSpec(FegenParser::LexerRuleSpecContext *ctx) override { // create node, get rules from child, and insert to node map - auto ruleList = std::any_cast>( + auto ruleList = std::any_cast>( this->visit(ctx->lexerRuleBlock())); auto ruleNode = - FegenNode::get(ruleList, ctx, FegenNode::NodeType::LEXER_RULE); + ParserNode::get(ruleList, ctx, ParserNode::NodeType::LEXER_RULE); // set source node for rules for (auto rule : ruleList) { rule->setSrc(ruleNode); @@ -171,9 +171,9 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitLexerAltList(FegenParser::LexerAltListContext *ctx) override { - std::vector ruleList; + std::vector ruleList; for (auto alt : ctx->lexerAlt()) { - auto rule = fegen::FegenRule::get(alt->getText(), nullptr, alt); + auto rule = fegen::ParserRule::get(alt->getText(), nullptr, alt); ruleList.push_back(rule); } return ruleList; @@ -182,12 +182,12 @@ class FegenVisitor : public FegenParserBaseVisitor { // return vector std::any visitVarDecls(FegenParser::VarDeclsContext *ctx) override { size_t varCount = ctx->typeSpec().size(); - std::vector valueList; + std::vector valueList; for (size_t i = 0; i <= varCount - 1; i++) { - auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); + auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); auto varName = ctx->identifier(i)->getText(); - auto var = fegen::FegenValue::get( - ty, varName, fegen::FegenRightValue::getPlaceHolder()); + auto var = fegen::Value::get( + ty, varName, fegen::RightValue::getPlaceHolder()); valueList.push_back(var); } @@ -198,22 +198,22 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitTypeInstanceSpec(FegenParser::TypeInstanceSpecContext *ctx) override { auto valueKind = ctx->valueKind() - ? std::any_cast( + ? std::any_cast( this->visit(ctx->valueKind())) - : fegen::FegenType::TypeKind::CPP; + : fegen::Type::TypeKind::CPP; auto typeInst = - std::any_cast(this->visit(ctx->typeInstance())); + std::any_cast(this->visit(ctx->typeInstance())); typeInst.setTypeKind(valueKind); return typeInst; } // return fegen::FegenType::TypeKind std::any visitValueKind(FegenParser::ValueKindContext *ctx) override { - auto kind = fegen::FegenType::TypeKind::ATTRIBUTE; + auto kind = fegen::Type::TypeKind::ATTRIBUTE; if (ctx->CPP()) { - kind = fegen::FegenType::TypeKind::CPP; + kind = fegen::Type::TypeKind::CPP; } else if (ctx->OPERAND()) { - kind = fegen::FegenType::TypeKind::OPERAND; + kind = fegen::Type::TypeKind::OPERAND; } // otherwise: ATTRIBUTE return kind; @@ -224,12 +224,12 @@ class FegenVisitor : public FegenParserBaseVisitor { if (ctx->typeTemplate()) { // typeTemplate (Less typeTemplateParam (Comma // typeTemplateParam)* Greater)? auto typeTeplt = - std::any_cast(this->visit(ctx->typeTemplate())); + std::any_cast(this->visit(ctx->typeTemplate())); // get parameters - std::vector paramList; + std::vector paramList; for (auto paramCtx : ctx->typeTemplateParam()) { auto tepltParams = - std::any_cast(this->visit(paramCtx)); + std::any_cast(this->visit(paramCtx)); paramList.push_back(tepltParams); } @@ -242,15 +242,15 @@ class FegenVisitor : public FegenParserBaseVisitor { } // get FegenType of instance auto typeInst = - FegenType::getInstanceType(typeTeplt.getTypeDefination(), paramList); + Type::getInstanceType(typeTeplt.getTypeDefination(), paramList); return typeInst; } else if (ctx->identifier()) { // identifier auto varName = ctx->identifier()->getText(); auto var = this->sstack.attemptFindVar(varName); if (var) { if (var->getContentKind() == - fegen::FegenRightValue::LiteralKind::TYPE) { - return var->getContent(); + fegen::RightValue::LiteralKind::TYPE) { + return var->getContent(); } else { std::cerr << "variable " << varName << " is not a Type or TypeTemplate." << std::endl; @@ -271,16 +271,16 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitTypeTemplateParam(FegenParser::TypeTemplateParamContext *ctx) override { if (ctx->builtinTypeInstances()) { - auto ty = std::any_cast( + auto ty = std::any_cast( this->visit(ctx->builtinTypeInstances())); - return fegen::FegenValue::get(ty, "param", - fegen::FegenRightValue::getPlaceHolder()); + return fegen::Value::get(ty, "param", + fegen::RightValue::getPlaceHolder()); } else { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->expression())); - return fegen::FegenValue::get(expr->exprType, "expression_tmp", - fegen::FegenRightValue::getByExpr(expr)); + return fegen::Value::get(expr->exprType, "expression_tmp", + fegen::RightValue::getByExpr(expr)); } } @@ -288,17 +288,17 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitBuiltinTypeInstances( FegenParser::BuiltinTypeInstancesContext *ctx) override { if (ctx->BOOL()) { - return FegenType::getBoolType(); + return Type::getBoolType(); } else if (ctx->INT()) { - return FegenType::getInt32Type(); + return Type::getInt32Type(); } else if (ctx->FLOAT()) { - return FegenType::getFloatType(); + return Type::getFloatType(); } else if (ctx->DOUBLE()) { - return FegenType::getDoubleType(); + return Type::getDoubleType(); } else if (ctx->CHAR()) { - return FegenType::getCharType(); + return Type::getCharType(); } else if (ctx->STRING()) { - return FegenType::getStringType(); + return Type::getStringType(); } else { std::cerr << "error builtin type." << std::endl; return nullptr; @@ -314,12 +314,12 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { // type auto tyDef = this->sstack.attemptFindTypeDef( ctx->prefixedName()->identifier(0)->getText()); - return fegen::FegenType::getTemplateType(tyDef); + return fegen::Type::getTemplateType(tyDef); } } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate return this->visit(ctx->builtinTypeTemplate()); } else { // TYPE - return fegen::FegenType::getMetaType(); + return fegen::Type::getMetaType(); } } @@ -327,15 +327,15 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitBuiltinTypeTemplate( FegenParser::BuiltinTypeTemplateContext *ctx) override { if (ctx->INTEGER()) { - return fegen::FegenType::getIntegerTemplate(); + return fegen::Type::getIntegerTemplate(); } else if (ctx->FLOATPOINT()) { - return fegen::FegenType::getFloatPointTemplate(); + return fegen::Type::getFloatPointTemplate(); } else if (ctx->TENSOR()) { // return fegen::FegenType::getTensorTemplate(); - return fegen::FegenType::getPlaceHolder(); + return fegen::Type::getPlaceHolder(); } else if (ctx->VECTOR()) { // return fegen::FegenType::getVectorTemplate(); - return fegen::FegenType::getPlaceHolder(); + return fegen::Type::getPlaceHolder(); } else { return nullptr; } @@ -344,12 +344,12 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenType std::any visitCollectTypeSpec(FegenParser::CollectTypeSpecContext *ctx) override { - auto kind = fegen::FegenType::TypeKind::CPP; + auto kind = fegen::Type::TypeKind::CPP; if (ctx->valueKind()) { - kind = std::any_cast( + kind = std::any_cast( this->visit(ctx->valueKind())); } - auto ty = std::any_cast(this->visit(ctx->collectType())); + auto ty = std::any_cast(this->visit(ctx->collectType())); ty.setTypeKind(kind); return ty; } @@ -357,41 +357,41 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenType std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->expression())); if (ctx->collectProtoType()->ANY()) { - std::vector tys; + std::vector tys; // TODO: reprot error - assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::VECTOR); + assert(expr->getKind() == fegen::RightValue::LiteralKind::VECTOR); auto exprs = std::any_cast< - std::vector>>( + std::vector>>( expr->getContent()); for (auto expr : exprs) { - auto ty = std::any_cast(expr->getContent()); + auto ty = std::any_cast(expr->getContent()); tys.push_back(ty); } - return fegen::FegenType::getAnyType(tys); + return fegen::Type::getAnyType(tys); } else if (ctx->collectProtoType()->LIST()) { - assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::TYPE); - auto ty = std::any_cast(expr->getContent()); - return fegen::FegenType::getListType(ty); + assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); + auto ty = std::any_cast(expr->getContent()); + return fegen::Type::getListType(ty); } else { // optional - assert(expr->getKind() == fegen::FegenRightValue::LiteralKind::TYPE); - auto ty = std::any_cast(expr->getContent()); - return fegen::FegenType::getOptionalType(ty); + assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); + auto ty = std::any_cast(expr->getContent()); + return fegen::Type::getOptionalType(ty); } } // return std::shared_ptr std::any visitExpression(FegenParser::ExpressionContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->andExpr(0))); for (size_t i = 1; i <= ctx->andExpr().size() - 1; i++) { auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->andExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation( + expr = RightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::OR); } return expr; @@ -400,13 +400,13 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitAndExpr(FegenParser::AndExprContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->equExpr(0))); for (size_t i = 1; i <= ctx->equExpr().size() - 1; i++) { auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->equExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation( + expr = RightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::AND); } return expr; @@ -415,7 +415,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitEquExpr(FegenParser::EquExprContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->compareExpr(0))); for (size_t i = 1; i <= ctx->compareExpr().size() - 1; i++) { FegenOperator op; @@ -425,9 +425,9 @@ class FegenVisitor : public FegenParserBaseVisitor { op = FegenOperator::NOT_EQUAL; } auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->compareExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } @@ -435,7 +435,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitCompareExpr(FegenParser::CompareExprContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->addExpr(0))); for (size_t i = 1; i <= ctx->addExpr().size() - 1; i++) { FegenOperator op; @@ -452,9 +452,9 @@ class FegenVisitor : public FegenParserBaseVisitor { op = FegenOperator::GREATER_EQUAL; } auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->addExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } @@ -462,7 +462,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitAddExpr(FegenParser::AddExprContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->term(0))); for (size_t i = 1; i <= ctx->term().size() - 1; i++) { FegenOperator op; @@ -473,9 +473,9 @@ class FegenVisitor : public FegenParserBaseVisitor { op = FegenOperator::SUB; } auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->term(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } @@ -483,7 +483,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitTerm(FegenParser::TermContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->powerExpr(0))); for (size_t i = 1; i <= ctx->powerExpr().size() - 1; i++) { FegenOperator op; @@ -496,9 +496,9 @@ class FegenVisitor : public FegenParserBaseVisitor { op = FegenOperator::MOD; } auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->powerExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation(expr, rhs, op); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; } @@ -506,13 +506,13 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitPowerExpr(FegenParser::PowerExprContext *ctx) override { auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->unaryExpr(0))); for (size_t i = 1; i <= ctx->unaryExpr().size() - 1; i++) { auto rhs = - std::any_cast>( + std::any_cast>( this->visit(ctx->unaryExpr(i))); - expr = FegenRightValue::ExpressionNode::binaryOperation( + expr = RightValue::ExpressionNode::binaryOperation( expr, rhs, FegenOperator::POWER); } return expr; @@ -524,7 +524,7 @@ class FegenVisitor : public FegenParserBaseVisitor { return this->visit(ctx->primaryExpr()); } auto expr = - std::any_cast>( + std::any_cast>( this->visit(ctx->primaryExpr())); FegenOperator op; if (ctx->Minus()) { @@ -532,7 +532,7 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::NOT; } - expr = FegenRightValue::ExpressionNode::unaryOperation(expr, op); + expr = RightValue::ExpressionNode::unaryOperation(expr, op); return expr; } @@ -548,15 +548,15 @@ class FegenVisitor : public FegenParserBaseVisitor { auto name = ctx->identifier()->getText(); auto var = this->sstack.attemptFindVar(name); if (var) { - return (std::shared_ptr) - fegen::FegenRightValue::ExpressionTerminal::getLeftValue(var); + return (std::shared_ptr) + fegen::RightValue::ExpressionTerminal::getLeftValue(var); } else { // TODO auto tyDef = this->manager.getTypeDefination(name); if (tyDef) { - auto tyVar = fegen::FegenType::getTemplateType(tyDef); - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getType(tyVar); + auto tyVar = fegen::Type::getTemplateType(tyDef); + return (std::shared_ptr) + fegen::RightValue::Expression::getType(tyVar); } else { // TODO: error report std::cerr << "can not find variable: " << ctx->identifier()->getText() @@ -566,9 +566,9 @@ class FegenVisitor : public FegenParserBaseVisitor { } } } else if (ctx->typeSpec()) { - auto ty = std::any_cast(this->visit(ctx->typeSpec())); - return (std::shared_ptr) - FegenRightValue::ExpressionTerminal::getType(ty); + auto ty = std::any_cast(this->visit(ctx->typeSpec())); + return (std::shared_ptr) + RightValue::ExpressionTerminal::getType(ty); } else { // constant, functionCall, parenSurroundedExpr,contextMethodInvoke, // and variableAccess return this->visit(ctx->children[0]); @@ -579,16 +579,16 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitIntLiteral(FegenParser::IntLiteralContext *ctx) override { long long int number = std::stoi(ctx->getText()); size_t size = 32; // TODO: Get size of number. - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getInteger(number, size); + return (std::shared_ptr) + fegen::RightValue::Expression::getInteger(number, size); } // return std::shared_ptr std::any visitRealLiteral(FegenParser::RealLiteralContext *ctx) override { long double number = std::stod(ctx->getText()); size_t size = 32; // TODO: Get size of number. - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getFloatPoint(number, size); + return (std::shared_ptr) + fegen::RightValue::Expression::getFloatPoint(number, size); } // return std::shared_ptr @@ -596,8 +596,8 @@ class FegenVisitor : public FegenParserBaseVisitor { std::string s = ctx->getText(); // remove quotation marks std::string strWithoutQuotation = s.substr(1, s.size() - 2); - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getString(strWithoutQuotation); + return (std::shared_ptr) + fegen::RightValue::Expression::getString(strWithoutQuotation); } // return std::shared_ptr @@ -606,21 +606,21 @@ class FegenVisitor : public FegenParserBaseVisitor { if (ctx->getText() == "true") { content = 1; } - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getInteger(content, 1); + return (std::shared_ptr) + fegen::RightValue::Expression::getInteger(content, 1); } // return std::shared_ptr std::any visitListLiteral(FegenParser::ListLiteralContext *ctx) override { - std::vector> elements; + std::vector> elements; for (auto exprCtx : ctx->expression()) { auto expr = - std::any_cast>( + std::any_cast>( this->visit(exprCtx)); elements.push_back(expr); } - return (std::shared_ptr) - fegen::FegenRightValue::Expression::getList(elements); + return (std::shared_ptr) + fegen::RightValue::Expression::getList(elements); } std::any visitActionSpec(FegenParser::ActionSpecContext *ctx) override { @@ -630,7 +630,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { sstack.pushScope(); auto returnType = - std::any_cast(this->visit(ctx->typeSpec())); + std::any_cast(this->visit(ctx->typeSpec())); auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasfunc = manager.functionMap.find(functionName); @@ -641,12 +641,12 @@ class FegenVisitor : public FegenParserBaseVisitor { exit(0); return nullptr; } - auto functionParams = std::any_cast>( + auto functionParams = std::any_cast>( this->visit(ctx->funcParams())); this->visit(ctx->statementBlock()); - fegen::FegenFunction *function = - fegen::FegenFunction::get(functionName, functionParams, &returnType); + fegen::Function *function = + fegen::Function::get(functionName, functionParams, &returnType); manager.functionMap.insert(std::pair{functionName, function}); sstack.popScope(); return nullptr; @@ -658,13 +658,14 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override { - std::vector paramsList = {}; + std::vector paramsList = {}; for (size_t i = 0; i < ctx->typeSpec().size(); i++) { auto paramType = - std::any_cast(this->visit(ctx->typeSpec(i))); + std::any_cast(this->visit(ctx->typeSpec(i))); auto paramName = ctx->identifier(i)->getText(); - auto param = fegen::FegenValue::get(paramType, paramName, fegen::FegenRightValue::getPlaceHolder()); + auto param = fegen::Value::get( + paramType, paramName, fegen::RightValue::getPlaceHolder()); paramsList.push_back(param); sstack.attemptAddVar(param); } @@ -673,21 +674,24 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto varType = - std::any_cast(this->visit(ctx->typeSpec())); + std::any_cast(this->visit(ctx->typeSpec())); auto varName = ctx->identifier()->getText(); - fegen::FegenValue *var; + fegen::Value *var; if (ctx->expression()) { - auto varcontent = std::any_cast>( - this->visit(ctx->expression())); + auto varcontent = + std::any_cast>( + this->visit(ctx->expression())); // TODO: check error // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ // std::cerr << "The variabel \" " << varName // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." // << std::endl; exit(0); return nullptr; // } - var = fegen::FegenValue::get(varType, varName, fegen::FegenRightValue::getByExpr(varcontent)); + var = fegen::Value::get( + varType, varName, fegen::RightValue::getByExpr(varcontent)); } else { - var = fegen::FegenValue::get(varType, varName, fegen::FegenRightValue::getPlaceHolder()); + var = fegen::Value::get(varType, varName, + fegen::RightValue::getPlaceHolder()); } sstack.attemptAddVar(var); manager.stmtContentMap.insert(std::pair{ctx, var}); @@ -696,25 +700,26 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { auto varName = ctx->identifier()->getText(); - auto varcontent = std::any_cast>( - this->visit(ctx->expression())); + auto varcontent = + std::any_cast>( + this->visit(ctx->expression())); auto var = sstack.attemptFindVar(varName); - if (!fegen::FegenType::isSameType(&var->getType(), &varcontent->exprType)) { + if (!fegen::Type::isSameType(&var->getType(), &varcontent->exprType)) { std::cerr << "The variabel \" " << varName << "\" need \"" << var->getType().getTypeName() << " \" type rightvalue." << std::endl; exit(0); return nullptr; } - fegen::FegenValue *stmt = - fegen::FegenValue::get(var->getType(), varName, fegen::FegenRightValue::getByExpr(varcontent)); + fegen::Value *stmt = fegen::Value::get( + var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); manager.stmtContentMap.insert(std::pair{ctx, stmt}); return stmt; } std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { - std::vector parasList = {}; + std::vector parasList = {}; auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasFunc = manager.functionMap.at(functionName); @@ -722,7 +727,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto paraList = hasFunc->getInputTypeList(); if (paramsNum > 0) { for (size_t i = 0; i < paramsNum; i++) { - auto oprand = std::any_cast( + auto oprand = std::any_cast( this->visit(ctx->expression(i))); parasList.push_back(oprand); } @@ -735,7 +740,7 @@ class FegenVisitor : public FegenParserBaseVisitor { return nullptr; } for (size_t i = 0; i < len1; i++) { - if (!fegen::FegenType::isSameType(¶List[i]->getType(), + if (!fegen::Type::isSameType(¶List[i]->getType(), ¶sList[i]->exprType)) { std::cerr << "The function \" " << functionName << "\" parameter" << i << " type mismatch." << std::endl; @@ -745,8 +750,8 @@ class FegenVisitor : public FegenParserBaseVisitor { } } auto returnType = hasFunc->getReturnType(); - fegen::FegenFunction *funcCall = - fegen::FegenFunction::get(functionName, paraList, returnType); + fegen::Function *funcCall = + fegen::Function::get(functionName, paraList, returnType); manager.stmtContentMap.insert(std::pair{ctx, funcCall}); return returnType; } @@ -786,7 +791,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { auto opName = ctx->opName()->getText(); auto opDef = - std::any_cast(this->visit(ctx->opBlock())); + std::any_cast(this->visit(ctx->opBlock())); opDef->setOpName(opName); bool success = this->manager.addOperationDefination(opDef); if (!success) { @@ -798,17 +803,17 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenOperation* std::any visitOpBlock(FegenParser::OpBlockContext *ctx) override { - std::vector args; - std::vector res; + std::vector args; + std::vector res; if (ctx->argumentSpec()) { - args = std::any_cast>( + args = std::any_cast>( this->visit(ctx->argumentSpec())); } if (ctx->resultSpec()) { - res = std::any_cast>( + res = std::any_cast>( this->visit(ctx->resultSpec())); } - return fegen::FegenOperation::get("", args, res, ctx->bodySpec()); + return fegen::Operation::get("", args, res, ctx->bodySpec()); } }; } // namespace fegen diff --git a/frontend/FrontendGen/include/Scope.h b/frontend/FrontendGen/include/Scope.h index c8c46573a7..1d1acc283f 100644 --- a/frontend/FrontendGen/include/Scope.h +++ b/frontend/FrontendGen/include/Scope.h @@ -20,8 +20,8 @@ template class SymbolTable { }; class FegenScope { - using TypeDefTable = SymbolTable; - using VariableTable = SymbolTable; + using TypeDefTable = SymbolTable; + using VariableTable = SymbolTable; friend class ScopeStack; private: @@ -35,15 +35,15 @@ class FegenScope { ~FegenScope() = default; /// @brief this will not check. - FegenTypeDefination *findTypeDef(std::string name); + TypeDefination *findTypeDef(std::string name); /// @brief this will not check whether tyDef is already existed or not. - void addTypeDef(FegenTypeDefination *tyDef); + void addTypeDef(TypeDefination *tyDef); /// @brief return true if exist. bool isExistTypeDef(std::string name); /// @brief this will not check. - FegenValue *findVar(std::string name); + Value *findVar(std::string name); /// @brief this will not check whether var is already existed or not. - void addVar(FegenValue *var); + void addVar(Value *var); /// @brief return true if exist. bool isExistVar(std::string name); }; @@ -68,13 +68,13 @@ class ScopeStack { void pushScope(); void popScope(); /// @brief check and add var to current scope, return false if failed. - bool attemptAddVar(FegenValue *var); + bool attemptAddVar(Value *var); /// @brief check add find var from current scope, return nullptr if failed. - FegenValue *attemptFindVar(std::string name); + Value *attemptFindVar(std::string name); /// @brief check and add tyDef to current scope, return false if failed. - bool attemptAddTypeDef(FegenTypeDefination *tyDef); + bool attemptAddTypeDef(TypeDefination *tyDef); /// @brief check and find tyDef from current scope, return nullptr if failed. - FegenTypeDefination *attemptFindTypeDef(std::string name); + TypeDefination *attemptFindTypeDef(std::string name); }; } // namespace fegen diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index b2c03059c5..ac5fa0803b 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -13,64 +13,64 @@ #include #include -fegen::FegenFunction::FegenFunction(std::string name, - std::vector &&inputTypeList, - FegenType *returnType) +fegen::Function::Function(std::string name, + std::vector &&inputTypeList, + Type *returnType) : name(name), inputTypeList(inputTypeList), returnType(returnType) {} -fegen::FegenFunction * -fegen::FegenFunction::get(std::string name, - std::vector inputTypeList, - FegenType *returnType) { - return new fegen::FegenFunction(name, std::move(inputTypeList), returnType); +fegen::Function * +fegen::Function::get(std::string name, + std::vector inputTypeList, + Type *returnType) { + return new fegen::Function(name, std::move(inputTypeList), returnType); } -std::string fegen::FegenFunction::getName() { return this->name; } +std::string fegen::Function::getName() { return this->name; } -std::vector &fegen::FegenFunction::getInputTypeList() { +std::vector &fegen::Function::getInputTypeList() { return this->inputTypeList; } -fegen::FegenValue *fegen::FegenFunction::getInputTypeList(size_t i) { +fegen::Value *fegen::Function::getInputTypeList(size_t i) { return this->inputTypeList[i]; } -fegen::FegenType *fegen::FegenFunction::getReturnType() { +fegen::Type *fegen::Function::getReturnType() { return this->returnType; } -fegen::FegenOperation::FegenOperation(std::string dialectName, +fegen::Operation::Operation(std::string dialectName, std::string operationName, - std::vector &&arguments, - std::vector &&results, + std::vector &&arguments, + std::vector &&results, fegen::FegenParser::BodySpecContext *ctx) : dialectName(dialectName), arguments(arguments), results(results), ctx(ctx) {} -void fegen::FegenOperation::setOpName(std::string name) { +void fegen::Operation::setOpName(std::string name) { this->operationName = name; } -std::string fegen::FegenOperation::getOpName() { return this->operationName; } +std::string fegen::Operation::getOpName() { return this->operationName; } -std::vector &fegen::FegenOperation::getArguments() { +std::vector &fegen::Operation::getArguments() { return this->arguments; } -fegen::FegenValue *fegen::FegenOperation::getArguments(size_t i) { +fegen::Value *fegen::Operation::getArguments(size_t i) { return this->arguments[i]; } -std::vector &fegen::FegenOperation::getResults() { +std::vector &fegen::Operation::getResults() { return this->results; } -fegen::FegenValue *fegen::FegenOperation::getResults(size_t i) { +fegen::Value *fegen::Operation::getResults(size_t i) { return this->results[i]; } -fegen::FegenOperation *fegen::FegenOperation::get( - std::string operationName, std::vector arguments, - std::vector results, FegenParser::BodySpecContext *ctx) { - return new fegen::FegenOperation(fegen::FegenManager::getManager().moduleName, +fegen::Operation *fegen::Operation::get( + std::string operationName, std::vector arguments, + std::vector results, FegenParser::BodySpecContext *ctx) { + return new fegen::Operation(fegen::Manager::getManager().moduleName, operationName, std::move(arguments), std::move(results), ctx); } @@ -81,7 +81,7 @@ fegen::FegenOperation *fegen::FegenOperation::get( /// for example: Integer + 32 --> Integer<32> /// @return joint name std::string jointTypeName(std::string templateName, - const std::vector ¶meters) { + const std::vector ¶meters) { if (parameters.empty()) { return templateName; } @@ -99,76 +99,76 @@ std::string jointTypeName(std::string templateName, return res; } -fegen::FegenType::FegenType(TypeKind kind, std::string name, - std::vector parameters, - FegenTypeDefination *tyDef, int typeLevel) +fegen::Type::Type(TypeKind kind, std::string name, + std::vector parameters, + TypeDefination *tyDef, int typeLevel) : kind(kind), typeName(name), parameters(std::move(parameters)), typeDefine(tyDef), typeLevel(typeLevel) {} -fegen::FegenType::FegenType(fegen::FegenType::TypeKind kind, - std::vector parameters, - FegenTypeDefination *tyDef, int typeLevel) +fegen::Type::Type(fegen::Type::TypeKind kind, + std::vector parameters, + TypeDefination *tyDef, int typeLevel) : kind(kind), typeName(jointTypeName(tyDef->getName(), parameters)), parameters(std::move(parameters)), typeDefine(tyDef), typeLevel((typeLevel)) {} -fegen::FegenType::FegenType(const fegen::FegenType &fty) +fegen::Type::Type(const fegen::Type &fty) : kind(fty.kind), typeName(fty.typeName), typeDefine(fty.typeDefine), typeLevel(fty.typeLevel) { // deep copy parameters for (auto paramPtr : fty.parameters) { - this->parameters.push_back(new fegen::FegenValue(*paramPtr)); + this->parameters.push_back(new fegen::Value(*paramPtr)); } } -fegen::FegenType::FegenType(fegen::FegenType &&fty) +fegen::Type::Type(fegen::Type &&fty) : kind(fty.kind), typeName(std::move(fty.typeName)), parameters(std::move(fty.parameters)), typeDefine(fty.typeDefine), typeLevel(fty.typeLevel) {} -fegen::FegenType::TypeKind fegen::FegenType::getTypeKind() { +fegen::Type::TypeKind fegen::Type::getTypeKind() { return this->kind; } -void fegen::FegenType::setTypeKind(fegen::FegenType::TypeKind kind) { +void fegen::Type::setTypeKind(fegen::Type::TypeKind kind) { this->kind = kind; } -std::vector &fegen::FegenType::getParameters() { +std::vector &fegen::Type::getParameters() { return this->parameters; } -fegen::FegenValue *fegen::FegenType::getParameters(size_t i) { +fegen::Value *fegen::Type::getParameters(size_t i) { return this->parameters[i]; } -void fegen::FegenType::setParameters(std::vector ¶ms) { +void fegen::Type::setParameters(std::vector ¶ms) { this->parameters = params; // set parameters and level up! this->typeLevel++; } -fegen::FegenTypeDefination *fegen::FegenType::getTypeDefination() { +fegen::TypeDefination *fegen::Type::getTypeDefination() { return this->typeDefine; } -void fegen::FegenType::setTypeDefination(fegen::FegenTypeDefination *tyDef) { +void fegen::Type::setTypeDefination(fegen::TypeDefination *tyDef) { this->typeDefine = tyDef; } -std::string fegen::FegenType::getTypeName() { return this->typeName; } +std::string fegen::Type::getTypeName() { return this->typeName; } -int fegen::FegenType::getTypeLevel() { return this->typeLevel; } +int fegen::Type::getTypeLevel() { return this->typeLevel; } -bool fegen::FegenType::isSameType(fegen::FegenType *type1, - fegen::FegenType *type2) { +bool fegen::Type::isSameType(fegen::Type *type1, + fegen::Type *type2) { if (type1->getTypeName() == type2->getTypeName()) return true; else return false; } -std::string fegen::FegenType::toStringForTypedef() { +std::string fegen::Type::toStringForTypedef() { // handle builtin type instance auto typeName = this->typeName; auto typedefName = this->typeDefine->getName(); @@ -224,7 +224,7 @@ std::string fegen::FegenType::toStringForTypedef() { } } -std::string fegen::FegenType::toStringForOpdef() { +std::string fegen::Type::toStringForOpdef() { // handle builtin type instance auto typeName = this->typeName; auto typedefName = this->typeDefine->getName(); @@ -256,7 +256,7 @@ std::string fegen::FegenType::toStringForOpdef() { exit(0); } -std::string fegen::FegenType::toStringForCppKind() { +std::string fegen::Type::toStringForCppKind() { // handle builtin type instance auto typeName = this->typeName; auto typedefName = this->typeDefine->getName(); @@ -290,219 +290,219 @@ std::string fegen::FegenType::toStringForCppKind() { exit(0); } -fegen::FegenType::~FegenType() { +fegen::Type::~Type() { for (auto p : this->parameters) { delete p; } } -fegen::FegenType fegen::FegenType::getPlaceHolder() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), +fegen::Type fegen::Type::getPlaceHolder() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), 0); } -fegen::FegenType fegen::FegenType::getMetaType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_TYPE), 2); +fegen::Type fegen::Type::getMetaType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPE), 2); } -fegen::FegenType fegen::FegenType::getMetaTemplateType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), +fegen::Type fegen::Type::getMetaTemplateType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), 1); } -fegen::FegenType fegen::FegenType::getInt32Type() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, "int", - {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", - fegen::FegenRightValue::getPlaceHolder())}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +fegen::Type fegen::Type::getInt32Type() { + return fegen::Type( + fegen::Type::TypeKind::CPP, "int", + {fegen::Value::get(fegen::Type::getPlaceHolder(), "size", + fegen::RightValue::getPlaceHolder())}, + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); } -fegen::FegenType fegen::FegenType::getFloatType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, "float", - {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getInteger(32))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +fegen::Type fegen::Type::getFloatType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, "float", + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getInteger(32))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); } -fegen::FegenType fegen::FegenType::getDoubleType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, "double", - {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getInteger(64))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +fegen::Type fegen::Type::getDoubleType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, "double", + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getInteger(64))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); } -fegen::FegenType fegen::FegenType::getBoolType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, "bool", - {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getInteger(1))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +fegen::Type fegen::Type::getBoolType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, "bool", + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getInteger(1))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); } -fegen::FegenType fegen::FegenType::getIntegerType(fegen::FegenValue *size) { +fegen::Type fegen::Type::getIntegerType(fegen::Value *size) { if (size->getContent() == 32) - return fegen::FegenType::getInt32Type(); - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {size}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 3); + return fegen::Type::getInt32Type(); + return fegen::Type( + fegen::Type::TypeKind::CPP, {size}, + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); } -fegen::FegenType fegen::FegenType::getFloatPointType(fegen::FegenValue *size) { +fegen::Type fegen::Type::getFloatPointType(fegen::Value *size) { if (size->getContent() == 32) { - return fegen::FegenType::getFloatType(); + return fegen::Type::getFloatType(); } else if (size->getContent() == 64) { - return fegen::FegenType::getDoubleType(); + return fegen::Type::getDoubleType(); } - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {size}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); + return fegen::Type( + fegen::Type::TypeKind::CPP, {size}, + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); } -fegen::FegenType fegen::FegenType::getCharType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_CHAR), 3); +fegen::Type fegen::Type::getCharType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_CHAR), 3); } -fegen::FegenType fegen::FegenType::getStringType() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_STRING), 3); +fegen::Type fegen::Type::getStringType() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3); } -fegen::FegenType fegen::FegenType::getVectorType(fegen::FegenValue *size, - fegen::FegenType elementType) { +fegen::Type fegen::Type::getVectorType(fegen::Value *size, + fegen::Type elementType) { assert(elementType.typeLevel == 3); - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, + return fegen::Type( + fegen::Type::TypeKind::CPP, {size, - fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::getType(elementType))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_VECTOR), + fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getType(elementType))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR), elementType.typeLevel); } -fegen::FegenType fegen::FegenType::getTensorType(fegen::FegenValue *shape, - fegen::FegenType elementType) { +fegen::Type fegen::Type::getTensorType(fegen::Value *shape, + fegen::Type elementType) { assert(elementType.typeLevel == 3); - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, + return fegen::Type( + fegen::Type::TypeKind::CPP, {shape, - fegen::FegenValue::get(fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::getType(elementType))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_TENSOR), + fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getType(elementType))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR), elementType.typeLevel); } // List -fegen::FegenType fegen::FegenType::getListType(fegen::FegenType elementType) { +fegen::Type fegen::Type::getListType(fegen::Type elementType) { assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, - {fegen::FegenValue::get( - elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() - : fegen::FegenType::getMetaType(), - "elementType", fegen::FegenRightValue::getType(elementType))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_LIST), + return fegen::Type( + fegen::Type::TypeKind::CPP, + {fegen::Value::get( + elementType.typeLevel == 2 ? fegen::Type::getMetaTemplateType() + : fegen::Type::getMetaType(), + "elementType", fegen::RightValue::getType(elementType))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), elementType.typeLevel); } // Optional -fegen::FegenType -fegen::FegenType::getOptionalType(fegen::FegenType elementType) { +fegen::Type +fegen::Type::getOptionalType(fegen::Type elementType) { assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, - {fegen::FegenValue::get( - elementType.typeLevel == 2 ? fegen::FegenType::getMetaTemplateType() - : fegen::FegenType::getMetaType(), - "elementType", fegen::FegenRightValue::getType(elementType))}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_OPTINAL), + return fegen::Type( + fegen::Type::TypeKind::CPP, + {fegen::Value::get( + elementType.typeLevel == 2 ? fegen::Type::getMetaTemplateType() + : fegen::Type::getMetaType(), + "elementType", fegen::RightValue::getType(elementType))}, + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), elementType.typeLevel); } // Any -fegen::FegenType -fegen::FegenType::getAnyType(std::vector elementTypes) { - std::vector p_elemTy; +fegen::Type +fegen::Type::getAnyType(std::vector elementTypes) { + std::vector p_elemTy; int i = 0; std::string name("elementType_"); auto tyLevel = elementTypes[0].typeLevel; assert(tyLevel == 2 || tyLevel == 3); - auto tyty = tyLevel == 2 ? fegen::FegenType::getMetaTemplateType() - : fegen::FegenType::getMetaType(); + auto tyty = tyLevel == 2 ? fegen::Type::getMetaTemplateType() + : fegen::Type::getMetaType(); for (auto &ty : elementTypes) { assert(ty.typeLevel == tyLevel); - p_elemTy.push_back(fegen::FegenValue::get( - tyty, name + std::to_string(i), fegen::FegenRightValue::getType(ty))); + p_elemTy.push_back(fegen::Value::get( + tyty, name + std::to_string(i), fegen::RightValue::getType(ty))); i++; } - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, p_elemTy, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_ANY), tyLevel); + return fegen::Type( + fegen::Type::TypeKind::CPP, p_elemTy, + fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), tyLevel); } -fegen::FegenType fegen::FegenType::getIntegerTemplate() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_INTEGER), 2); +fegen::Type fegen::Type::getIntegerTemplate() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 2); } -fegen::FegenType fegen::FegenType::getFloatPointTemplate() { - return fegen::FegenType( - fegen::FegenType::TypeKind::CPP, {}, - fegen::FegenManager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 2); +fegen::Type fegen::Type::getFloatPointTemplate() { + return fegen::Type( + fegen::Type::TypeKind::CPP, {}, + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 2); } -fegen::FegenType -fegen::FegenType::getInstanceType(fegen::FegenTypeDefination *typeDefination, - std::vector parameters) { - return fegen::FegenType(fegen::FegenType::TypeKind::CPP, parameters, +fegen::Type +fegen::Type::getInstanceType(fegen::TypeDefination *typeDefination, + std::vector parameters) { + return fegen::Type(fegen::Type::TypeKind::CPP, parameters, typeDefination, 3); } -fegen::FegenType -fegen::FegenType::getTemplateType(fegen::FegenTypeDefination *typeDefination) { - return fegen::FegenType(fegen::FegenType::TypeKind::CPP, {}, typeDefination, +fegen::Type +fegen::Type::getTemplateType(fegen::TypeDefination *typeDefination) { + return fegen::Type(fegen::Type::TypeKind::CPP, {}, typeDefination, 2); } // class FegenTypeDefination -fegen::FegenTypeDefination::FegenTypeDefination( +fegen::TypeDefination::TypeDefination( std::string dialectName, std::string name, - std::vector parameters, + std::vector parameters, FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome) : dialectName(std::move(dialectName)), name(std::move(name)), parameters(std::move(parameters)), ctx(ctx), ifCustome(ifCustome) {} -fegen::FegenTypeDefination * -fegen::FegenTypeDefination::get(std::string dialectName, std::string name, - std::vector parameters, +fegen::TypeDefination * +fegen::TypeDefination::get(std::string dialectName, std::string name, + std::vector parameters, FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome) { - return new fegen::FegenTypeDefination(std::move(dialectName), std::move(name), + return new fegen::TypeDefination(std::move(dialectName), std::move(name), std::move(parameters), ctx, ifCustome); } -std::string fegen::FegenTypeDefination::getDialectName() { +std::string fegen::TypeDefination::getDialectName() { return this->dialectName; } -void fegen::FegenTypeDefination::setDialectName(std::string name) { +void fegen::TypeDefination::setDialectName(std::string name) { this->dialectName = name; } -std::string fegen::FegenTypeDefination::getName() { return this->name; } +std::string fegen::TypeDefination::getName() { return this->name; } -std::string fegen::FegenTypeDefination::getMnemonic() { +std::string fegen::TypeDefination::getMnemonic() { if (this->mnemonic.empty()) { this->mnemonic = this->name; std::transform(this->mnemonic.begin(), this->mnemonic.end(), @@ -511,131 +511,131 @@ std::string fegen::FegenTypeDefination::getMnemonic() { return this->mnemonic; } -void fegen::FegenTypeDefination::setName(std::string name) { +void fegen::TypeDefination::setName(std::string name) { this->name = name; } -const std::vector & -fegen::FegenTypeDefination::getParameters() { +const std::vector & +fegen::TypeDefination::getParameters() { return this->parameters; } fegen::FegenParser::TypeDefinationDeclContext * -fegen::FegenTypeDefination::getCtx() { +fegen::TypeDefination::getCtx() { return this->ctx; } -void fegen::FegenTypeDefination::setCtx( +void fegen::TypeDefination::setCtx( FegenParser::TypeDefinationDeclContext *ctx) { this->ctx = ctx; } -bool fegen::FegenTypeDefination::isCustome() { return this->ifCustome; } +bool fegen::TypeDefination::isCustome() { return this->ifCustome; } // class Expression -fegen::FegenRightValue::Expression::Expression(bool ifTerminal, +fegen::RightValue::Expression::Expression(bool ifTerminal, LiteralKind kind, - FegenType &exprTy, + Type &exprTy, bool isConstexpr) : ifTerminal(ifTerminal), kind(kind), exprType(exprTy), ifConstexpr(isConstexpr) {} -bool fegen::FegenRightValue::Expression::isTerminal() { +bool fegen::RightValue::Expression::isTerminal() { return this->ifTerminal; } -fegen::FegenRightValue::LiteralKind -fegen::FegenRightValue::Expression::getKind() { +fegen::RightValue::LiteralKind +fegen::RightValue::Expression::getKind() { return this->kind; } -fegen::FegenType &fegen::FegenRightValue::Expression::getType() { +fegen::Type &fegen::RightValue::Expression::getType() { return this->exprType; } -bool fegen::FegenRightValue::Expression::isConstexpr() { +bool fegen::RightValue::Expression::isConstexpr() { return this->ifConstexpr; } -std::shared_ptr -fegen::FegenRightValue::Expression::getPlaceHolder() { - return std::make_shared(); +std::shared_ptr +fegen::RightValue::Expression::getPlaceHolder() { + return std::make_shared(); } -std::shared_ptr -fegen::FegenRightValue::Expression::getInteger(long long int content, +std::shared_ptr +fegen::RightValue::Expression::getInteger(long long int content, size_t size) { - return std::make_shared(content, + return std::make_shared(content, size); } -std::shared_ptr -fegen::FegenRightValue::Expression::getFloatPoint(long double content, +std::shared_ptr +fegen::RightValue::Expression::getFloatPoint(long double content, size_t size) { - return std::make_shared(content, + return std::make_shared(content, size); } -std::shared_ptr -fegen::FegenRightValue::Expression::getString(std::string content) { - return std::make_shared(content); +std::shared_ptr +fegen::RightValue::Expression::getString(std::string content) { + return std::make_shared(content); } -std::shared_ptr -fegen::FegenRightValue::Expression::getType(fegen::FegenType &content) { - return std::make_shared(content); +std::shared_ptr +fegen::RightValue::Expression::getType(fegen::Type &content) { + return std::make_shared(content); } -std::shared_ptr -fegen::FegenRightValue::Expression::getList( - std::vector> &content) { - return std::make_shared(content); +std::shared_ptr +fegen::RightValue::Expression::getList( + std::vector> &content) { + return std::make_shared(content); } -std::shared_ptr -fegen::FegenRightValue::Expression::getLeftValue(fegen::FegenValue *content) { - return std::make_shared(content); +std::shared_ptr +fegen::RightValue::Expression::getLeftValue(fegen::Value *content) { + return std::make_shared(content); } -std::shared_ptr -fegen::FegenRightValue::Expression::binaryOperation( - std::shared_ptr lhs, - std::shared_ptr rhs, FegenOperator op) { - FegenType resTy = fegen::inferenceType({lhs, rhs}, op); - return std::make_shared( - op, std::vector>{ +std::shared_ptr +fegen::RightValue::Expression::binaryOperation( + std::shared_ptr lhs, + std::shared_ptr rhs, FegenOperator op) { + Type resTy = fegen::inferenceType({lhs, rhs}, op); + return std::make_shared( + op, std::vector>{ lhs, rhs}); } -std::shared_ptr -fegen::FegenRightValue::Expression::unaryOperation( - std::shared_ptr v, FegenOperator op) { - FegenType resTy = fegen::inferenceType({v}, op); - return std::make_shared( - op, std::vector>{v}); +std::shared_ptr +fegen::RightValue::Expression::unaryOperation( + std::shared_ptr v, FegenOperator op) { + Type resTy = fegen::inferenceType({v}, op); + return std::make_shared( + op, std::vector>{v}); } // class ExpressionNode -fegen::FegenRightValue::ExpressionNode::ExpressionNode(LiteralKind kind, - FegenType exprTy, +fegen::RightValue::ExpressionNode::ExpressionNode(LiteralKind kind, + Type exprTy, bool ifConstexpr) : Expression(false, kind, exprTy, ifConstexpr) {} -std::string fegen::FegenRightValue::ExpressionNode::toString() { +std::string fegen::RightValue::ExpressionNode::toString() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionNode::toStringForTypedef() { +std::string fegen::RightValue::ExpressionNode::toStringForTypedef() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionNode::toStringForOpdef() { +std::string fegen::RightValue::ExpressionNode::toStringForOpdef() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionNode::toStringForCppKind() { +std::string fegen::RightValue::ExpressionNode::toStringForCppKind() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } @@ -702,7 +702,7 @@ std::string getCppOperator(fegen::FegenOperator op) { // class FunctionCall inline bool isFuncParamsAllConstant( - std::vector> ¶ms) { + std::vector> ¶ms) { for (auto param : params) { if (!param->isConstexpr()) { return false; @@ -712,175 +712,175 @@ inline bool isFuncParamsAllConstant( } // TODO: invoke methods of FegenFunction -fegen::FegenRightValue::FunctionCall::FunctionCall( - fegen::FegenFunction *func, - std::vector> params) - : ExpressionNode(fegen::FegenRightValue::LiteralKind::FUNC_CALL, - fegen::FegenType::getInt32Type(), +fegen::RightValue::FunctionCall::FunctionCall( + fegen::Function *func, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::FUNC_CALL, + fegen::Type::getInt32Type(), isFuncParamsAllConstant(params)), func(func), params(std::move(params)) {} -std::string fegen::FegenRightValue::FunctionCall::toString() { +std::string fegen::RightValue::FunctionCall::toString() { return "FunctionCall::toString"; } -std::string fegen::FegenRightValue::FunctionCall::toStringForTypedef() { +std::string fegen::RightValue::FunctionCall::toStringForTypedef() { return "FunctionCall::toStringForTypedef"; } -std::string fegen::FegenRightValue::FunctionCall::toStringForOpdef() { +std::string fegen::RightValue::FunctionCall::toStringForOpdef() { return "FunctionCall::toStringForOpdef"; } -std::string fegen::FegenRightValue::FunctionCall::toStringForCppKind() { +std::string fegen::RightValue::FunctionCall::toStringForCppKind() { return "FunctionCall::toStringForCppKind"; } -std::any fegen::FegenRightValue::FunctionCall::getContent() { return this; } +std::any fegen::RightValue::FunctionCall::getContent() { return this; } // class OperationCall -fegen::FegenRightValue::OperationCall::OperationCall( - fegen::FegenOperation *op, - std::vector> params) - : ExpressionNode(fegen::FegenRightValue::LiteralKind::OPERATION_CALL, - fegen::FegenType::getInt32Type(), +fegen::RightValue::OperationCall::OperationCall( + fegen::Operation *op, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, + fegen::Type::getInt32Type(), isFuncParamsAllConstant(params)), op(op), params(std::move(params)) {} -std::string fegen::FegenRightValue::OperationCall::toString() { +std::string fegen::RightValue::OperationCall::toString() { return "OperationCall::toString"; } -std::string fegen::FegenRightValue::OperationCall::toStringForTypedef() { +std::string fegen::RightValue::OperationCall::toStringForTypedef() { return "OperationCall::toStringForTypedef"; } -std::string fegen::FegenRightValue::OperationCall::toStringForOpdef() { +std::string fegen::RightValue::OperationCall::toStringForOpdef() { return "OperationCall::toStringForOpdef"; } -std::string fegen::FegenRightValue::OperationCall::toStringForCppKind() { +std::string fegen::RightValue::OperationCall::toStringForCppKind() { return "OperationCall::toStringForCppKind"; } -std::any fegen::FegenRightValue::OperationCall::getContent() { return this; } +std::any fegen::RightValue::OperationCall::getContent() { return this; } // class OperatorCall -fegen::FegenRightValue::OperatorCall::OperatorCall( +fegen::RightValue::OperatorCall::OperatorCall( fegen::FegenOperator op, - std::vector> params) - : ExpressionNode(fegen::FegenRightValue::LiteralKind::OPERATION_CALL, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, fegen::inferenceType(params, op), isFuncParamsAllConstant(params)), op(op), params(std::move(params)) {} -std::string fegen::FegenRightValue::OperatorCall::toString() { +std::string fegen::RightValue::OperatorCall::toString() { return "OperatorCall::toString"; } -std::string fegen::FegenRightValue::OperatorCall::toStringForTypedef() { +std::string fegen::RightValue::OperatorCall::toStringForTypedef() { return "OperatorCall::toStringForTypedef"; } -std::string fegen::FegenRightValue::OperatorCall::toStringForOpdef() { +std::string fegen::RightValue::OperatorCall::toStringForOpdef() { return "OperatorCall::toStringForOpdef"; } -std::string fegen::FegenRightValue::OperatorCall::toStringForCppKind() { +std::string fegen::RightValue::OperatorCall::toStringForCppKind() { return "OperatorCall::toStringForCppKind"; } -std::any fegen::FegenRightValue::OperatorCall::getContent() { return this; } +std::any fegen::RightValue::OperatorCall::getContent() { return this; } // class ExpressionTerminal -fegen::FegenRightValue::ExpressionTerminal::ExpressionTerminal( - fegen::FegenRightValue::LiteralKind kind, FegenType exprTy, +fegen::RightValue::ExpressionTerminal::ExpressionTerminal( + fegen::RightValue::LiteralKind kind, Type exprTy, bool ifConstexpr) : Expression(true, kind, exprTy, ifConstexpr) {} -std::string fegen::FegenRightValue::ExpressionTerminal::toString() { +std::string fegen::RightValue::ExpressionTerminal::toString() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionTerminal::toStringForTypedef() { +std::string fegen::RightValue::ExpressionTerminal::toStringForTypedef() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionTerminal::toStringForOpdef() { +std::string fegen::RightValue::ExpressionTerminal::toStringForOpdef() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -std::string fegen::FegenRightValue::ExpressionTerminal::toStringForCppKind() { +std::string fegen::RightValue::ExpressionTerminal::toStringForCppKind() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } // class PlaceHolder -fegen::FegenRightValue::PlaceHolder::PlaceHolder() - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::MONOSTATE, - fegen::FegenType::getPlaceHolder(), true) {} +fegen::RightValue::PlaceHolder::PlaceHolder() + : ExpressionTerminal(fegen::RightValue::LiteralKind::MONOSTATE, + fegen::Type::getPlaceHolder(), true) {} -std::any fegen::FegenRightValue::PlaceHolder::getContent() { +std::any fegen::RightValue::PlaceHolder::getContent() { return std::monostate(); } -std::string fegen::FegenRightValue::PlaceHolder::toString() { return ""; } +std::string fegen::RightValue::PlaceHolder::toString() { return ""; } // class IntegerLiteral -fegen::FegenRightValue::IntegerLiteral::IntegerLiteral(int content) - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::INT, - fegen::FegenType::getInt32Type(), true), +fegen::RightValue::IntegerLiteral::IntegerLiteral(int content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::INT, + fegen::Type::getInt32Type(), true), content(content) {} -fegen::FegenRightValue::IntegerLiteral::IntegerLiteral(long long int content, +fegen::RightValue::IntegerLiteral::IntegerLiteral(long long int content, size_t size) : ExpressionTerminal( - fegen::FegenRightValue::LiteralKind::INT, - fegen::FegenType::getIntegerType(fegen::FegenValue::get( - fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getByExpr( - std::make_shared( + fegen::RightValue::LiteralKind::INT, + fegen::Type::getIntegerType(fegen::Value::get( + fegen::Type::getInt32Type(), "size", + fegen::RightValue::getByExpr( + std::make_shared( size)))), true), content(content) {} -std::any fegen::FegenRightValue::IntegerLiteral::getContent() { +std::any fegen::RightValue::IntegerLiteral::getContent() { return this->content; } -std::string fegen::FegenRightValue::IntegerLiteral::toString() { +std::string fegen::RightValue::IntegerLiteral::toString() { return std::to_string(this->content); } // class FloatPointLiteral -fegen::FegenRightValue::FloatPointLiteral::FloatPointLiteral( +fegen::RightValue::FloatPointLiteral::FloatPointLiteral( long double content, size_t size) : ExpressionTerminal( - fegen::FegenRightValue::LiteralKind::FLOAT, - fegen::FegenType::getFloatPointType( - fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getInteger(size))), + fegen::RightValue::LiteralKind::FLOAT, + fegen::Type::getFloatPointType( + fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getInteger(size))), true), content(content) {} -std::any fegen::FegenRightValue::FloatPointLiteral::getContent() { +std::any fegen::RightValue::FloatPointLiteral::getContent() { return this->content; } -std::string fegen::FegenRightValue::FloatPointLiteral::toString() { +std::string fegen::RightValue::FloatPointLiteral::toString() { return std::to_string(this->content); } // class StringLiteral -fegen::FegenRightValue::StringLiteral::StringLiteral(std::string content) - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::STRING, - fegen::FegenType::getStringType(), true), +fegen::RightValue::StringLiteral::StringLiteral(std::string content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::STRING, + fegen::Type::getStringType(), true), content(content) {} -std::any fegen::FegenRightValue::StringLiteral::getContent() { +std::any fegen::RightValue::StringLiteral::getContent() { return this->content; } -std::string fegen::FegenRightValue::StringLiteral::toString() { +std::string fegen::RightValue::StringLiteral::toString() { std::string res; res.append("\""); res.append(this->content); @@ -891,7 +891,7 @@ std::string fegen::FegenRightValue::StringLiteral::toString() { // class TypeLiteral // Check params of content and return ture if params are all const expr. -inline bool isParamsConstant(fegen::FegenType &content) { +inline bool isParamsConstant(fegen::Type &content) { for (auto param : content.getParameters()) { if (!param->getExpr()->isConstexpr()) { return false; @@ -901,39 +901,39 @@ inline bool isParamsConstant(fegen::FegenType &content) { } // Get type of type literal. -fegen::FegenType getTypeLiteralType(fegen::FegenType &content) { +fegen::Type getTypeLiteralType(fegen::Type &content) { if (content.getTypeLevel() == 2) { - return fegen::FegenType::getMetaTemplateType(); + return fegen::Type::getMetaTemplateType(); } else if (content.getTypeLevel() == 3) { - return fegen::FegenType::getMetaType(); + return fegen::Type::getMetaType(); } else { - return fegen::FegenType::getPlaceHolder(); + return fegen::Type::getPlaceHolder(); } } -fegen::FegenRightValue::TypeLiteral::TypeLiteral(fegen::FegenType &content) - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::TYPE, +fegen::RightValue::TypeLiteral::TypeLiteral(fegen::Type &content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::TYPE, getTypeLiteralType(content), isParamsConstant(content)), content(content) {} -std::any fegen::FegenRightValue::TypeLiteral::getContent() { +std::any fegen::RightValue::TypeLiteral::getContent() { return this->content; } -std::string fegen::FegenRightValue::TypeLiteral::toString() { +std::string fegen::RightValue::TypeLiteral::toString() { return this->content.getTypeName(); } -std::string fegen::FegenRightValue::TypeLiteral::toStringForTypedef() { +std::string fegen::RightValue::TypeLiteral::toStringForTypedef() { return this->content.toStringForTypedef(); } -std::string fegen::FegenRightValue::TypeLiteral::toStringForOpdef() { +std::string fegen::RightValue::TypeLiteral::toStringForOpdef() { return this->content.toStringForOpdef(); } -std::string fegen::FegenRightValue::TypeLiteral::toStringForCppKind() { +std::string fegen::RightValue::TypeLiteral::toStringForCppKind() { return this->content.toStringForCppKind(); } @@ -941,7 +941,7 @@ std::string fegen::FegenRightValue::TypeLiteral::toStringForCppKind() { // Return ture if all Expressions in content are all true. bool isExpressionListConst( - std::vector> &content) { + std::vector> &content) { for (auto p : content) { if (!p->isConstexpr()) { return false; @@ -951,17 +951,17 @@ bool isExpressionListConst( return true; } -fegen::FegenRightValue::ListLiteral::ListLiteral( +fegen::RightValue::ListLiteral::ListLiteral( std::vector> &content) - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::VECTOR, + : ExpressionTerminal(fegen::RightValue::LiteralKind::VECTOR, content[0]->exprType, isExpressionListConst(content)), content(content) {} -std::any fegen::FegenRightValue::ListLiteral::getContent() { +std::any fegen::RightValue::ListLiteral::getContent() { return this->content; } -std::string fegen::FegenRightValue::ListLiteral::toString() { +std::string fegen::RightValue::ListLiteral::toString() { std::string res; res.append("["); for (size_t i = 0; i <= this->content.size() - 1; i++) { @@ -974,7 +974,7 @@ std::string fegen::FegenRightValue::ListLiteral::toString() { return res; } -std::string fegen::FegenRightValue::ListLiteral::toStringForTypedef() { +std::string fegen::RightValue::ListLiteral::toStringForTypedef() { std::string res; res.append("["); for (size_t i = 0; i <= this->content.size() - 1; i++) { @@ -987,7 +987,7 @@ std::string fegen::FegenRightValue::ListLiteral::toStringForTypedef() { return res; } -std::string fegen::FegenRightValue::ListLiteral::toStringForOpdef() { +std::string fegen::RightValue::ListLiteral::toStringForOpdef() { std::string res; res.append("["); for (size_t i = 0; i <= this->content.size() - 1; i++) { @@ -1001,207 +1001,211 @@ std::string fegen::FegenRightValue::ListLiteral::toStringForOpdef() { } // class LeftValue -fegen::FegenRightValue::LeftValue::LeftValue(fegen::FegenValue *content) - : ExpressionTerminal(fegen::FegenRightValue::LiteralKind::LEFT_VAR, +fegen::RightValue::LeftValue::LeftValue(fegen::Value *content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::LEFT_VAR, content->getType(), content->getExpr()->isConstexpr()), content(content) {} -std::any fegen::FegenRightValue::LeftValue::getContent() { +std::any fegen::RightValue::LeftValue::getContent() { return this->content; } -std::string fegen::FegenRightValue::LeftValue::toString() { +std::string fegen::RightValue::LeftValue::toString() { return this->content->getName(); } // class FegenRightValue -fegen::FegenRightValue::FegenRightValue( - std::shared_ptr content) +fegen::RightValue::RightValue( + std::shared_ptr content) : content(content) {} -fegen::FegenRightValue::LiteralKind fegen::FegenRightValue::getLiteralKind() { +fegen::RightValue::LiteralKind fegen::RightValue::getLiteralKind() { return this->content->getKind(); } -std::string fegen::FegenRightValue::toString() { +std::string fegen::RightValue::toString() { return this->content->toString(); } -std::string fegen::FegenRightValue::toStringForTypedef() { +std::string fegen::RightValue::toStringForTypedef() { return this->content->toStringForTypedef(); } -std::string fegen::FegenRightValue::toStringForOpdef() { +std::string fegen::RightValue::toStringForOpdef() { return this->content->toStringForOpdef(); } -std::string fegen::FegenRightValue::toStringForCppKind() { +std::string fegen::RightValue::toStringForCppKind() { return this->content->toStringForCppKind(); } -std::any fegen::FegenRightValue::getContent() { +std::any fegen::RightValue::getContent() { return this->content->getContent(); } -fegen::FegenType &fegen::FegenRightValue::getType() { +fegen::Type &fegen::RightValue::getType() { return this->content->getType(); } -std::shared_ptr -fegen::FegenRightValue::getExpr() { +std::shared_ptr +fegen::RightValue::getExpr() { return this->content; } -fegen::FegenRightValue fegen::FegenRightValue::getPlaceHolder() { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getPlaceHolder()); +fegen::RightValue fegen::RightValue::getPlaceHolder() { + return fegen::RightValue( + fegen::RightValue::Expression::getPlaceHolder()); } -fegen::FegenRightValue fegen::FegenRightValue::getInteger(long long int content, +fegen::RightValue fegen::RightValue::getInteger(long long int content, size_t size) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getInteger(content, size)); + return fegen::RightValue( + fegen::RightValue::Expression::getInteger(content, size)); } -fegen::FegenRightValue -fegen::FegenRightValue::getFloatPoint(long double content, size_t size) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getFloatPoint(content, size)); +fegen::RightValue +fegen::RightValue::getFloatPoint(long double content, size_t size) { + return fegen::RightValue( + fegen::RightValue::Expression::getFloatPoint(content, size)); } -fegen::FegenRightValue fegen::FegenRightValue::getString(std::string content) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getString(content)); +fegen::RightValue fegen::RightValue::getString(std::string content) { + return fegen::RightValue( + fegen::RightValue::Expression::getString(content)); } -fegen::FegenRightValue -fegen::FegenRightValue::getType(fegen::FegenType &content) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getType(content)); +fegen::RightValue +fegen::RightValue::getType(fegen::Type &content) { + return fegen::RightValue( + fegen::RightValue::Expression::getType(content)); } -fegen::FegenRightValue fegen::FegenRightValue::getList( - std::vector> &content) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getList(content)); +fegen::RightValue fegen::RightValue::getList( + std::vector> &content) { + return fegen::RightValue( + fegen::RightValue::Expression::getList(content)); } -fegen::FegenRightValue -fegen::FegenRightValue::getLeftValue(fegen::FegenValue *content) { - return fegen::FegenRightValue( - fegen::FegenRightValue::Expression::getLeftValue(content)); +fegen::RightValue +fegen::RightValue::getLeftValue(fegen::Value *content) { + return fegen::RightValue( + fegen::RightValue::Expression::getLeftValue(content)); } -fegen::FegenRightValue fegen::FegenRightValue::getByExpr( - std::shared_ptr expr) { +fegen::RightValue fegen::RightValue::getByExpr( + std::shared_ptr expr) { assert(expr != nullptr); - return fegen::FegenRightValue(expr); + return fegen::RightValue(expr); } // class FegenValue -fegen::FegenValue::FegenValue(fegen::FegenType type, std::string name, - fegen::FegenRightValue content) +fegen::Value::Value(fegen::Type type, std::string name, + fegen::RightValue content) : type(std::move(type)), name(std::move(name)), content(std::move(content)) {} -fegen::FegenValue::FegenValue(const fegen::FegenValue &rhs) +fegen::Value::Value(const fegen::Value &rhs) : type(rhs.type), name(rhs.name), content(rhs.content) {} -fegen::FegenValue::FegenValue(fegen::FegenValue &&rhs) +fegen::Value::Value(fegen::Value &&rhs) : type(std::move(rhs.type)), name(std::move(rhs.name)), content(std::move(rhs.content)) {} -fegen::FegenValue *fegen::FegenValue::get(fegen::FegenType type, +fegen::Value *fegen::Value::get(fegen::Type type, std::string name, - FegenRightValue content) { - return new fegen::FegenValue(std::move(type), std::move(name), + RightValue content) { + return new fegen::Value(std::move(type), std::move(name), std::move(content)); } -fegen::FegenType &fegen::FegenValue::getType() { return this->type; } +fegen::Type &fegen::Value::getType() { return this->type; } -std::string fegen::FegenValue::getName() { return this->name; } +std::string fegen::Value::getName() { return this->name; } -fegen::FegenRightValue::LiteralKind fegen::FegenValue::getContentKind() { +void fegen::Value::setContent(fegen::RightValue content) { + this->content = content; +} + +fegen::RightValue::LiteralKind fegen::Value::getContentKind() { return this->content.getLiteralKind(); } -std::string fegen::FegenValue::getContentString() { +std::string fegen::Value::getContentString() { return this->content.toString(); } -std::string fegen::FegenValue::getContentStringForTypedef() { +std::string fegen::Value::getContentStringForTypedef() { return this->content.toStringForTypedef(); } -std::string fegen::FegenValue::getContentStringForOpdef() { +std::string fegen::Value::getContentStringForOpdef() { return this->content.toStringForOpdef(); } -std::string fegen::FegenValue::getContentStringForCppKind() { +std::string fegen::Value::getContentStringForCppKind() { return this->content.toStringForCppKind(); } -std::shared_ptr -fegen::FegenValue::getExpr() { +std::shared_ptr +fegen::Value::getExpr() { return this->content.getExpr(); } -fegen::FegenRule::FegenRule(std::string content, fegen::FegenNode *src, +fegen::ParserRule::ParserRule(std::string content, fegen::ParserNode *src, antlr4::ParserRuleContext *ctx) : content(content), src(src), ctx(ctx) {} -fegen::FegenRule *fegen::FegenRule::get(std::string content, - fegen::FegenNode *src, +fegen::ParserRule *fegen::ParserRule::get(std::string content, + fegen::ParserNode *src, antlr4::ParserRuleContext *ctx) { - return new fegen::FegenRule(content, src, ctx); + return new fegen::ParserRule(content, src, ctx); } -llvm::StringRef fegen::FegenRule::getContent() { return this->content; } +llvm::StringRef fegen::ParserRule::getContent() { return this->content; } -bool fegen::FegenRule::addInput(fegen::FegenValue input) { +bool fegen::ParserRule::addInput(fegen::Value input) { auto name = input.getName(); if (this->inputs.count(name) == 0) { return false; } - this->inputs.insert({name, new fegen::FegenValue(input)}); + this->inputs.insert({name, new fegen::Value(input)}); return true; } -bool fegen::FegenRule::addReturn(fegen::FegenValue output) { +bool fegen::ParserRule::addReturn(fegen::Value output) { auto name = output.getName(); if (this->returns.count(name) == 0) { return false; } - this->returns.insert({name, new fegen::FegenValue(output)}); + this->returns.insert({name, new fegen::Value(output)}); return true; } -void fegen::FegenRule::setSrc(FegenNode *src) { this->src = src; } +void fegen::ParserRule::setSrc(ParserNode *src) { this->src = src; } -fegen::FegenNode::FegenNode(std::vector &&rules, +fegen::ParserNode::ParserNode(std::vector &&rules, antlr4::ParserRuleContext *ctx, - fegen::FegenNode::NodeType ntype) + fegen::ParserNode::NodeType ntype) : rules(rules), ctx(ctx), ntype(ntype) {} -fegen::FegenNode *fegen::FegenNode::get(std::vector rules, +fegen::ParserNode *fegen::ParserNode::get(std::vector rules, antlr4::ParserRuleContext *ctx, - fegen::FegenNode::NodeType ntype) { - return new fegen::FegenNode(std::move(rules), ctx, ntype); + fegen::ParserNode::NodeType ntype) { + return new fegen::ParserNode(std::move(rules), ctx, ntype); } -fegen::FegenNode *fegen::FegenNode::get(antlr4::ParserRuleContext *ctx, - fegen::FegenNode::NodeType ntype) { - std::vector rules; - return new fegen::FegenNode(std::move(rules), ctx, ntype); +fegen::ParserNode *fegen::ParserNode::get(antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) { + std::vector rules; + return new fegen::ParserNode(std::move(rules), ctx, ntype); } -void fegen::FegenNode::addFegenRule(fegen::FegenRule *rule) { +void fegen::ParserNode::addFegenRule(fegen::ParserRule *rule) { this->rules.push_back(rule); } -fegen::FegenNode::~FegenNode() { +fegen::ParserNode::~ParserNode() { for (auto rule : this->rules) { delete rule; } } -void fegen::FegenManager::setModuleName(std::string name) { +void fegen::Manager::setModuleName(std::string name) { this->moduleName = name; } @@ -1217,7 +1221,7 @@ std::string getChildrenText(antlr4::tree::ParseTree *ctx) { return ruleText; } -fegen::FegenManager::FegenManager() {} +fegen::Manager::Manager() {} namespace fegen { @@ -1261,20 +1265,20 @@ class Emitter { class StmtGenerator : FegenParserBaseVisitor { private: - FegenManager &manager; + Manager &manager; Emitter &emitter; public: StmtGenerator(Emitter &emitter) - : manager(FegenManager::getManager()), emitter(emitter) {} + : manager(Manager::getManager()), emitter(emitter) {} std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - auto var = manager.getStmtContent(ctx->identifier()); + auto var = manager.getStmtContent(ctx->identifier()); switch (var->getType().getTypeKind()) { - case fegen::FegenType::TypeKind::CPP: { + case fegen::Type::TypeKind::CPP: { this->emitter << var->getType().toStringForCppKind() << " " << var->getName(); if (ctx->expression()) { - auto expr = this->manager.getStmtContent( + auto expr = this->manager.getStmtContent( ctx->expression()); this->emitter << " = " << expr->toStringForCppKind(); } @@ -1282,10 +1286,10 @@ class StmtGenerator : FegenParserBaseVisitor { this->emitter.newLine(); break; } - case fegen::FegenType::TypeKind::ATTRIBUTE: { + case fegen::Type::TypeKind::ATTRIBUTE: { break; } - case fegen::FegenType::TypeKind::OPERAND: { + case fegen::Type::TypeKind::OPERAND: { break; } } @@ -1305,7 +1309,7 @@ class StmtGenerator : FegenParserBaseVisitor { } // namespace fegen -void fegen::FegenManager::emitG4() { +void fegen::Manager::emitG4() { std::ofstream fileStream; fileStream.open(this->moduleName + ".g4"); fegen::Emitter emitter(fileStream); @@ -1333,7 +1337,7 @@ void fegen::FegenManager::emitG4() { } // TODO: emit to file -void fegen::FegenManager::emitTypeDefination() { +void fegen::Manager::emitTypeDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Types.td"); fegen::Emitter emitter(fileStream); @@ -1429,7 +1433,7 @@ void fegen::FegenManager::emitTypeDefination() { fileStream.close(); } -void fegen::FegenManager::emitOpDefination() { +void fegen::Manager::emitOpDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Ops.td"); fegen::Emitter emitter(fileStream); @@ -1520,7 +1524,7 @@ void fegen::FegenManager::emitOpDefination() { fileStream.close(); } -void fegen::FegenManager::emitDialectDefination() { +void fegen::Manager::emitDialectDefination() { std::ofstream fileStream; fileStream.open(this->moduleName + "Dialect.td"); fegen::Emitter emitter(fileStream); @@ -1570,122 +1574,122 @@ void fegen::FegenManager::emitDialectDefination() { fileStream.close(); } -void fegen::FegenManager::emitTdFiles() { +void fegen::Manager::emitTdFiles() { this->emitDialectDefination(); this->emitTypeDefination(); this->emitOpDefination(); } -void fegen::FegenManager::initbuiltinTypes() { +void fegen::Manager::initbuiltinTypes() { // placeholder type - auto placeholderTypeDefination = fegen::FegenTypeDefination::get( + auto placeholderTypeDefination = fegen::TypeDefination::get( "fegen_builtin", FEGEN_PLACEHOLDER, {}, nullptr, false); this->typeDefMap.insert({FEGEN_PLACEHOLDER, placeholderTypeDefination}); // Type this->typeDefMap.insert( - {FEGEN_TYPE, fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_TYPE, + {FEGEN_TYPE, fegen::TypeDefination::get("fegen_builtin", FEGEN_TYPE, {}, nullptr, false)}); // TypeTemplate this->typeDefMap.insert( {FEGEN_TYPETEMPLATE, - fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_TYPETEMPLATE, {}, + fegen::TypeDefination::get("fegen_builtin", FEGEN_TYPETEMPLATE, {}, nullptr, false)}); // recursive define Integer Type // Integer>> - auto intTypeDefination = fegen::FegenTypeDefination::get( + auto intTypeDefination = fegen::TypeDefination::get( "fegen_builtin", FEGEN_INTEGER, {}, nullptr, false); - auto intType = fegen::FegenType( - fegen::FegenType::TypeKind::CPP, - {fegen::FegenValue::get(fegen::FegenType::getPlaceHolder(), "size", - fegen::FegenRightValue::getPlaceHolder())}, + auto intType = fegen::Type( + fegen::Type::TypeKind::CPP, + {fegen::Value::get(fegen::Type::getPlaceHolder(), "size", + fegen::RightValue::getPlaceHolder())}, intTypeDefination, false); // parameters of Integer is int32(Integer<32>) - intTypeDefination->parameters.push_back(fegen::FegenValue::get( - intType, "size", fegen::FegenRightValue::getPlaceHolder())); + intTypeDefination->parameters.push_back(fegen::Value::get( + intType, "size", fegen::RightValue::getPlaceHolder())); this->typeDefMap.insert({FEGEN_INTEGER, intTypeDefination}); // FloatPoint this->typeDefMap.insert( {FEGEN_FLOATPOINT, - fegen::FegenTypeDefination::get( + fegen::TypeDefination::get( "fegen_builtin", FEGEN_FLOATPOINT, - {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getPlaceHolder())}, + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Char this->typeDefMap.insert( - {FEGEN_CHAR, fegen::FegenTypeDefination::get("fegen_builtin", FEGEN_CHAR, + {FEGEN_CHAR, fegen::TypeDefination::get("fegen_builtin", FEGEN_CHAR, {}, nullptr, false)}); // String this->typeDefMap.insert( - {FEGEN_STRING, fegen::FegenTypeDefination::get( + {FEGEN_STRING, fegen::TypeDefination::get( "fegen_builtin", FEGEN_STRING, {}, nullptr, false)}); // Vector this->typeDefMap.insert( {FEGEN_VECTOR, - fegen::FegenTypeDefination::get( + fegen::TypeDefination::get( "fegen_builtin", FEGEN_VECTOR, - {fegen::FegenValue::get(fegen::FegenType::getInt32Type(), "size", - fegen::FegenRightValue::getPlaceHolder()), - fegen::FegenValue::get(fegen::FegenType::getMetaType(), + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getPlaceHolder()), + fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::FegenRightValue::getPlaceHolder())}, + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // List (this should be ahead of Tensor and Any Type defination) this->typeDefMap.insert( - {FEGEN_LIST, fegen::FegenTypeDefination::get( + {FEGEN_LIST, fegen::TypeDefination::get( "fegen_builtin", FEGEN_LIST, - {fegen::FegenValue::get( - fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::getPlaceHolder())}, + {fegen::Value::get( + fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Tensor this->typeDefMap.insert( {FEGEN_TENSOR, - fegen::FegenTypeDefination::get( + fegen::TypeDefination::get( "fegen_builtin", FEGEN_TENSOR, - {fegen::FegenValue::get( - fegen::FegenType::getListType(fegen::FegenType::getInt32Type()), - "shape", fegen::FegenRightValue::getPlaceHolder()), - fegen::FegenValue::get(fegen::FegenType::getMetaType(), + {fegen::Value::get( + fegen::Type::getListType(fegen::Type::getInt32Type()), + "shape", fegen::RightValue::getPlaceHolder()), + fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::FegenRightValue::getPlaceHolder())}, + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Optional this->typeDefMap.insert( - {FEGEN_OPTINAL, fegen::FegenTypeDefination::get( + {FEGEN_OPTINAL, fegen::TypeDefination::get( "fegen_builtin", FEGEN_OPTINAL, - {fegen::FegenValue::get( - fegen::FegenType::getMetaType(), "elementType", - fegen::FegenRightValue::getPlaceHolder())}, + {fegen::Value::get( + fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Any this->typeDefMap.insert( {FEGEN_ANY, - fegen::FegenTypeDefination::get( + fegen::TypeDefination::get( "fegen_builtin", FEGEN_ANY, - {fegen::FegenValue::get( - fegen::FegenType::getListType(fegen::FegenType::getMetaType()), - "elementType", fegen::FegenRightValue::getPlaceHolder())}, + {fegen::Value::get( + fegen::Type::getListType(fegen::Type::getMetaType()), + "elementType", fegen::RightValue::getPlaceHolder())}, nullptr, false)}); } -fegen::FegenTypeDefination * -fegen::FegenManager::getTypeDefination(std::string name) { +fegen::TypeDefination * +fegen::Manager::getTypeDefination(std::string name) { return this->typeDefMap[name]; } -bool fegen::FegenManager::addTypeDefination(fegen::FegenTypeDefination *tyDef) { +bool fegen::Manager::addTypeDefination(fegen::TypeDefination *tyDef) { if (this->typeDefMap.count(tyDef->name) != 0) { return false; } @@ -1693,12 +1697,12 @@ bool fegen::FegenManager::addTypeDefination(fegen::FegenTypeDefination *tyDef) { return true; } -fegen::FegenOperation * -fegen::FegenManager::getOperationDefination(std::string name) { +fegen::Operation * +fegen::Manager::getOperationDefination(std::string name) { return this->operationMap[name]; } -bool fegen::FegenManager::addOperationDefination(fegen::FegenOperation *opDef) { +bool fegen::Manager::addOperationDefination(fegen::Operation *opDef) { if (this->operationMap.count(opDef->getOpName()) != 0) { return false; } @@ -1706,28 +1710,28 @@ bool fegen::FegenManager::addOperationDefination(fegen::FegenOperation *opDef) { return true; } -void fegen::FegenManager::addStmtContent(antlr4::ParserRuleContext *ctx, +void fegen::Manager::addStmtContent(antlr4::ParserRuleContext *ctx, std::any content) { this->stmtContentMap.insert({ctx, content}); } -fegen::FegenManager &fegen::FegenManager::getManager() { - static fegen::FegenManager fmg; +fegen::Manager &fegen::Manager::getManager() { + static fegen::Manager fmg; return fmg; } -fegen::FegenManager::~FegenManager() { +fegen::Manager::~Manager() { // release nodes for (auto node_pair : this->nodeMap) { delete node_pair.second; } } -fegen::FegenType fegen::inferenceType( - std::vector> operands, +fegen::Type fegen::inferenceType( + std::vector> operands, fegen::FegenOperator op) { // TODO: infer type - return fegen::FegenType::getInt32Type(); + return fegen::Type::getInt32Type(); } namespace fegen { @@ -1737,7 +1741,7 @@ namespace fegen { // }; } -void fegen::FegenManager::emitBuiltinFunction() { +void fegen::Manager::emitBuiltinFunction() { Emitter emitter(std::cout); for (auto function_pair : this->functionMap) { auto functionName = function_pair.first; diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp index 316b0ae1ad..5dc096eb8a 100644 --- a/frontend/FrontendGen/lib/FegenVisitor.cpp +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -1,12 +1,12 @@ #include "FegenVisitor.h" -bool fegen::checkParams(std::vector &expected, - std::vector &actual) { +bool fegen::checkParams(std::vector &expected, + std::vector &actual) { return true; } bool fegen::checkListLiteral( - std::vector> + std::vector> &listLiteral) { return true; } \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Scope.cpp b/frontend/FrontendGen/lib/Scope.cpp index ef0a1dbe72..443536f11e 100644 --- a/frontend/FrontendGen/lib/Scope.cpp +++ b/frontend/FrontendGen/lib/Scope.cpp @@ -24,11 +24,11 @@ fegen::FegenScope::FegenScope(unsigned int scopeId, fegen::FegenScope *parentScope) : scopeId(scopeId), parentScope(parentScope) {} -fegen::FegenTypeDefination *fegen::FegenScope::findTypeDef(std::string name) { +fegen::TypeDefination *fegen::FegenScope::findTypeDef(std::string name) { return this->typeTable.get(name); } -void fegen::FegenScope::addTypeDef(FegenTypeDefination *tyDef) { +void fegen::FegenScope::addTypeDef(TypeDefination *tyDef) { this->typeTable.add(tyDef->getName(), tyDef); } @@ -36,11 +36,11 @@ bool fegen::FegenScope::isExistTypeDef(std::string name) { return this->typeTable.exist(name); } -fegen::FegenValue *fegen::FegenScope::findVar(std::string name) { +fegen::Value *fegen::FegenScope::findVar(std::string name) { return this->varTable.get(name); } -void fegen::FegenScope::addVar(fegen::FegenValue *var) { +void fegen::FegenScope::addVar(fegen::Value *var) { this->varTable.add(var->getName(), var); } @@ -77,7 +77,7 @@ void fegen::ScopeStack::popScope() { this->scopeStack.pop(); this->currentScope = this->scopeStack.top(); } -bool fegen::ScopeStack::attemptAddVar(fegen::FegenValue *var) { +bool fegen::ScopeStack::attemptAddVar(fegen::Value *var) { if (this->currentScope->isExistVar(var->getName())) { return false; } @@ -85,7 +85,7 @@ bool fegen::ScopeStack::attemptAddVar(fegen::FegenValue *var) { return true; } -fegen::FegenValue *fegen::ScopeStack::attemptFindVar(std::string name) { +fegen::Value *fegen::ScopeStack::attemptFindVar(std::string name) { auto p = this->currentScope; while (p != nullptr) { if (p->isExistVar(name)) { @@ -96,7 +96,7 @@ fegen::FegenValue *fegen::ScopeStack::attemptFindVar(std::string name) { return nullptr; } -bool fegen::ScopeStack::attemptAddTypeDef(fegen::FegenTypeDefination *tyDef) { +bool fegen::ScopeStack::attemptAddTypeDef(fegen::TypeDefination *tyDef) { if (this->currentScope->isExistTypeDef(tyDef->getName())) { return false; } @@ -104,7 +104,7 @@ bool fegen::ScopeStack::attemptAddTypeDef(fegen::FegenTypeDefination *tyDef) { return true; } -fegen::FegenTypeDefination * +fegen::TypeDefination * fegen::ScopeStack::attemptFindTypeDef(std::string name) { auto p = this->currentScope; while (p != nullptr) { From 0968b60d7934a57ed963eeeab7b63d48b3f8deb7 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Mon, 15 Jul 2024 13:46:04 +0000 Subject: [PATCH 06/17] [FrontendGen] add function codegen --- examples/FrontendGen/function.fegen | 43 +++------ frontend/FrontendGen/include/FegenVisitor.h | 41 +++++--- frontend/FrontendGen/lib/FegenManager.cpp | 102 ++++++++++++++++++-- frontend/FrontendGen/lib/FegenParser.g4 | 10 +- 4 files changed, 141 insertions(+), 55 deletions(-) diff --git a/examples/FrontendGen/function.fegen b/examples/FrontendGen/function.fegen index e2fc6e7bfa..87e33db88b 100644 --- a/examples/FrontendGen/function.fegen +++ b/examples/FrontendGen/function.fegen @@ -1,38 +1,17 @@ fegen toy double stod(string numStr){ - double res = 0; - int index; - int i; - for(i = 0; i <= len(numStr)-1; i=i+1){ - char c = numStr[0]; - int charNum; - if(c == '0'){ - charNum = 0; - }else if (c == '1'){ - charNum = 1; - }else if (c == '2'){ - charNum = 2; - }else if (c == '3'){ - charNum = 3; - }else if (c == '4'){ - charNum = 4; - }else if (c == '5'){ - charNum = 5; - }else if (c == '6'){ - charNum = 6; - }else if (c == '7'){ - charNum = 7; - }else if (c == '8'){ - charNum = 8; - }else if (c == '9'){ - charNum = 9; - }else if (c == '.'){ - index = i; - } - res = res * 10; - res = res + charNum; + double res = 0.0; + int index = 1; + if(c == '0'){ + int charNum = 0; + int intNum = 1; + intNum = 1; + }else if (c == '1'){ + int charNum = 1; + }else { + int charNum = 2; } - res = res * 0.1**(len(numStr) - 1 - index); + return res; } \ No newline at end of file diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 454fc5ba0e..79e7243982 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -655,8 +655,8 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->expression())); if (!fegen::FegenType::isSameType(&varType, &varContent->exprType)) { std::cerr << "The variabel \"" << varName << "\" need \"" - << varType.getTypeName() << " \" type rightvalue. But now is " << varContent->exprType.getTypeName() - << std::endl; + << varType.getTypeName() << " \" type rightvalue. But now is " + << varContent->exprType.getTypeName() << std::endl; exit(0); return nullptr; } @@ -694,11 +694,11 @@ class FegenVisitor : public FegenParserBaseVisitor { auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasFunc = manager.functionMap.find(functionName); - if(hasFunc == manager.functionMap.end()){ - std::cerr << "The called function \"" << functionName - << "\" is not exist." << std::endl; - exit(0); - return nullptr; + if (hasFunc == manager.functionMap.end()) { + std::cerr << "The called function \"" << functionName + << "\" is not exist." << std::endl; + exit(0); + return nullptr; } function = hasFunc->second; auto paramsNum = ctx->expression().size(); @@ -739,20 +739,31 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + for (size_t i = 0; i < ctx->ifBlock().size(); i++) { + this->visit(ctx->ifBlock(i)); + } + + if (ctx->elseBlock()) { + this->visit(ctx->elseBlock()); + } + return nullptr; + } + + std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { sstack.pushScope(); - this->visit(ctx->expression(0)); - this->visit(ctx->statementBlock(0)); - for (size_t i = 1; i <= ctx->expression().size() - 1; i++) { - this->visit(ctx->expression(i)); - this->visit(ctx->statementBlock(i)); - } - if (ctx->statementBlock(ctx->expression().size() + 1)) - this->visit(ctx->statementBlock(ctx->expression().size() + 1)); + this->visit(ctx->expression()); + this->visit(ctx->statementBlock()); sstack.popScope(); return nullptr; } + std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { + sstack.pushScope(); + this->visit(ctx->statementBlock()); + sstack.popScope(); + } + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { sstack.pushScope(); if (ctx->varDeclStmt()) { diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 29bc9bf993..2cba2127d8 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -778,8 +778,7 @@ fegen::FegenRightValue::ExpressionTerminal * fegen::FegenRightValue::ExpressionTerminal::get(fegen::FegenValue *content) { return new fegen::FegenRightValue::ExpressionTerminal( content, fegen::FegenRightValue::LiteralKind::LEFT_VAR, - content->getType(), - content->getExpr()->isConstexpr()); + content->getType(), content->getExpr()->isConstexpr()); } // class FegenRightValue @@ -1451,13 +1450,68 @@ fegen::inferenceType(std::vector operands, } namespace fegen { -// class StmtVisitor : public FegenParserBaseVisitor{ -// public: -// }; +class StmtVisitor : public FegenParserBaseVisitor { +private: + FegenManager &manager; -} +public: + StmtVisitor() : manager(FegenManager::getManager()) {} + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + Emitter emitter(std::cout); + auto varDecl = + std::any_cast(manager.stmtContentMap[ctx]); + emitter << varDecl->getType().getTypeName() << " " << varDecl->getName() + << " = " << varDecl->getContentString() << ";"; + emitter.newLine(); + return nullptr; + } + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { + Emitter emitter(std::cout); + auto assignStmt = + std::any_cast(manager.stmtContentMap[ctx]); + emitter << assignStmt->getName() << " = " << assignStmt->getContentString() + << ";"; + emitter.newLine(); + return nullptr; + } + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { + Emitter emitter(std::cout); + auto function = + std::any_cast(manager.stmtContentMap[ctx]); + emitter << function->getName() << " ("; + for (auto para : function->getInputTypeList()) { + emitter << para->getName(); + if (para != function->getInputTypeList().back()) + emitter << ", "; + } + // TODO:补充functioncall作为操作数的情况 + emitter << ");"; + emitter.newLine(); + return nullptr; + } + std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { + Emitter emitter(std::cout); + auto expr = std::any_cast( + manager.stmtContentMap[ctx]); + + emitter << "if (" << expr->toString() << "){"; + emitter.newLine(); + emitter.tab(); + return nullptr; + } + // TODO: 支持for循环 + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { + Emitter emitter(std::cout); + emitter << "for ("; + return nullptr; + } +}; + +} // namespace fegen void fegen::FegenManager::emitBuiltinFunction() { Emitter emitter(std::cout); + fegen::StmtVisitor visitor; + for (auto function_pair : this->functionMap) { auto functionName = function_pair.first; auto function = function_pair.second; @@ -1473,7 +1527,41 @@ void fegen::FegenManager::emitBuiltinFunction() { emitter.newLine(); emitter.tab(); // TODO::function body - + auto blockNum = 0; + auto expressionNum = 1; + FegenParser::IfBlockContext *ifBlock = nullptr; + for (auto stmt : stmtContentMap) { + visitor.visit(stmt.first); + if (stmt.second.type().name() == "string") { + if (std::any_cast(stmt.second) == "IF") { + ifBlock = std::any_cast(stmt.first); + // blockNum = ifBlock->statement().size(); + continue; + } else if (std::any_cast(stmt.second) == "FOR") { + // TODO: 支持for循环 + continue; + } + } + if (blockNum > 0) + blockNum--; + if (blockNum > 1) { + emitter.shiftTab(); + // emitter << "} else if (" << + // ifStmt->expression(expressionNum)->toString() << "){"; + emitter.newLine(); + emitter.tab(); + expressionNum++; + } else if (blockNum == 1) { + emitter.shiftTab(); + emitter << "} else {"; + emitter.newLine(); + emitter.tab(); + } else if (blockNum == 0) { + emitter.shiftTab(); + emitter << "}"; + expressionNum = 1; + } + } emitter.shiftTab(); emitter.newLine(); emitter << "}"; diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 index 048d104a0e..5e1cc4a54f 100644 --- a/frontend/FrontendGen/lib/FegenParser.g4 +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -289,7 +289,15 @@ opResTypeParams ; ifStmt - : IF LeftParen expression RightParen statementBlock (ELSE IF LeftParen expression RightParen statementBlock)* (ELSE statementBlock)? + : ifBlock (ELSE ifBlock)* (elseBlock)? + ; + +ifBlock: + IF LeftParen expression RightParen statementBlock + ; + +elseBlock + : ELSE statementBlock ; forStmt From 0e0a16da7cc515189303a65438736016a636c224 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Mon, 15 Jul 2024 13:57:20 +0000 Subject: [PATCH 07/17] [FrontendGen] Merge branch 'main' of https://github.com/CBalaa/buddy-mlir into fegen-cww --- frontend/FrontendGen/include/FegenVisitor.h | 2 +- frontend/FrontendGen/lib/FegenManager.cpp | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 197d28f873..04df570393 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -645,7 +645,7 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->funcParams())); this->visit(ctx->statementBlock()); - fegen::FegenFunction* function = fegen::FegenFunction::get(functionName, functionParams, &returnType); + fegen::Function* function = fegen::Function::get(functionName, functionParams, &returnType); manager.functionMap.insert(std::pair{functionName, function}); sstack.popScope(); return nullptr; diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index cd4a47d222..b1514a2795 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1738,14 +1738,14 @@ namespace fegen { class StmtVisitor : public FegenParserBaseVisitor { private: - FegenManager &manager; + Manager &manager; public: - StmtVisitor() : manager(FegenManager::getManager()) {} + StmtVisitor() : manager(Manager::getManager()) {} std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { Emitter emitter(std::cout); auto varDecl = - std::any_cast(manager.stmtContentMap[ctx]); + std::any_cast(manager.stmtContentMap[ctx]); emitter << varDecl->getType().getTypeName() << " " << varDecl->getName() << " = " << varDecl->getContentString() << ";"; emitter.newLine(); @@ -1754,7 +1754,7 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { Emitter emitter(std::cout); auto assignStmt = - std::any_cast(manager.stmtContentMap[ctx]); + std::any_cast(manager.stmtContentMap[ctx]); emitter << assignStmt->getName() << " = " << assignStmt->getContentString() << ";"; emitter.newLine(); @@ -1763,7 +1763,7 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { Emitter emitter(std::cout); auto function = - std::any_cast(manager.stmtContentMap[ctx]); + std::any_cast(manager.stmtContentMap[ctx]); emitter << function->getName() << " ("; for (auto para : function->getInputTypeList()) { emitter << para->getName(); @@ -1777,7 +1777,7 @@ class StmtVisitor : public FegenParserBaseVisitor { } std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { Emitter emitter(std::cout); - auto expr = std::any_cast( + auto expr = std::any_cast( manager.stmtContentMap[ctx]); emitter << "if (" << expr->toString() << "){"; @@ -1794,7 +1794,7 @@ class StmtVisitor : public FegenParserBaseVisitor { }; } // namespace fegen -void fegen::FegenManager::emitBuiltinFunction() { +void fegen::Manager::emitBuiltinFunction() { Emitter emitter(std::cout); fegen::StmtVisitor visitor; From 5a8614702c1fba9ac7399e9d06003f96732c606e Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Wed, 17 Jul 2024 07:04:28 +0000 Subject: [PATCH 08/17] [FrontendGen] update function codegen --- examples/FrontendGen/function.fegen | 8 +- frontend/FrontendGen/frontendgen.cpp | 1 + frontend/FrontendGen/include/FegenManager.h | 2 +- frontend/FrontendGen/include/FegenVisitor.h | 222 +++++++++----------- frontend/FrontendGen/lib/FegenManager.cpp | 88 ++++---- frontend/FrontendGen/lib/FegenParser.g4 | 1 - 6 files changed, 144 insertions(+), 178 deletions(-) diff --git a/examples/FrontendGen/function.fegen b/examples/FrontendGen/function.fegen index 87e33db88b..35da97395d 100644 --- a/examples/FrontendGen/function.fegen +++ b/examples/FrontendGen/function.fegen @@ -1,13 +1,13 @@ fegen toy double stod(string numStr){ - double res = 0.0; - int index = 1; - if(c == '0'){ + float res = 0.0; + int c = 1; + if(c == 0){ int charNum = 0; int intNum = 1; intNum = 1; - }else if (c == '1'){ + }else if (c == 1){ int charNum = 1; }else { int charNum = 2; diff --git a/frontend/FrontendGen/frontendgen.cpp b/frontend/FrontendGen/frontendgen.cpp index 51b29d46f6..8490405f48 100644 --- a/frontend/FrontendGen/frontendgen.cpp +++ b/frontend/FrontendGen/frontendgen.cpp @@ -67,5 +67,6 @@ int main(int argc, char *argv[]) { visitor.emitTypeDefination(); visitor.emitDialectDefination(); visitor.emitOpDefination(); + visitor.emitBuiltinFunction(moduleAST); return 0; } diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 7400e96cc8..c7796659d3 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -573,7 +573,7 @@ class Manager { void emitOpDefination(); void emitDialectDefination(); void emitTdFiles(); - void emitBuiltinFunction(); + void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *); }; Type diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 04df570393..10b0d302eb 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -25,13 +25,11 @@ namespace fegen { /// @param expected expected params. /// @param actual actual params. /// @return true if correct. -bool checkParams(std::vector &expected, - std::vector &actual); +bool checkParams(std::vector &expected, std::vector &actual); /// @brief check if the type of elements in list are correct. bool checkListLiteral( - std::vector> - &listLiteral); + std::vector> &listLiteral); class FegenVisitor : public FegenParserBaseVisitor { private: @@ -43,10 +41,10 @@ class FegenVisitor : public FegenParserBaseVisitor { void emitTypeDefination() { this->manager.emitTypeDefination(); } void emitDialectDefination() { this->manager.emitDialectDefination(); } void emitOpDefination() { this->manager.emitOpDefination(); } + void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST){this->manager.emitBuiltinFunction(moduleAST);} FegenVisitor() - : manager(Manager::getManager()), - sstack(ScopeStack::getScopeStack()) { + : manager(Manager::getManager()), sstack(ScopeStack::getScopeStack()) { this->manager.initbuiltinTypes(); } @@ -66,8 +64,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenTypeDefination* std::any visitTypeDefinationBlock( FegenParser::TypeDefinationBlockContext *ctx) override { - auto params = std::any_cast>( - this->visit(ctx->parametersSpec())); + auto params = + std::any_cast>(this->visit(ctx->parametersSpec())); auto tyDef = TypeDefination::get(this->manager.moduleName, "", params, nullptr); return tyDef; @@ -104,9 +102,9 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitActionAlt(FegenParser::ActionAltContext *ctx) override { auto rawRule = this->visit(ctx->alternative()); if (ctx->actionBlock()) { - auto blockValues = std::any_cast< - std::tuple, std::vector>>( - this->visit(ctx->actionBlock())); + auto blockValues = + std::any_cast, std::vector>>( + this->visit(ctx->actionBlock())); auto inputs = std::get<0>(blockValues); auto returns = std::get<1>(blockValues); auto rule = std::any_cast(rawRule); @@ -133,13 +131,13 @@ class FegenVisitor : public FegenParserBaseVisitor { std::vector inputs; std::vector returns; if (ctx->inputsSpec()) { - inputs = std::any_cast>( - this->visit(ctx->inputsSpec())); + inputs = + std::any_cast>(this->visit(ctx->inputsSpec())); } if (ctx->returnsSpec()) { - returns = std::any_cast>( - this->visit(ctx->returnsSpec())); + returns = + std::any_cast>(this->visit(ctx->returnsSpec())); } if (ctx->actionSpec()) { @@ -186,21 +184,19 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i <= varCount - 1; i++) { auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); auto varName = ctx->identifier(i)->getText(); - auto var = fegen::Value::get( - ty, varName, fegen::RightValue::getPlaceHolder()); + auto var = + fegen::Value::get(ty, varName, fegen::RightValue::getPlaceHolder()); valueList.push_back(var); } - return valueList; } // return fegen::FegenType std::any visitTypeInstanceSpec(FegenParser::TypeInstanceSpecContext *ctx) override { - auto valueKind = ctx->valueKind() - ? std::any_cast( - this->visit(ctx->valueKind())) - : fegen::Type::TypeKind::CPP; + auto valueKind = ctx->valueKind() ? std::any_cast( + this->visit(ctx->valueKind())) + : fegen::Type::TypeKind::CPP; auto typeInst = std::any_cast(this->visit(ctx->typeInstance())); typeInst.setTypeKind(valueKind); @@ -228,8 +224,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // get parameters std::vector paramList; for (auto paramCtx : ctx->typeTemplateParam()) { - auto tepltParams = - std::any_cast(this->visit(paramCtx)); + auto tepltParams = std::any_cast(this->visit(paramCtx)); paramList.push_back(tepltParams); } @@ -248,8 +243,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto varName = ctx->identifier()->getText(); auto var = this->sstack.attemptFindVar(varName); if (var) { - if (var->getContentKind() == - fegen::RightValue::LiteralKind::TYPE) { + if (var->getContentKind() == fegen::RightValue::LiteralKind::TYPE) { return var->getContent(); } else { std::cerr << "variable " << varName @@ -271,16 +265,15 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitTypeTemplateParam(FegenParser::TypeTemplateParamContext *ctx) override { if (ctx->builtinTypeInstances()) { - auto ty = std::any_cast( - this->visit(ctx->builtinTypeInstances())); + auto ty = + std::any_cast(this->visit(ctx->builtinTypeInstances())); return fegen::Value::get(ty, "param", - fegen::RightValue::getPlaceHolder()); + fegen::RightValue::getPlaceHolder()); } else { - auto expr = - std::any_cast>( - this->visit(ctx->expression())); + auto expr = std::any_cast>( + this->visit(ctx->expression())); return fegen::Value::get(expr->exprType, "expression_tmp", - fegen::RightValue::getByExpr(expr)); + fegen::RightValue::getByExpr(expr)); } } @@ -314,7 +307,10 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { // type auto tyDef = this->sstack.attemptFindTypeDef( ctx->prefixedName()->identifier(0)->getText()); - return fegen::Type::getTemplateType(tyDef); + if (tyDef != nullptr) + return fegen::Type::getTemplateType(tyDef); + else + return fegen::Type::getPlaceHolder(); } } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate return this->visit(ctx->builtinTypeTemplate()); @@ -346,8 +342,8 @@ class FegenVisitor : public FegenParserBaseVisitor { visitCollectTypeSpec(FegenParser::CollectTypeSpecContext *ctx) override { auto kind = fegen::Type::TypeKind::CPP; if (ctx->valueKind()) { - kind = std::any_cast( - this->visit(ctx->valueKind())); + kind = + std::any_cast(this->visit(ctx->valueKind())); } auto ty = std::any_cast(this->visit(ctx->collectType())); ty.setTypeKind(kind); @@ -356,9 +352,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // return FegenType std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->expression())); + auto expr = std::any_cast>( + this->visit(ctx->expression())); if (ctx->collectProtoType()->ANY()) { std::vector tys; // TODO: reprot error @@ -384,39 +379,35 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitExpression(FegenParser::ExpressionContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->andExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->andExpr(0))); for (size_t i = 1; i <= ctx->andExpr().size() - 1; i++) { - auto rhs = - std::any_cast>( - this->visit(ctx->andExpr(i))); - expr = RightValue::ExpressionNode::binaryOperation( - expr, rhs, FegenOperator::OR); + auto rhs = std::any_cast>( + this->visit(ctx->andExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::OR); } + manager.addStmtContent(ctx, expr); return expr; } // return std::shared_ptr std::any visitAndExpr(FegenParser::AndExprContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->equExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->equExpr(0))); for (size_t i = 1; i <= ctx->equExpr().size() - 1; i++) { - auto rhs = - std::any_cast>( - this->visit(ctx->equExpr(i))); - expr = RightValue::ExpressionNode::binaryOperation( - expr, rhs, FegenOperator::AND); + auto rhs = std::any_cast>( + this->visit(ctx->equExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::AND); } return expr; } // return std::shared_ptr std::any visitEquExpr(FegenParser::EquExprContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->compareExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->compareExpr(0))); for (size_t i = 1; i <= ctx->compareExpr().size() - 1; i++) { FegenOperator op; if (ctx->children[2 * i - 1]->getText() == "==") { @@ -424,9 +415,8 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::NOT_EQUAL; } - auto rhs = - std::any_cast>( - this->visit(ctx->compareExpr(i))); + auto rhs = std::any_cast>( + this->visit(ctx->compareExpr(i))); expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; @@ -434,9 +424,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitCompareExpr(FegenParser::CompareExprContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->addExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->addExpr(0))); for (size_t i = 1; i <= ctx->addExpr().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -451,9 +440,8 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::GREATER_EQUAL; } - auto rhs = - std::any_cast>( - this->visit(ctx->addExpr(i))); + auto rhs = std::any_cast>( + this->visit(ctx->addExpr(i))); expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; @@ -461,9 +449,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitAddExpr(FegenParser::AddExprContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->term(0))); + auto expr = std::any_cast>( + this->visit(ctx->term(0))); for (size_t i = 1; i <= ctx->term().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -472,9 +459,8 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::SUB; } - auto rhs = - std::any_cast>( - this->visit(ctx->term(i))); + auto rhs = std::any_cast>( + this->visit(ctx->term(i))); expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; @@ -482,9 +468,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitTerm(FegenParser::TermContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->powerExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->powerExpr(0))); for (size_t i = 1; i <= ctx->powerExpr().size() - 1; i++) { FegenOperator op; auto opStr = ctx->children[2 * i - 1]->getText(); @@ -495,9 +480,8 @@ class FegenVisitor : public FegenParserBaseVisitor { } else { op = FegenOperator::MOD; } - auto rhs = - std::any_cast>( - this->visit(ctx->powerExpr(i))); + auto rhs = std::any_cast>( + this->visit(ctx->powerExpr(i))); expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); } return expr; @@ -505,15 +489,13 @@ class FegenVisitor : public FegenParserBaseVisitor { // return std::shared_ptr std::any visitPowerExpr(FegenParser::PowerExprContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->unaryExpr(0))); + auto expr = std::any_cast>( + this->visit(ctx->unaryExpr(0))); for (size_t i = 1; i <= ctx->unaryExpr().size() - 1; i++) { - auto rhs = - std::any_cast>( - this->visit(ctx->unaryExpr(i))); - expr = RightValue::ExpressionNode::binaryOperation( - expr, rhs, FegenOperator::POWER); + auto rhs = std::any_cast>( + this->visit(ctx->unaryExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::POWER); } return expr; } @@ -523,9 +505,8 @@ class FegenVisitor : public FegenParserBaseVisitor { if (ctx->children.size() == 1 || ctx->Plus()) { return this->visit(ctx->primaryExpr()); } - auto expr = - std::any_cast>( - this->visit(ctx->primaryExpr())); + auto expr = std::any_cast>( + this->visit(ctx->primaryExpr())); FegenOperator op; if (ctx->Minus()) { op = FegenOperator::NEG; @@ -614,9 +595,8 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitListLiteral(FegenParser::ListLiteralContext *ctx) override { std::vector> elements; for (auto exprCtx : ctx->expression()) { - auto expr = - std::any_cast>( - this->visit(exprCtx)); + auto expr = std::any_cast>( + this->visit(exprCtx)); elements.push_back(expr); } return (std::shared_ptr) @@ -629,8 +609,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { sstack.pushScope(); - auto returnType = - std::any_cast(this->visit(ctx->typeSpec())); + auto returnType = std::any_cast(this->visit(ctx->typeSpec())); auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasfunc = manager.functionMap.find(functionName); @@ -644,12 +623,11 @@ class FegenVisitor : public FegenParserBaseVisitor { auto functionParams = std::any_cast>( this->visit(ctx->funcParams())); this->visit(ctx->statementBlock()); - - fegen::Function* function = fegen::Function::get(functionName, functionParams, &returnType); - manager.functionMap.insert(std::pair{functionName, function}); - sstack.popScope(); - return nullptr; - } + auto function = fegen::Function::get(functionName, functionParams, &returnType); + manager.functionMap.insert(std::pair{functionName, function}); + sstack.popScope(); + return nullptr; + } std::any visitFuncName(FegenParser::FuncNameContext *ctx) override { auto functionName = ctx->identifier()->getText(); @@ -663,8 +641,8 @@ class FegenVisitor : public FegenParserBaseVisitor { auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); auto paramName = ctx->identifier(i)->getText(); - auto param = fegen::Value::get( - paramType, paramName, fegen::RightValue::getPlaceHolder()); + auto param = fegen::Value::get(paramType, paramName, + fegen::RightValue::getPlaceHolder()); paramsList.push_back(param); sstack.attemptAddVar(param); } @@ -672,8 +650,7 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - auto varType = - std::any_cast(this->visit(ctx->typeSpec())); + auto varType = std::any_cast(this->visit(ctx->typeSpec())); auto varName = ctx->identifier()->getText(); fegen::Value *var; if (ctx->expression()) { @@ -681,19 +658,22 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any_cast>( this->visit(ctx->expression())); // TODO: check error - // if(!fegen::FegenType::isSameType(&varType, &varcontent->exprType)){ - // std::cerr << "The variabel \" " << varName - // << "\" need \"" << varType.getTypeName() << " \" type rightvalue." - // << std::endl; exit(0); return nullptr; - // } - var = fegen::Value::get( - varType, varName, fegen::RightValue::getByExpr(varcontent)); + if (!fegen::Type::isSameType(&varType, &varcontent->exprType)) { + std::cerr << "The variabel \" " << varName << "\" need \"" + << varType.getTypeName() + << " \" type rightvalue. Now the expression is " + << varcontent->exprType.getTypeName() << "." << std::endl; + exit(0); + return nullptr; + } + var = fegen::Value::get(varType, varName, + fegen::RightValue::getByExpr(varcontent)); } else { var = fegen::Value::get(varType, varName, - fegen::RightValue::getPlaceHolder()); + fegen::RightValue::getPlaceHolder()); } sstack.attemptAddVar(var); - manager.stmtContentMap.insert(std::pair{ctx, var}); + manager.addStmtContent(ctx, varType); return var; } @@ -710,15 +690,12 @@ class FegenVisitor : public FegenParserBaseVisitor { exit(0); return nullptr; } - fegen::Value *stmt = fegen::Value::get( - var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); - manager.stmtContentMap.insert(std::pair{ctx, stmt}); - return stmt; + return var; } std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { - std::vector parasList = {}; + std::vector> parasList = {}; auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasFunc = manager.functionMap.at(functionName); @@ -726,7 +703,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto paraList = hasFunc->getInputTypeList(); if (paramsNum > 0) { for (size_t i = 0; i < paramsNum; i++) { - auto oprand = std::any_cast( + auto oprand = std::any_cast>( this->visit(ctx->expression(i))); parasList.push_back(oprand); } @@ -740,7 +717,7 @@ class FegenVisitor : public FegenParserBaseVisitor { } for (size_t i = 0; i < len1; i++) { if (!fegen::Type::isSameType(¶List[i]->getType(), - ¶sList[i]->exprType)) { + ¶sList[i]->exprType)) { std::cerr << "The function \" " << functionName << "\" parameter" << i << " type mismatch." << std::endl; exit(0); @@ -804,8 +781,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { auto opName = ctx->opName()->getText(); - auto opDef = - std::any_cast(this->visit(ctx->opBlock())); + auto opDef = std::any_cast(this->visit(ctx->opBlock())); opDef->setOpName(opName); bool success = this->manager.addOperationDefination(opDef); if (!success) { diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index b1514a2795..e513468303 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -349,7 +349,7 @@ fegen::Type fegen::Type::getBoolType() { } fegen::Type fegen::Type::getIntegerType(fegen::Value *size) { - if (size->getContent() == 32) + if (size->getContent() == 32) return fegen::Type::getInt32Type(); return fegen::Type( fegen::Type::TypeKind::CPP, {size}, @@ -357,9 +357,9 @@ fegen::Type fegen::Type::getIntegerType(fegen::Value *size) { } fegen::Type fegen::Type::getFloatPointType(fegen::Value *size) { - if (size->getContent() == 32) { + if (size->getContent() == 32) { return fegen::Type::getFloatType(); - } else if (size->getContent() == 64) { + } else if (size->getContent() == 64) { return fegen::Type::getDoubleType(); } return fegen::Type( @@ -1744,19 +1744,19 @@ class StmtVisitor : public FegenParserBaseVisitor { StmtVisitor() : manager(Manager::getManager()) {} std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { Emitter emitter(std::cout); - auto varDecl = + auto varType = std::any_cast(manager.stmtContentMap[ctx]); - emitter << varDecl->getType().getTypeName() << " " << varDecl->getName() - << " = " << varDecl->getContentString() << ";"; + auto varName = ctx->identifier()->toString(); + auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + emitter << varType->getName() << " " << varName << " = " << expr->toString() << ";"; emitter.newLine(); return nullptr; } std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { Emitter emitter(std::cout); - auto assignStmt = - std::any_cast(manager.stmtContentMap[ctx]); - emitter << assignStmt->getName() << " = " << assignStmt->getContentString() - << ";"; + auto varName = ctx->identifier()->toString(); + auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + emitter << varName << " = " << expr->toString() << ";"; emitter.newLine(); return nullptr; } @@ -1775,16 +1775,40 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter.newLine(); return nullptr; } + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + Emitter emitter(std::cout); + this->visit(ctx->ifBlock(0)); + for(size_t i = 1; i < ctx->ifBlock().size(); i++){ + emitter << " else "; + this->visit(ctx->ifBlock(i)); + } + if(ctx->elseBlock()) this->visit(ctx->elseBlock()); + return nullptr; + } std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { Emitter emitter(std::cout); - auto expr = std::any_cast( - manager.stmtContentMap[ctx]); + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); emitter << "if (" << expr->toString() << "){"; emitter.newLine(); emitter.tab(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; return nullptr; } + std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { + Emitter emitter(std::cout); + emitter << "else {"; + emitter.newLine(); + emitter.tab(); + this->visit(ctx->statementBlock()); + emitter.newLine(); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + } // TODO: 支持for循环 std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { Emitter emitter(std::cout); @@ -1794,7 +1818,7 @@ class StmtVisitor : public FegenParserBaseVisitor { }; } // namespace fegen -void fegen::Manager::emitBuiltinFunction() { +void fegen::Manager::emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST) { Emitter emitter(std::cout); fegen::StmtVisitor visitor; @@ -1802,7 +1826,7 @@ void fegen::Manager::emitBuiltinFunction() { auto functionName = function_pair.first; auto function = function_pair.second; auto paraList = function->getInputTypeList(); - emitter << function->getReturnType()->toStringForTypedef() << " " + emitter << function->getReturnType()->getTypeName() << " " << functionName << "("; for (auto para : paraList) { emitter << para->getContentStringForTypedef() << " " << para->getName(); @@ -1813,41 +1837,7 @@ void fegen::Manager::emitBuiltinFunction() { emitter.newLine(); emitter.tab(); // TODO::function body - auto blockNum = 0; - auto expressionNum = 1; - FegenParser::IfBlockContext *ifBlock = nullptr; - for (auto stmt : stmtContentMap) { - visitor.visit(stmt.first); - if (stmt.second.type().name() == "string") { - if (std::any_cast(stmt.second) == "IF") { - ifBlock = std::any_cast(stmt.first); - // blockNum = ifBlock->statement().size(); - continue; - } else if (std::any_cast(stmt.second) == "FOR") { - // TODO: 支持for循环 - continue; - } - } - if (blockNum > 0) - blockNum--; - if (blockNum > 1) { - emitter.shiftTab(); - // emitter << "} else if (" << - // ifStmt->expression(expressionNum)->toString() << "){"; - emitter.newLine(); - emitter.tab(); - expressionNum++; - } else if (blockNum == 1) { - emitter.shiftTab(); - emitter << "} else {"; - emitter.newLine(); - emitter.tab(); - } else if (blockNum == 0) { - emitter.shiftTab(); - emitter << "}"; - expressionNum = 1; - } - } + visitor.visit(moduleAST); emitter.shiftTab(); emitter.newLine(); emitter << "}"; diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 index 5e1cc4a54f..8a6c671888 100644 --- a/frontend/FrontendGen/lib/FegenParser.g4 +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -406,7 +406,6 @@ valueKind | ATTRIBUTE ; -// 这里的identifier是不是没用? typeInstance : typeTemplate (Less typeTemplateParam (Comma typeTemplateParam)* Greater)? | builtinTypeInstances From d390aad5a7da5106f2f27c5db68afec33e33fd1d Mon Sep 17 00:00:00 2001 From: chh Date: Wed, 17 Jul 2024 20:26:27 +0800 Subject: [PATCH 09/17] [FrontendGen] Refactor fegen Type. --- examples/FrontendGen/makefile | 2 + frontend/FrontendGen/include/FegenManager.h | 256 ++-- frontend/FrontendGen/include/FegenVisitor.h | 136 +- frontend/FrontendGen/include/Scope.h | 12 - frontend/FrontendGen/lib/FegenManager.cpp | 1238 +++++++++---------- frontend/FrontendGen/lib/FegenParser.g4 | 6 +- frontend/FrontendGen/lib/FegenVisitor.cpp | 2 +- frontend/FrontendGen/lib/Scope.cpp | 32 - 8 files changed, 853 insertions(+), 831 deletions(-) diff --git a/examples/FrontendGen/makefile b/examples/FrontendGen/makefile index 6b65351f0b..43c29b9586 100644 --- a/examples/FrontendGen/makefile +++ b/examples/FrontendGen/makefile @@ -10,3 +10,5 @@ typeDefine: rule: @${BUDDY_FRONTEND_GEN} -f ./rule.fegen +clean: + rm -f ./toy* \ No newline at end of file diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 7400e96cc8..4a84bbe2b9 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -28,15 +28,19 @@ #define FEGEN_LIST "List" #define FEGEN_OPTINAL "Optional" #define FEGEN_ANY "Any" - +#define FEGEN_DIALECT_NAME "fegen_builtin" #define FEGEN_NOT_IMPLEMENTED_ERROR false -namespace fegen { +namespace fegen { class Type; class Manager; class Value; +class RightValue; +class Expression; +using TypePtr = std::shared_ptr; +using largestInt = long long int; // binary operation enum class FegenOperator { @@ -66,20 +70,20 @@ class Function { // input object std::vector inputTypeList; // return type - Type *returnType; + TypePtr returnType; explicit Function(std::string name, std::vector &&inputTypeList, - Type *returnType); + TypePtr returnType); public: static Function *get(std::string name, std::vector inputTypeList, - Type *returnType = nullptr); + TypePtr returnType = nullptr); ~Function() = default; std::string getName(); std::vector &getInputTypeList(); Value *getInputTypeList(size_t i); - Type *getReturnType(); + TypePtr getReturnType(); }; class Value; @@ -115,7 +119,7 @@ class Operation { }; class TypeDefination; - +class RightValue; class Type { friend class Value; @@ -125,89 +129,84 @@ class Type { private: TypeKind kind; std::string typeName; - std::vector parameters; + // std::vector parameters; TypeDefination *typeDefine; int typeLevel; + bool isConstType; public: - Type(TypeKind kind, std::string name, - std::vector parameters, TypeDefination *tyDef, - int typeLevel); - Type(TypeKind kind, std::vector parameters, - TypeDefination *tyDef, int typeLevel); - Type(const Type &); - Type(Type &&); + Type(TypeKind kind, std::string name, TypeDefination *tyDef, int typeLevel, bool isConstType); + + Type(const Type &) = default; + Type(Type &&) = default; TypeKind getTypeKind(); void setTypeKind(TypeKind kind); - std::vector &getParameters(); - Value *getParameters(size_t i); - void setParameters(std::vector ¶ms); TypeDefination *getTypeDefination(); void setTypeDefination(TypeDefination *tyDef); std::string getTypeName(); int getTypeLevel(); + bool isConstant(); // for generating typedef td file. - std::string toStringForTypedef(); + virtual std::string toStringForTypedef(); // for generating op def td file. - std::string toStringForOpdef(); + virtual std::string toStringForOpdef(); // for generating cpp type kind. - std::string toStringForCppKind(); + virtual std::string toStringForCppKind(); static bool isSameType(Type *type1, Type *type2); - ~Type(); + virtual ~Type() = default; + // placeholder - static Type getPlaceHolder(); + static TypePtr getPlaceHolder(); + // Type - static Type getMetaType(); + static TypePtr getMetaType(); // TypeTemplate - static Type getMetaTemplateType(); + static TypePtr getMetaTemplateType(); // int - static Type getInt32Type(); + static TypePtr getInt32Type(); // float - static Type getFloatType(); + static TypePtr getFloatType(); // float - static Type getDoubleType(); + static TypePtr getDoubleType(); // bool - static Type getBoolType(); + static TypePtr getBoolType(); // Integer - static Type getIntegerType(Value *size); + static TypePtr getIntegerType(RightValue size); // FloatPoint - static Type getFloatPointType(Value *size); - - // char - static Type getCharType(); + static TypePtr getFloatPointType(RightValue size); // string - static Type getStringType(); + static TypePtr getStringType(); + + // List + static TypePtr getListType(TypePtr elementType); + static TypePtr getListType(RightValue elementType); // Vector - static Type getVectorType(Value *size, Type elementType); + static TypePtr getVectorType(TypePtr elementType, RightValue size); + static TypePtr getVectorType(RightValue elementType, RightValue size); // Tensor - static Type getTensorType(Value *shape, Type elementType); - - // List - static Type getListType(Type elementType); + static TypePtr getTensorType(TypePtr elementType, RightValue shape); + static TypePtr getTensorType(RightValue elementType, RightValue shape); // Optional - static Type getOptionalType(Type elementType); + static TypePtr getOptionalType(TypePtr elementType); + static TypePtr getOptionalType(RightValue elementType); - // Any - static Type getAnyType(std::vector elementTypes); + // Any<[elementType1, elementType2, ...]> + static TypePtr getAnyType(RightValue elementTypes); - static Type getIntegerTemplate(); - static Type getFloatPointTemplate(); + static TypePtr getCustomeType(std::vector params, TypeDefination* tydef); - static Type getInstanceType(TypeDefination *typeDefination, - std::vector parameters); - - static Type getTemplateType(TypeDefination *typeDefination); + static TypePtr getTemplateType(TypeDefination *typeDefination); }; class TypeDefination { @@ -245,7 +244,7 @@ class TypeDefination { class RightValue { friend class Type; friend class Value; - + public: enum class LiteralKind { MONOSTATE, @@ -267,10 +266,9 @@ class RightValue { struct Expression { bool ifTerminal; LiteralKind kind; - Type exprType; bool isLiteral; bool ifConstexpr; - Expression(bool, LiteralKind, Type &, bool); + Expression(bool, LiteralKind, bool); virtual ~Expression() = default; virtual bool isTerminal(); virtual std::string toString() = 0; @@ -278,7 +276,7 @@ class RightValue { virtual std::string toStringForOpdef() = 0; virtual std::string toStringForCppKind() = 0; LiteralKind getKind(); - Type &getType(); + virtual TypePtr getType() = 0; virtual std::any getContent() = 0; virtual bool isConstexpr(); @@ -299,12 +297,12 @@ class RightValue { callOperation(std::vector>, Operation *); static std::shared_ptr getPlaceHolder(); - static std::shared_ptr getInteger(long long int, + static std::shared_ptr getInteger(largestInt, size_t size = 32); static std::shared_ptr getFloatPoint(long double, size_t size = 32); static std::shared_ptr getString(std::string); - static std::shared_ptr getType(Type &); + static std::shared_ptr getTypeRightValue(TypePtr); static std::shared_ptr getList(std::vector> &); static std::shared_ptr @@ -312,12 +310,13 @@ class RightValue { }; struct ExpressionNode : public Expression { - ExpressionNode(LiteralKind, Type, bool); + ExpressionNode(LiteralKind, bool); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; virtual std::any getContent() override = 0; + virtual TypePtr getType() override; }; struct FunctionCall : public ExpressionNode { @@ -329,6 +328,7 @@ class RightValue { virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; virtual std::any getContent() override; + virtual TypePtr getType() override; }; struct OperationCall : public ExpressionNode { @@ -340,6 +340,7 @@ class RightValue { virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; virtual std::any getContent() override; + virtual TypePtr getType() override; }; struct OperatorCall : public ExpressionNode { @@ -351,15 +352,17 @@ class RightValue { virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; virtual std::any getContent() override; + virtual TypePtr getType() override; }; struct ExpressionTerminal : public Expression { - ExpressionTerminal(LiteralKind, Type, bool); + ExpressionTerminal(LiteralKind, bool); virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; virtual std::any getContent() override = 0; + virtual TypePtr getType() override; }; struct PlaceHolder : public ExpressionTerminal { @@ -370,12 +373,11 @@ class RightValue { struct IntegerLiteral : public ExpressionTerminal { size_t size; - long long int content; - // size = 32 - IntegerLiteral(int content); - IntegerLiteral(long long int content, size_t size); + largestInt content; + IntegerLiteral(largestInt content, size_t size); virtual std::any getContent() override; virtual std::string toString() override; + virtual TypePtr getType() override; }; struct FloatPointLiteral : public ExpressionTerminal { @@ -384,6 +386,7 @@ class RightValue { FloatPointLiteral(long double content, size_t size); virtual std::any getContent() override; virtual std::string toString() override; + virtual TypePtr getType() override; }; struct StringLiteral : public ExpressionTerminal { @@ -391,16 +394,18 @@ class RightValue { StringLiteral(std::string content); virtual std::any getContent() override; virtual std::string toString() override; + virtual TypePtr getType() override; }; struct TypeLiteral : public ExpressionTerminal { - Type content; - TypeLiteral(Type &content); + TypePtr content; + TypeLiteral(TypePtr content); virtual std::any getContent() override; virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; }; struct ListLiteral : public ExpressionTerminal { @@ -410,6 +415,7 @@ class RightValue { virtual std::string toString() override; virtual std::string toStringForTypedef() override; virtual std::string toStringForOpdef() override; + virtual TypePtr getType() override; }; struct LeftValue : public ExpressionTerminal { @@ -417,9 +423,11 @@ class RightValue { LeftValue(Value *content); virtual std::any getContent() override; virtual std::string toString() override; + virtual TypePtr getType() override; }; public: + using ExprPtr = std::shared_ptr; RightValue(std::shared_ptr); RightValue(const RightValue &) = default; RightValue(RightValue &&) = default; @@ -430,14 +438,15 @@ class RightValue { std::string toStringForOpdef(); std::string toStringForCppKind(); std::any getContent(); - Type &getType(); + TypePtr getType(); std::shared_ptr getExpr(); + bool isConstant(); static RightValue getPlaceHolder(); - static RightValue getInteger(long long int content, size_t size = 32); + static RightValue getInteger(largestInt content, size_t size = 32); static RightValue getFloatPoint(long double content, size_t size = 32); static RightValue getString(std::string content); - static RightValue getType(Type &content); + static RightValue getTypeRightValue(TypePtr content); static RightValue getList(std::vector> &content); static RightValue getLeftValue(fegen::Value *content); @@ -448,24 +457,129 @@ class RightValue { std::shared_ptr content; }; +// PlaceHolder +class PlaceHolderType : public Type { + public: + PlaceHolderType(); +}; + +// Type +class MetaType : public Type { + public: + MetaType(); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + +}; +// Template +class MetaTemplate : public Type { + public: + MetaTemplate(); +}; +// Integer +class IntegerType : public Type { + RightValue size; + public: + IntegerType(RightValue size, TypeDefination* tyDef); + IntegerType(RightValue size); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// FloatPoint +class FloatPointType : public Type { + RightValue size; + public: + FloatPointType(RightValue size); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// String +class StringType : public Type { + public: + StringType(); +}; +// List +class ListType : public Type { + RightValue elementType; + public: + ListType(RightValue elementType); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// Vector +class VectorType : public Type { + RightValue elementType; + RightValue size; + public: + VectorType(RightValue elementType, RightValue size); +}; +// Tensor +class TensorType : public Type { + RightValue elementType; + RightValue shape; + public: + TensorType(RightValue elementType, RightValue shape); +}; +// Optional +class OptionalType : public Type { + RightValue elementType; + public: + OptionalType(RightValue elementType); +}; +// Any<[ty1, ty2, ...]> +class AnyType : public Type { + RightValue elementTypes; + public: + AnyType(RightValue elementTypes); +}; +// custome type +class CustomeType : public Type { + std::vector params; + public: + CustomeType(std::vector params, TypeDefination* tydef); +}; + +class TemplateType : public Type { + public: + TemplateType(TypeDefination* tydef); + TypePtr instantiate(std::vector params); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; +}; + + class Value { friend class Type; private: - Type type; + TypePtr type; std::string name; RightValue content; public: - Value(Type type, std::string name, RightValue content); + Value(TypePtr type, std::string name, RightValue content); Value(const Value &rhs); Value(Value &&rhs); - static Value *get(Type type, std::string name, + static Value *get(TypePtr type, std::string name, RightValue constant); std::string getName(); - Type &getType(); + TypePtr getType(); /// @brief return content of right value, get ExprssionNode* if kind is /// EXPRESSION. template T getContent() { @@ -576,7 +690,7 @@ class Manager { void emitBuiltinFunction(); }; -Type +TypePtr inferenceType(std::vector>, FegenOperator); diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 2b2724cce2..1c3f40a9aa 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -2,6 +2,7 @@ #define FEGEN_FEGENVISITOR_H #include +#include #include #include #include @@ -26,7 +27,7 @@ namespace fegen { /// @param actual actual params. /// @return true if correct. bool checkParams(std::vector &expected, - std::vector &actual); + std::vector &actual); /// @brief check if the type of elements in list are correct. bool checkListLiteral( @@ -184,17 +185,16 @@ class FegenVisitor : public FegenParserBaseVisitor { size_t varCount = ctx->typeSpec().size(); std::vector valueList; for (size_t i = 0; i <= varCount - 1; i++) { - auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); + auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); auto varName = ctx->identifier(i)->getText(); auto var = fegen::Value::get( ty, varName, fegen::RightValue::getPlaceHolder()); valueList.push_back(var); } - return valueList; } - // return fegen::FegenType + // return fegen::TypePtr std::any visitTypeInstanceSpec(FegenParser::TypeInstanceSpecContext *ctx) override { auto valueKind = ctx->valueKind() @@ -202,8 +202,8 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->valueKind())) : fegen::Type::TypeKind::CPP; auto typeInst = - std::any_cast(this->visit(ctx->typeInstance())); - typeInst.setTypeKind(valueKind); + std::any_cast(this->visit(ctx->typeInstance())); + typeInst->setTypeKind(valueKind); return typeInst; } @@ -219,30 +219,33 @@ class FegenVisitor : public FegenParserBaseVisitor { return kind; } - // return fegen::FegenType + // return fegen::TypePtr std::any visitTypeInstance(FegenParser::TypeInstanceContext *ctx) override { if (ctx->typeTemplate()) { // typeTemplate (Less typeTemplateParam (Comma // typeTemplateParam)* Greater)? - auto typeTeplt = - std::any_cast(this->visit(ctx->typeTemplate())); + auto typeTeplt = + std::any_cast(this->visit(ctx->typeTemplate())); + if(ctx->typeTemplate()->TYPE()){ + return typeTeplt; + } + auto teplt = std::dynamic_pointer_cast(typeTeplt); // get parameters - std::vector paramList; + std::vector paramList; for (auto paramCtx : ctx->typeTemplateParam()) { auto tepltParams = - std::any_cast(this->visit(paramCtx)); + std::any_cast(this->visit(paramCtx)); paramList.push_back(tepltParams); } // check parameters - auto expectedParams = typeTeplt.getTypeDefination()->getParameters(); + auto expectedParams = teplt->getTypeDefination()->getParameters(); if (!checkParams(expectedParams, paramList)) { std::cerr << "parameters error in context: " << ctx->getText() << std::endl; exit(0); } - // get FegenType of instance - auto typeInst = - Type::getInstanceType(typeTeplt.getTypeDefination(), paramList); + // get instance + auto typeInst = teplt->instantiate(paramList); return typeInst; } else if (ctx->identifier()) { // identifier auto varName = ctx->identifier()->getText(); @@ -250,7 +253,7 @@ class FegenVisitor : public FegenParserBaseVisitor { if (var) { if (var->getContentKind() == fegen::RightValue::LiteralKind::TYPE) { - return var->getContent(); + return var->getContent(); } else { std::cerr << "variable " << varName << " is not a Type or TypeTemplate." << std::endl; @@ -267,20 +270,18 @@ class FegenVisitor : public FegenParserBaseVisitor { } } - // return FegenValue* + // return RightValue std::any visitTypeTemplateParam(FegenParser::TypeTemplateParamContext *ctx) override { if (ctx->builtinTypeInstances()) { - auto ty = std::any_cast( + auto ty = std::any_cast( this->visit(ctx->builtinTypeInstances())); - return fegen::Value::get(ty, "param", - fegen::RightValue::getPlaceHolder()); + return fegen::RightValue::getTypeRightValue(ty); } else { auto expr = std::any_cast>( this->visit(ctx->expression())); - return fegen::Value::get(expr->exprType, "expression_tmp", - fegen::RightValue::getByExpr(expr)); + return fegen::RightValue::getByExpr(expr); } } @@ -295,8 +296,6 @@ class FegenVisitor : public FegenParserBaseVisitor { return Type::getFloatType(); } else if (ctx->DOUBLE()) { return Type::getDoubleType(); - } else if (ctx->CHAR()) { - return Type::getCharType(); } else if (ctx->STRING()) { return Type::getStringType(); } else { @@ -305,15 +304,14 @@ class FegenVisitor : public FegenParserBaseVisitor { } } - // return FegenType + // return TypePtr std::any visitTypeTemplate(FegenParser::TypeTemplateContext *ctx) override { if (ctx->prefixedName()) { // prefixedName if (ctx->prefixedName()->identifier().size() == 2) { // dialect.type // TODO: return type from other dialect return nullptr; } else { // type - auto tyDef = this->sstack.attemptFindTypeDef( - ctx->prefixedName()->identifier(0)->getText()); + auto tyDef = this->manager.getTypeDefination(ctx->prefixedName()->identifier(0)->getText()); return fegen::Type::getTemplateType(tyDef); } } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate @@ -323,25 +321,24 @@ class FegenVisitor : public FegenParserBaseVisitor { } } - // return FegenType + // return TypePtr std::any visitBuiltinTypeTemplate( FegenParser::BuiltinTypeTemplateContext *ctx) override { if (ctx->INTEGER()) { - return fegen::Type::getIntegerTemplate(); + return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_INTEGER)); } else if (ctx->FLOATPOINT()) { - return fegen::Type::getFloatPointTemplate(); + return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_FLOATPOINT)); } else if (ctx->TENSOR()) { // return fegen::FegenType::getTensorTemplate(); - return fegen::Type::getPlaceHolder(); + return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_TENSOR)); } else if (ctx->VECTOR()) { - // return fegen::FegenType::getVectorTemplate(); - return fegen::Type::getPlaceHolder(); + return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_VECTOR)); } else { return nullptr; } } - // return FegenType + // return TypePtr std::any visitCollectTypeSpec(FegenParser::CollectTypeSpecContext *ctx) override { auto kind = fegen::Type::TypeKind::CPP; @@ -349,36 +346,26 @@ class FegenVisitor : public FegenParserBaseVisitor { kind = std::any_cast( this->visit(ctx->valueKind())); } - auto ty = std::any_cast(this->visit(ctx->collectType())); - ty.setTypeKind(kind); + auto ty = std::any_cast(this->visit(ctx->collectType())); + ty->setTypeKind(kind); return ty; } - // return FegenType + // return TypePtr std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { auto expr = std::any_cast>( this->visit(ctx->expression())); if (ctx->collectProtoType()->ANY()) { std::vector tys; - // TODO: reprot error assert(expr->getKind() == fegen::RightValue::LiteralKind::VECTOR); - auto exprs = std::any_cast< - std::vector>>( - expr->getContent()); - for (auto expr : exprs) { - auto ty = std::any_cast(expr->getContent()); - tys.push_back(ty); - } - return fegen::Type::getAnyType(tys); + return fegen::Type::getAnyType(fegen::RightValue::getByExpr(expr)); } else if (ctx->collectProtoType()->LIST()) { assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); - auto ty = std::any_cast(expr->getContent()); - return fegen::Type::getListType(ty); + return fegen::Type::getListType(fegen::RightValue::getByExpr(expr)); } else { // optional assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); - auto ty = std::any_cast(expr->getContent()); - return fegen::Type::getOptionalType(ty); + return fegen::Type::getOptionalType(fegen::RightValue::getByExpr(expr)); } } @@ -556,19 +543,19 @@ class FegenVisitor : public FegenParserBaseVisitor { if (tyDef) { auto tyVar = fegen::Type::getTemplateType(tyDef); return (std::shared_ptr) - fegen::RightValue::Expression::getType(tyVar); + fegen::RightValue::Expression::getTypeRightValue(tyVar); } else { // TODO: error report std::cerr << "can not find variable: " << ctx->identifier()->getText() << "." << std::endl; - exit(0); + assert(false); return nullptr; } } } else if (ctx->typeSpec()) { - auto ty = std::any_cast(this->visit(ctx->typeSpec())); + auto ty = std::any_cast(this->visit(ctx->typeSpec())); return (std::shared_ptr) - RightValue::ExpressionTerminal::getType(ty); + RightValue::Expression::getTypeRightValue(ty); } else { // constant, functionCall, parenSurroundedExpr,contextMethodInvoke, // and variableAccess return this->visit(ctx->children[0]); @@ -630,7 +617,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { sstack.pushScope(); auto returnType = - std::any_cast(this->visit(ctx->typeSpec())); + std::any_cast(this->visit(ctx->typeSpec())); auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasfunc = manager.functionMap.find(functionName); @@ -646,7 +633,7 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->statementBlock()); fegen::Function *function = - fegen::Function::get(functionName, functionParams, &returnType); + fegen::Function::get(functionName, functionParams, returnType); manager.functionMap.insert(std::pair{functionName, function}); sstack.popScope(); return nullptr; @@ -662,7 +649,7 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i < ctx->typeSpec().size(); i++) { auto paramType = - std::any_cast(this->visit(ctx->typeSpec(i))); + std::any_cast(this->visit(ctx->typeSpec(i))); auto paramName = ctx->identifier(i)->getText(); auto param = fegen::Value::get( paramType, paramName, fegen::RightValue::getPlaceHolder()); @@ -674,7 +661,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto varType = - std::any_cast(this->visit(ctx->typeSpec())); + std::any_cast(this->visit(ctx->typeSpec())); auto varName = ctx->identifier()->getText(); fegen::Value *var; if (ctx->expression()) { @@ -704,13 +691,14 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any_cast>( this->visit(ctx->expression())); auto var = sstack.attemptFindVar(varName); - if (!fegen::Type::isSameType(&var->getType(), &varcontent->exprType)) { - std::cerr << "The variabel \" " << varName << "\" need \"" - << var->getType().getTypeName() << " \" type rightvalue." - << std::endl; - exit(0); - return nullptr; - } + // TODO + // if (!fegen::Type::isSameType(&var->getType(), &varcontent->exprType)) { + // std::cerr << "The variabel \" " << varName << "\" need \"" + // << var->getType().getTypeName() << " \" type rightvalue." + // << std::endl; + // exit(0); + // return nullptr; + // } fegen::Value *stmt = fegen::Value::get( var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); manager.stmtContentMap.insert(std::pair{ctx, stmt}); @@ -739,15 +727,15 @@ class FegenVisitor : public FegenParserBaseVisitor { exit(0); return nullptr; } - for (size_t i = 0; i < len1; i++) { - if (!fegen::Type::isSameType(¶List[i]->getType(), - ¶sList[i]->exprType)) { - std::cerr << "The function \" " << functionName << "\" parameter" << i - << " type mismatch." << std::endl; - exit(0); - return nullptr; - } - } + // for (size_t i = 0; i < len1; i++) { + // if (!fegen::Type::isSameType(¶List[i]->getType(), + // ¶sList[i]->exprType)) { + // std::cerr << "The function \" " << functionName << "\" parameter" << i + // << " type mismatch." << std::endl; + // exit(0); + // return nullptr; + // } + // } } auto returnType = hasFunc->getReturnType(); fegen::Function *funcCall = diff --git a/frontend/FrontendGen/include/Scope.h b/frontend/FrontendGen/include/Scope.h index 1d1acc283f..56e5eecb21 100644 --- a/frontend/FrontendGen/include/Scope.h +++ b/frontend/FrontendGen/include/Scope.h @@ -20,26 +20,18 @@ template class SymbolTable { }; class FegenScope { - using TypeDefTable = SymbolTable; using VariableTable = SymbolTable; friend class ScopeStack; private: unsigned int scopeId; FegenScope *parentScope; - TypeDefTable typeTable; VariableTable varTable; public: explicit FegenScope(unsigned int scopeId, FegenScope *parentScope); ~FegenScope() = default; - /// @brief this will not check. - TypeDefination *findTypeDef(std::string name); - /// @brief this will not check whether tyDef is already existed or not. - void addTypeDef(TypeDefination *tyDef); - /// @brief return true if exist. - bool isExistTypeDef(std::string name); /// @brief this will not check. Value *findVar(std::string name); /// @brief this will not check whether var is already existed or not. @@ -71,10 +63,6 @@ class ScopeStack { bool attemptAddVar(Value *var); /// @brief check add find var from current scope, return nullptr if failed. Value *attemptFindVar(std::string name); - /// @brief check and add tyDef to current scope, return false if failed. - bool attemptAddTypeDef(TypeDefination *tyDef); - /// @brief check and find tyDef from current scope, return nullptr if failed. - TypeDefination *attemptFindTypeDef(std::string name); }; } // namespace fegen diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index ac5fa0803b..7b9c52bc75 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -10,18 +10,18 @@ #include #include #include +#include #include #include fegen::Function::Function(std::string name, - std::vector &&inputTypeList, - Type *returnType) + std::vector &&inputTypeList, + TypePtr returnType) : name(name), inputTypeList(inputTypeList), returnType(returnType) {} -fegen::Function * -fegen::Function::get(std::string name, - std::vector inputTypeList, - Type *returnType) { +fegen::Function *fegen::Function::get(std::string name, + std::vector inputTypeList, + TypePtr returnType) { return new fegen::Function(name, std::move(inputTypeList), returnType); } std::string fegen::Function::getName() { return this->name; } @@ -34,15 +34,12 @@ fegen::Value *fegen::Function::getInputTypeList(size_t i) { return this->inputTypeList[i]; } -fegen::Type *fegen::Function::getReturnType() { - return this->returnType; -} +fegen::TypePtr fegen::Function::getReturnType() { return this->returnType; } -fegen::Operation::Operation(std::string dialectName, - std::string operationName, - std::vector &&arguments, - std::vector &&results, - fegen::FegenParser::BodySpecContext *ctx) +fegen::Operation::Operation(std::string dialectName, std::string operationName, + std::vector &&arguments, + std::vector &&results, + fegen::FegenParser::BodySpecContext *ctx) : dialectName(dialectName), arguments(arguments), results(results), ctx(ctx) {} @@ -67,86 +64,25 @@ fegen::Value *fegen::Operation::getResults(size_t i) { return this->results[i]; } -fegen::Operation *fegen::Operation::get( - std::string operationName, std::vector arguments, - std::vector results, FegenParser::BodySpecContext *ctx) { +fegen::Operation *fegen::Operation::get(std::string operationName, + std::vector arguments, + std::vector results, + FegenParser::BodySpecContext *ctx) { return new fegen::Operation(fegen::Manager::getManager().moduleName, - operationName, std::move(arguments), - std::move(results), ctx); + operationName, std::move(arguments), + std::move(results), ctx); } // class FegenType -/// @brief get name of Type Instance by jointsing template name and parameters, -/// for example: Integer + 32 --> Integer<32> -/// @return joint name -std::string jointTypeName(std::string templateName, - const std::vector ¶meters) { - if (parameters.empty()) { - return templateName; - } - std::string res = templateName; - res.append("<"); - size_t count = parameters.size(); - auto firstParamStr = parameters[0]->getContentString(); - res.append(firstParamStr); - for (size_t i = 1; i <= count - 1; i++) { - auto paramStr = parameters[i]->getContentString(); - res.append(", "); - res.append(paramStr); - } - res.append(">"); - return res; -} - -fegen::Type::Type(TypeKind kind, std::string name, - std::vector parameters, - TypeDefination *tyDef, int typeLevel) - : kind(kind), typeName(name), parameters(std::move(parameters)), - typeDefine(tyDef), typeLevel(typeLevel) {} - -fegen::Type::Type(fegen::Type::TypeKind kind, - std::vector parameters, - TypeDefination *tyDef, int typeLevel) - : kind(kind), typeName(jointTypeName(tyDef->getName(), parameters)), - parameters(std::move(parameters)), typeDefine(tyDef), - typeLevel((typeLevel)) {} - -fegen::Type::Type(const fegen::Type &fty) - : kind(fty.kind), typeName(fty.typeName), typeDefine(fty.typeDefine), - typeLevel(fty.typeLevel) { - // deep copy parameters - for (auto paramPtr : fty.parameters) { - this->parameters.push_back(new fegen::Value(*paramPtr)); - } -} +fegen::Type::Type(TypeKind kind, std::string name, TypeDefination *tyDef, + int typeLevel, bool isConstType) + : kind(kind), typeName(name), typeDefine(tyDef), typeLevel(typeLevel), + isConstType(isConstType) {} -fegen::Type::Type(fegen::Type &&fty) - : kind(fty.kind), typeName(std::move(fty.typeName)), - parameters(std::move(fty.parameters)), typeDefine(fty.typeDefine), - typeLevel(fty.typeLevel) {} +fegen::Type::TypeKind fegen::Type::getTypeKind() { return this->kind; } -fegen::Type::TypeKind fegen::Type::getTypeKind() { - return this->kind; -} - -void fegen::Type::setTypeKind(fegen::Type::TypeKind kind) { - this->kind = kind; -} - -std::vector &fegen::Type::getParameters() { - return this->parameters; -} - -fegen::Value *fegen::Type::getParameters(size_t i) { - return this->parameters[i]; -} - -void fegen::Type::setParameters(std::vector ¶ms) { - this->parameters = params; - // set parameters and level up! - this->typeLevel++; -} +void fegen::Type::setTypeKind(fegen::Type::TypeKind kind) { this->kind = kind; } fegen::TypeDefination *fegen::Type::getTypeDefination() { return this->typeDefine; @@ -160,8 +96,9 @@ std::string fegen::Type::getTypeName() { return this->typeName; } int fegen::Type::getTypeLevel() { return this->typeLevel; } -bool fegen::Type::isSameType(fegen::Type *type1, - fegen::Type *type2) { +bool fegen::Type::isConstant() { return this->isConstType; } + +bool fegen::Type::isSameType(fegen::Type *type1, fegen::Type *type2) { if (type1->getTypeName() == type2->getTypeName()) return true; else @@ -169,310 +106,409 @@ bool fegen::Type::isSameType(fegen::Type *type1, } std::string fegen::Type::toStringForTypedef() { - // handle builtin type instance - auto typeName = this->typeName; - auto typedefName = this->typeDefine->getName(); - if (this->typeDefine->isCustome()) { - return this->typeDefine->getName(); - } else if (typedefName == FEGEN_TYPE) { - return "\"Type\""; - } else if (typedefName == FEGEN_LIST) { - std::string res = "ArrayRefParameter<"; - for (size_t i = 0; i <= this->parameters.size() - 1; i++) { - res.append(this->parameters[i]->getContentStringForTypedef()); - if (i != this->parameters.size() - 1) { - res.append(", "); - } - } - res.append(">"); - return res; - } else if (typedefName == FEGEN_INTEGER) { - if (this->parameters.size() == 0) { - return "Builtin_IntegerAttr"; - } else { - if (typeName == "int") { - return "\"int\""; - } else if (typeName == "bool") { - return "\"bool\""; - } - int size = this->getParameters(0)->getContent(); - if (size == 64) { - return "\"long\""; - } else if (size == 16) { - return "\"short\""; - } else { - std::cerr << "unsupport type: " << typeName << std::endl; - exit(0); - } - } - } else if (typedefName == FEGEN_FLOATPOINT) { - if (this->parameters.size() == 0) { - return "Builtin_FloatAttr"; - } else { - if (typeName == "float") { - return "\"float\""; - } else if (typeName == "double") { - return "\"double\""; - } else { - std::cerr << "unsupport type: " << typeName << std::endl; - exit(0); - } - } - } else { - std::cerr << "unsupport type: " << typeName << std::endl; - exit(0); - } + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } std::string fegen::Type::toStringForOpdef() { - // handle builtin type instance - auto typeName = this->typeName; - auto typedefName = this->typeDefine->getName(); - if (this->typeDefine->isCustome()) { - return this->typeDefine->getName(); - } else if (typedefName == FEGEN_LIST) { - std::string res = "Variadic<"; - assert(this->parameters.size() == 1); - res.append(this->parameters[0]->getContentStringForTypedef()); - res.append(">"); - return res; - } else if (typedefName == FEGEN_INTEGER) { - if (this->parameters.size() == 0) { - return "Builtin_Integer"; - } else { - if (typeName == "int") { - return "I32"; - } - int size = this->getParameters(0)->getContent(); - if (size == 64) { - return "I64"; - } else if (size == 16) { - return "I16"; - } - } - } - - std::cerr << "unsupport type: " << typeName << std::endl; - exit(0); + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } std::string fegen::Type::toStringForCppKind() { - // handle builtin type instance - auto typeName = this->typeName; - auto typedefName = this->typeDefine->getName(); - if (typedefName == FEGEN_LIST) { - assert(this->parameters.size() == 1); - std::string res = "std::vector<"; - res.append(this->parameters[0]->getContentStringForTypedef()); - res.append(">"); - return res; - } else if (typedefName == FEGEN_INTEGER) { - assert(this->parameters.size() == 1); - if (typeName == "int") { - return "int"; - } - int size = this->getParameters(0)->getContent(); - if (size == 64) { - return "long"; - } else if (size == 16) { - return "short"; - } - } else if (typedefName == FEGEN_FLOATPOINT) { - assert(this->parameters.size() == 1); - if (typeName == "float") { - return "float"; - } else if (typeName == "double") { - return "double"; - } - } - std::cerr << "Unsupported type: " << typeName << "in generating cpp type." - << std::endl; - exit(0); + assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -fegen::Type::~Type() { - for (auto p : this->parameters) { - delete p; - } +fegen::TypePtr fegen::Type::getPlaceHolder() { + return std::make_shared(); } -fegen::Type fegen::Type::getPlaceHolder() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), - 0); +fegen::TypePtr fegen::Type::getMetaType() { + return std::make_shared(); } -fegen::Type fegen::Type::getMetaType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_TYPE), 2); +fegen::TypePtr fegen::Type::getMetaTemplateType() { + return std::make_shared(); } -fegen::Type fegen::Type::getMetaTemplateType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), - 1); +fegen::TypePtr fegen::Type::getInt32Type() { + return std::make_shared(RightValue::getInteger(32)); } -fegen::Type fegen::Type::getInt32Type() { - return fegen::Type( - fegen::Type::TypeKind::CPP, "int", - {fegen::Value::get(fegen::Type::getPlaceHolder(), "size", - fegen::RightValue::getPlaceHolder())}, - fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +fegen::TypePtr fegen::Type::getFloatType() { + return std::make_shared(RightValue::getInteger(32)); +} + +fegen::TypePtr fegen::Type::getDoubleType() { + return std::make_shared(RightValue::getInteger(64)); +} + +fegen::TypePtr fegen::Type::getBoolType() { + return std::make_shared(RightValue::getInteger(1)); +} + +fegen::TypePtr fegen::Type::getIntegerType(fegen::RightValue size) { + return std::make_shared(size); +} + +fegen::TypePtr fegen::Type::getFloatPointType(fegen::RightValue size) { + return std::make_shared(size); +} + +fegen::TypePtr fegen::Type::getStringType() { + return std::make_shared(); +} + +fegen::TypePtr fegen::Type::getListType(fegen::TypePtr elementType) { + assert(elementType->typeLevel == 2 || elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getListType(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getListType(ty); } -fegen::Type fegen::Type::getFloatType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, "float", - {fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getInteger(32))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +fegen::TypePtr fegen::Type::getVectorType(fegen::TypePtr elementType, + fegen::RightValue size) { + assert(elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType), size); } -fegen::Type fegen::Type::getDoubleType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, "double", - {fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getInteger(64))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); +fegen::TypePtr fegen::Type::getVectorType(RightValue elementType, + RightValue size) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getVectorType(ty, size); } -fegen::Type fegen::Type::getBoolType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, "bool", - {fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getInteger(1))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +fegen::TypePtr fegen::Type::getTensorType(fegen::TypePtr elementType, + fegen::RightValue shape) { + assert(elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType), shape); } -fegen::Type fegen::Type::getIntegerType(fegen::Value *size) { - if (size->getContent() == 32) - return fegen::Type::getInt32Type(); - return fegen::Type( - fegen::Type::TypeKind::CPP, {size}, - fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3); +fegen::TypePtr fegen::Type::getTensorType(RightValue elementType, + RightValue shape) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getTensorType(ty, shape); } -fegen::Type fegen::Type::getFloatPointType(fegen::Value *size) { - if (size->getContent() == 32) { - return fegen::Type::getFloatType(); - } else if (size->getContent() == 64) { - return fegen::Type::getDoubleType(); +fegen::TypePtr fegen::Type::getOptionalType(fegen::TypePtr elementType) { + assert(elementType->typeLevel == 2 || elementType->typeLevel == 3); + return std::make_shared( + RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getOptionalType(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getOptionalType(ty); +} + +fegen::TypePtr fegen::Type::getAnyType(fegen::RightValue elementTypes) { + return std::make_shared(elementTypes); +} + +fegen::TypePtr +fegen::Type::getCustomeType(std::vector params, + fegen::TypeDefination *tydef) { + return std::make_shared(params, tydef); +} + +fegen::TypePtr +fegen::Type::getTemplateType(fegen::TypeDefination *typeDefination) { + return std::make_shared(typeDefination); +} + +/// @brief get name of Type Instance by jointsing template name and parameters, +/// for example: Integer + 32 --> Integer<32> +/// @return joint name +std::string jointTypeName(std::string templateName, + std::vector parameters) { + if (parameters.empty()) { + return templateName; } - return fegen::Type( - fegen::Type::TypeKind::CPP, {size}, - fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3); -} - -fegen::Type fegen::Type::getCharType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_CHAR), 3); -} - -fegen::Type fegen::Type::getStringType() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3); -} - -fegen::Type fegen::Type::getVectorType(fegen::Value *size, - fegen::Type elementType) { - assert(elementType.typeLevel == 3); - return fegen::Type( - fegen::Type::TypeKind::CPP, - {size, - fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::RightValue::getType(elementType))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR), - elementType.typeLevel); -} - -fegen::Type fegen::Type::getTensorType(fegen::Value *shape, - fegen::Type elementType) { - assert(elementType.typeLevel == 3); - return fegen::Type( - fegen::Type::TypeKind::CPP, - {shape, - fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::RightValue::getType(elementType))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR), - elementType.typeLevel); -} - -// List -fegen::Type fegen::Type::getListType(fegen::Type elementType) { - assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); - return fegen::Type( - fegen::Type::TypeKind::CPP, - {fegen::Value::get( - elementType.typeLevel == 2 ? fegen::Type::getMetaTemplateType() - : fegen::Type::getMetaType(), - "elementType", fegen::RightValue::getType(elementType))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), - elementType.typeLevel); -} - -// Optional -fegen::Type -fegen::Type::getOptionalType(fegen::Type elementType) { - assert(elementType.typeLevel == 2 || elementType.typeLevel == 3); - return fegen::Type( - fegen::Type::TypeKind::CPP, - {fegen::Value::get( - elementType.typeLevel == 2 ? fegen::Type::getMetaTemplateType() - : fegen::Type::getMetaType(), - "elementType", fegen::RightValue::getType(elementType))}, - fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), - elementType.typeLevel); -} - -// Any -fegen::Type -fegen::Type::getAnyType(std::vector elementTypes) { - std::vector p_elemTy; - int i = 0; - std::string name("elementType_"); - auto tyLevel = elementTypes[0].typeLevel; - assert(tyLevel == 2 || tyLevel == 3); - auto tyty = tyLevel == 2 ? fegen::Type::getMetaTemplateType() - : fegen::Type::getMetaType(); - for (auto &ty : elementTypes) { - assert(ty.typeLevel == tyLevel); - p_elemTy.push_back(fegen::Value::get( - tyty, name + std::to_string(i), fegen::RightValue::getType(ty))); - i++; + std::string res = templateName; + res.append("<"); + size_t count = parameters.size(); + auto firstParamStr = parameters[0].toString(); + res.append(firstParamStr); + for (size_t i = 1; i <= count - 1; i++) { + auto paramStr = parameters[i].toString(); + res.append(", "); + res.append(paramStr); } - return fegen::Type( - fegen::Type::TypeKind::CPP, p_elemTy, - fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), tyLevel); + res.append(">"); + return res; } -fegen::Type fegen::Type::getIntegerTemplate() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 2); +// class PlaceHolderType +fegen::PlaceHolderType::PlaceHolderType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_PLACEHOLDER, + fegen::Manager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), 0, + true) {} + +// class MetaType +fegen::MetaType::MetaType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_TYPE, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPE), 2, + true) {} + +std::string fegen::MetaType::toStringForTypedef() { return "\"Type\""; } + +// class MetaTemplate +fegen::MetaTemplate::MetaTemplate() + : Type(fegen::Type::TypeKind::CPP, FEGEN_TYPETEMPLATE, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), + 1, true) {} + +// class IntegerType + +fegen::IntegerType::IntegerType(RightValue size, TypeDefination *tyDef) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_INTEGER, {size}), + tyDef, 3, size.isConstant()), + size(size) {} + +fegen::IntegerType::IntegerType(fegen::RightValue size) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_INTEGER, {size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3, + size.isConstant()), + size(size) {} + +std::string fegen::IntegerType::toStringForTypedef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "\"int\""; + } else if (content == 1) { + return "\"bool\""; + } else if (content == 64) { + return "\"long\""; + } else if (content == 16) { + return "\"short\""; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } } -fegen::Type fegen::Type::getFloatPointTemplate() { - return fegen::Type( - fegen::Type::TypeKind::CPP, {}, - fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 2); +std::string fegen::IntegerType::toStringForOpdef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "I32"; + } else if (content == 64) { + return "I64"; + } else if (content == 16) { + return "I16"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } } -fegen::Type -fegen::Type::getInstanceType(fegen::TypeDefination *typeDefination, - std::vector parameters) { - return fegen::Type(fegen::Type::TypeKind::CPP, parameters, - typeDefination, 3); +std::string fegen::IntegerType::toStringForCppKind() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "int"; + } + if (content == 64) { + return "long"; + } else if (content == 16) { + return "short"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } } -fegen::Type -fegen::Type::getTemplateType(fegen::TypeDefination *typeDefination) { - return fegen::Type(fegen::Type::TypeKind::CPP, {}, typeDefination, - 2); + +// class FloatPointType +fegen::FloatPointType::FloatPointType(fegen::RightValue size) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_FLOATPOINT, {size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3, + size.isConstant()), + size(size) {} + +std::string fegen::FloatPointType::toStringForTypedef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "\"float\""; + } else if (content == 64) { + return "\"double\""; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +std::string fegen::FloatPointType::toStringForOpdef() { + return "FloatPointType::toStringForOpdef"; +} + +std::string fegen::FloatPointType::toStringForCppKind() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "float"; + } + if (content == 64) { + return "double"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +// class StringType +fegen::StringType::StringType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_STRING, + fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3, + true) {} + +// class ListType +fegen::ListType::ListType(fegen::RightValue elementType) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_LIST, {elementType}), + fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), + std::any_cast(elementType.getContent())->getTypeLevel(), + elementType.isConstant()), + elementType(elementType) {} + +std::string fegen::ListType::toStringForTypedef() { + std::string res = "ArrayRefParameter<"; + res.append(this->elementType.toStringForTypedef()); + res.append(">"); + return res; +} + +std::string fegen::ListType::toStringForOpdef() { + std::string res = "Variadic<"; + res.append(this->elementType.toStringForOpdef()); + res.append(">"); + return res; +} + +std::string fegen::ListType::toStringForCppKind() { + std::string res = "std::vector<"; + res.append(this->elementType.toStringForCppKind()); + res.append(">"); + return res; +} + +// class VectorType +fegen::VectorType::VectorType(RightValue elementType, RightValue size) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_VECTOR, {elementType, size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR), 3, + (elementType.isConstant() && size.isConstant())), + elementType(elementType), size(size) {} + +// class TensorType +fegen::TensorType::TensorType(RightValue elementType, RightValue shape) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_TENSOR, {elementType, shape}), + fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR), 3, + (elementType.isConstant() && shape.isConstant())), + elementType(elementType), shape(shape) {} + +// class OptionalType +fegen::OptionalType::OptionalType(RightValue elementType) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_OPTINAL, {elementType}), + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), + std::any_cast(elementType.getContent())->getTypeLevel(), + elementType.isConstant()), + elementType(elementType) {} + +// class AnyType + +inline int getTypeLevelOfListType(fegen::RightValue& elementTypes) { + auto listContent = std::any_cast>(elementTypes.getContent()); + fegen::TypePtr ty = std::any_cast(listContent[0]->getContent()); + return ty->getTypeLevel(); +} + +fegen::AnyType::AnyType(RightValue elementTypes) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_ANY, {elementTypes}), + fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), + getTypeLevelOfListType(elementTypes), + elementTypes.isConstant()), + elementTypes(elementTypes) {} + +// class CustomeType +inline bool isAllConstant(std::vector ¶ms) { + for (auto v : params) { + if (!v.isConstant()) { + return false; + } + } + return true; +} + +fegen::CustomeType::CustomeType(std::vector params, + TypeDefination *tydef) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_ANY, params), tydef, + 3, isAllConstant(params)), + params(params) {} + +// class TemplateType +fegen::TemplateType::TemplateType(TypeDefination *tydef) + : Type(fegen::Type::TypeKind::CPP, tydef->getName(), tydef, 2, true) {} + +fegen::TypePtr +fegen::TemplateType::instantiate(std::vector params) { + auto tydef = this->getTypeDefination(); + if (tydef->isCustome()) { + return Type::getCustomeType(params, tydef); + } else if (tydef->getName() == FEGEN_INTEGER) { + assert(params.size() == 1); + return Type::getIntegerType(params[0]); + } else if (tydef->getName() == FEGEN_FLOATPOINT) { + assert(params.size() == 1); + return Type::getFloatPointType(params[0]); + } else if (tydef->getName() == FEGEN_STRING) { + assert(params.size() == 0); + return Type::getStringType(); + } else if (tydef->getName() == FEGEN_LIST) { + assert(params.size() == 1); + return Type::getListType(params[0]); + } else if (tydef->getName() == FEGEN_VECTOR) { + assert(params.size() == 2); + return Type::getVectorType(params[0], params[1]); + } else if (tydef->getName() == FEGEN_TENSOR) { + assert(params.size() == 2); + return Type::getTensorType(params[0], params[1]); + } else if (tydef->getName() == FEGEN_OPTINAL) { + assert(params.size() == 1); + return Type::getOptionalType(params[0]); + } else if (tydef->getName() == FEGEN_ANY) { + assert(params.size() == 1); + return Type::getAnyType(params[0]); + } else { + assert(false); + } +} + +std::string fegen::TemplateType::toStringForTypedef() { + auto tyd = this->getTypeDefination(); + if (tyd->isCustome()) { + return this->getTypeDefination()->getName(); + } else if (tyd->getName() == FEGEN_INTEGER) { + return "Builtin_IntegerAttr"; + } else if (tyd->getName() == FEGEN_FLOATPOINT) { + return "Builtin_FloatAttr"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +std::string fegen::TemplateType::toStringForOpdef() { + auto tyd = this->getTypeDefination(); + if (tyd->isCustome()) { + return this->getTypeDefination()->getName(); + } else if (tyd->getName() == FEGEN_INTEGER) { + return "Builtin_Integer"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } } // class FegenTypeDefination @@ -485,11 +521,11 @@ fegen::TypeDefination::TypeDefination( fegen::TypeDefination * fegen::TypeDefination::get(std::string dialectName, std::string name, - std::vector parameters, - FegenParser::TypeDefinationDeclContext *ctx, - bool ifCustome) { + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome) { return new fegen::TypeDefination(std::move(dialectName), std::move(name), - std::move(parameters), ctx, ifCustome); + std::move(parameters), ctx, ifCustome); } std::string fegen::TypeDefination::getDialectName() { @@ -511,17 +547,13 @@ std::string fegen::TypeDefination::getMnemonic() { return this->mnemonic; } -void fegen::TypeDefination::setName(std::string name) { - this->name = name; -} +void fegen::TypeDefination::setName(std::string name) { this->name = name; } -const std::vector & -fegen::TypeDefination::getParameters() { +const std::vector &fegen::TypeDefination::getParameters() { return this->parameters; } -fegen::FegenParser::TypeDefinationDeclContext * -fegen::TypeDefination::getCtx() { +fegen::FegenParser::TypeDefinationDeclContext *fegen::TypeDefination::getCtx() { return this->ctx; } @@ -534,29 +566,17 @@ bool fegen::TypeDefination::isCustome() { return this->ifCustome; } // class Expression -fegen::RightValue::Expression::Expression(bool ifTerminal, - LiteralKind kind, - Type &exprTy, - bool isConstexpr) - : ifTerminal(ifTerminal), kind(kind), exprType(exprTy), - ifConstexpr(isConstexpr) {} +fegen::RightValue::Expression::Expression(bool ifTerminal, LiteralKind kind, + bool isConstexpr) + : ifTerminal(ifTerminal), kind(kind), ifConstexpr(isConstexpr) {} -bool fegen::RightValue::Expression::isTerminal() { - return this->ifTerminal; -} +bool fegen::RightValue::Expression::isTerminal() { return this->ifTerminal; } -fegen::RightValue::LiteralKind -fegen::RightValue::Expression::getKind() { +fegen::RightValue::LiteralKind fegen::RightValue::Expression::getKind() { return this->kind; } -fegen::Type &fegen::RightValue::Expression::getType() { - return this->exprType; -} - -bool fegen::RightValue::Expression::isConstexpr() { - return this->ifConstexpr; -} +bool fegen::RightValue::Expression::isConstexpr() { return this->ifConstexpr; } std::shared_ptr fegen::RightValue::Expression::getPlaceHolder() { @@ -564,17 +584,13 @@ fegen::RightValue::Expression::getPlaceHolder() { } std::shared_ptr -fegen::RightValue::Expression::getInteger(long long int content, - size_t size) { - return std::make_shared(content, - size); +fegen::RightValue::Expression::getInteger(largestInt content, size_t size) { + return std::make_shared(content, size); } std::shared_ptr -fegen::RightValue::Expression::getFloatPoint(long double content, - size_t size) { - return std::make_shared(content, - size); +fegen::RightValue::Expression::getFloatPoint(long double content, size_t size) { + return std::make_shared(content, size); } std::shared_ptr @@ -583,7 +599,7 @@ fegen::RightValue::Expression::getString(std::string content) { } std::shared_ptr -fegen::RightValue::Expression::getType(fegen::Type &content) { +fegen::RightValue::Expression::getTypeRightValue(fegen::TypePtr content) { return std::make_shared(content); } @@ -602,16 +618,16 @@ std::shared_ptr fegen::RightValue::Expression::binaryOperation( std::shared_ptr lhs, std::shared_ptr rhs, FegenOperator op) { - Type resTy = fegen::inferenceType({lhs, rhs}, op); + TypePtr resTy = fegen::inferenceType({lhs, rhs}, op); return std::make_shared( - op, std::vector>{ - lhs, rhs}); + op, + std::vector>{lhs, rhs}); } std::shared_ptr fegen::RightValue::Expression::unaryOperation( std::shared_ptr v, FegenOperator op) { - Type resTy = fegen::inferenceType({v}, op); + TypePtr resTy = fegen::inferenceType({v}, op); return std::make_shared( op, std::vector>{v}); } @@ -619,9 +635,8 @@ fegen::RightValue::Expression::unaryOperation( // class ExpressionNode fegen::RightValue::ExpressionNode::ExpressionNode(LiteralKind kind, - Type exprTy, - bool ifConstexpr) - : Expression(false, kind, exprTy, ifConstexpr) {} + bool ifConstexpr) + : Expression(false, kind, ifConstexpr) {} std::string fegen::RightValue::ExpressionNode::toString() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); @@ -639,6 +654,10 @@ std::string fegen::RightValue::ExpressionNode::toStringForCppKind() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } +fegen::TypePtr fegen::RightValue::ExpressionNode::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + inline bool isBinaryOperator(fegen::FegenOperator &op) { switch (op) { case fegen::FegenOperator::NEG: @@ -649,57 +668,6 @@ inline bool isBinaryOperator(fegen::FegenOperator &op) { } } -std::string getCppOperator(fegen::FegenOperator op) { - // switch(op){ - // OR, - // AND, - // EQUAL, - // NOT_EQUAL, - // LESS, - // LESS_EQUAL, - // GREATER, - // GREATER_EQUAL, - // ADD, - // SUB, - // MUL, - // DIV, - // MOD, - // POWER, - // NEG, - // NOT - // } -} - -// std::string res; -// auto opKind = this->op.index(); -// if(opKind == 0){ // function -// auto func = std::get<0>(this->op); -// // res.append(func.) -// // TODO: add FegenFunction methods. -// }else if(opKind == 1) { // operation -// assert(false); -// return res; -// }else{ // operator -// auto op = std::get<2>(this->op); -// if(isBinaryOperator(op)){ -// assert(this->params.size() == 2); -// res.append(this->params[0]->toStringForCppKind()); -// switch(op){ -// case fegen::FegenOperator::ADD:{ -// res.append() -// } -// } -// res.append(this->params[1]->toStringForCppKind()); -// }else{ - -// } -// switch(op) { -// case fegen::FegenOperator::ADD: { - -// } -// } -// } - // class FunctionCall inline bool isFuncParamsAllConstant( std::vector> ¶ms) { @@ -716,7 +684,6 @@ fegen::RightValue::FunctionCall::FunctionCall( fegen::Function *func, std::vector> params) : ExpressionNode(fegen::RightValue::LiteralKind::FUNC_CALL, - fegen::Type::getInt32Type(), isFuncParamsAllConstant(params)), func(func), params(std::move(params)) {} @@ -738,12 +705,15 @@ std::string fegen::RightValue::FunctionCall::toStringForCppKind() { std::any fegen::RightValue::FunctionCall::getContent() { return this; } +fegen::TypePtr fegen::RightValue::FunctionCall::getType() { + return this->func->getReturnType(); +} + // class OperationCall fegen::RightValue::OperationCall::OperationCall( fegen::Operation *op, std::vector> params) : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, - fegen::Type::getInt32Type(), isFuncParamsAllConstant(params)), op(op), params(std::move(params)) {} @@ -765,12 +735,15 @@ std::string fegen::RightValue::OperationCall::toStringForCppKind() { std::any fegen::RightValue::OperationCall::getContent() { return this; } +fegen::TypePtr fegen::RightValue::OperationCall::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + // class OperatorCall fegen::RightValue::OperatorCall::OperatorCall( fegen::FegenOperator op, std::vector> params) : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, - fegen::inferenceType(params, op), isFuncParamsAllConstant(params)), op(op), params(std::move(params)) {} @@ -792,11 +765,14 @@ std::string fegen::RightValue::OperatorCall::toStringForCppKind() { std::any fegen::RightValue::OperatorCall::getContent() { return this; } +fegen::TypePtr fegen::RightValue::OperatorCall::getType() { + return inferenceType(this->params, this->op); +} + // class ExpressionTerminal fegen::RightValue::ExpressionTerminal::ExpressionTerminal( - fegen::RightValue::LiteralKind kind, Type exprTy, - bool ifConstexpr) - : Expression(true, kind, exprTy, ifConstexpr) {} + fegen::RightValue::LiteralKind kind, bool ifConstexpr) + : Expression(true, kind, ifConstexpr) {} std::string fegen::RightValue::ExpressionTerminal::toString() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); @@ -814,10 +790,13 @@ std::string fegen::RightValue::ExpressionTerminal::toStringForCppKind() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } +fegen::TypePtr fegen::RightValue::ExpressionTerminal::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + // class PlaceHolder fegen::RightValue::PlaceHolder::PlaceHolder() - : ExpressionTerminal(fegen::RightValue::LiteralKind::MONOSTATE, - fegen::Type::getPlaceHolder(), true) {} + : ExpressionTerminal(fegen::RightValue::LiteralKind::MONOSTATE, true) {} std::any fegen::RightValue::PlaceHolder::getContent() { return std::monostate(); @@ -826,21 +805,9 @@ std::any fegen::RightValue::PlaceHolder::getContent() { std::string fegen::RightValue::PlaceHolder::toString() { return ""; } // class IntegerLiteral -fegen::RightValue::IntegerLiteral::IntegerLiteral(int content) - : ExpressionTerminal(fegen::RightValue::LiteralKind::INT, - fegen::Type::getInt32Type(), true), - content(content) {} - -fegen::RightValue::IntegerLiteral::IntegerLiteral(long long int content, - size_t size) - : ExpressionTerminal( - fegen::RightValue::LiteralKind::INT, - fegen::Type::getIntegerType(fegen::Value::get( - fegen::Type::getInt32Type(), "size", - fegen::RightValue::getByExpr( - std::make_shared( - size)))), - true), +fegen::RightValue::IntegerLiteral::IntegerLiteral(largestInt content, + size_t size) + : ExpressionTerminal(fegen::RightValue::LiteralKind::INT, true), content(content) {} std::any fegen::RightValue::IntegerLiteral::getContent() { @@ -851,15 +818,14 @@ std::string fegen::RightValue::IntegerLiteral::toString() { return std::to_string(this->content); } +fegen::TypePtr fegen::RightValue::IntegerLiteral::getType() { + return fegen::Type::getIntegerType(fegen::RightValue::getInteger(this->size)); +} + // class FloatPointLiteral -fegen::RightValue::FloatPointLiteral::FloatPointLiteral( - long double content, size_t size) - : ExpressionTerminal( - fegen::RightValue::LiteralKind::FLOAT, - fegen::Type::getFloatPointType( - fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getInteger(size))), - true), +fegen::RightValue::FloatPointLiteral::FloatPointLiteral(long double content, + size_t size) + : ExpressionTerminal(fegen::RightValue::LiteralKind::FLOAT, true), content(content) {} std::any fegen::RightValue::FloatPointLiteral::getContent() { @@ -870,10 +836,14 @@ std::string fegen::RightValue::FloatPointLiteral::toString() { return std::to_string(this->content); } +fegen::TypePtr fegen::RightValue::FloatPointLiteral::getType() { + return fegen::Type::getFloatPointType( + fegen::RightValue::getInteger(this->size)); +} + // class StringLiteral fegen::RightValue::StringLiteral::StringLiteral(std::string content) - : ExpressionTerminal(fegen::RightValue::LiteralKind::STRING, - fegen::Type::getStringType(), true), + : ExpressionTerminal(fegen::RightValue::LiteralKind::STRING, true), content(content) {} std::any fegen::RightValue::StringLiteral::getContent() { @@ -888,56 +858,65 @@ std::string fegen::RightValue::StringLiteral::toString() { return res; } +fegen::TypePtr fegen::RightValue::StringLiteral::getType() { + return fegen::Type::getStringType(); +} + // class TypeLiteral // Check params of content and return ture if params are all const expr. -inline bool isParamsConstant(fegen::Type &content) { - for (auto param : content.getParameters()) { - if (!param->getExpr()->isConstexpr()) { - return false; - } - } +inline bool isParamsConstant(fegen::TypePtr content) { + // for (auto param : content.getParameters()) { + // if (!param->getExpr()->isConstexpr()) { + // return false; + // } + // } return true; } // Get type of type literal. -fegen::Type getTypeLiteralType(fegen::Type &content) { - if (content.getTypeLevel() == 2) { +fegen::TypePtr getTypeLiteralType(fegen::TypePtr content) { + if (content->getTypeLevel() == 2) { return fegen::Type::getMetaTemplateType(); - } else if (content.getTypeLevel() == 3) { + } else if (content->getTypeLevel() == 3) { return fegen::Type::getMetaType(); } else { return fegen::Type::getPlaceHolder(); } } -fegen::RightValue::TypeLiteral::TypeLiteral(fegen::Type &content) +fegen::RightValue::TypeLiteral::TypeLiteral(fegen::TypePtr content) : ExpressionTerminal(fegen::RightValue::LiteralKind::TYPE, - getTypeLiteralType(content), - isParamsConstant(content)), + content->isConstant()), content(content) {} -std::any fegen::RightValue::TypeLiteral::getContent() { - return this->content; -} +std::any fegen::RightValue::TypeLiteral::getContent() { return this->content; } std::string fegen::RightValue::TypeLiteral::toString() { - return this->content.getTypeName(); + return this->content->getTypeName(); } std::string fegen::RightValue::TypeLiteral::toStringForTypedef() { - return this->content.toStringForTypedef(); + return this->content->toStringForTypedef(); } std::string fegen::RightValue::TypeLiteral::toStringForOpdef() { - return this->content.toStringForOpdef(); + return this->content->toStringForOpdef(); } std::string fegen::RightValue::TypeLiteral::toStringForCppKind() { - return this->content.toStringForCppKind(); + return this->content->toStringForCppKind(); } -// class ExpressionTerminal +fegen::TypePtr fegen::RightValue::TypeLiteral::getType() { + if (this->content->getTypeLevel() == 2) { + return fegen::Type::getMetaTemplateType(); + } else if (this->content->getTypeLevel() == 3) { + return fegen::Type::getMetaType(); + } else { + assert(false); + } +} // Return ture if all Expressions in content are all true. bool isExpressionListConst( @@ -954,12 +933,10 @@ bool isExpressionListConst( fegen::RightValue::ListLiteral::ListLiteral( std::vector> &content) : ExpressionTerminal(fegen::RightValue::LiteralKind::VECTOR, - content[0]->exprType, isExpressionListConst(content)), + isExpressionListConst(content)), content(content) {} -std::any fegen::RightValue::ListLiteral::getContent() { - return this->content; -} +std::any fegen::RightValue::ListLiteral::getContent() { return this->content; } std::string fegen::RightValue::ListLiteral::toString() { std::string res; @@ -1000,20 +977,26 @@ std::string fegen::RightValue::ListLiteral::toStringForOpdef() { return res; } +fegen::TypePtr fegen::RightValue::ListLiteral::getType() { + return fegen::Type::getListType(this->content[0]->getType()); +} + // class LeftValue fegen::RightValue::LeftValue::LeftValue(fegen::Value *content) : ExpressionTerminal(fegen::RightValue::LiteralKind::LEFT_VAR, - content->getType(), content->getExpr()->isConstexpr()), + content->getExpr()->isConstexpr()), content(content) {} -std::any fegen::RightValue::LeftValue::getContent() { - return this->content; -} +std::any fegen::RightValue::LeftValue::getContent() { return this->content; } std::string fegen::RightValue::LeftValue::toString() { return this->content->getName(); } +fegen::TypePtr fegen::RightValue::LeftValue::getType() { + return this->content->getType(); +} + // class FegenRightValue fegen::RightValue::RightValue( std::shared_ptr content) @@ -1023,9 +1006,7 @@ fegen::RightValue::LiteralKind fegen::RightValue::getLiteralKind() { return this->content->getKind(); } -std::string fegen::RightValue::toString() { - return this->content->toString(); -} +std::string fegen::RightValue::toString() { return this->content->toString(); } std::string fegen::RightValue::toStringForTypedef() { return this->content->toStringForTypedef(); @@ -1039,52 +1020,44 @@ std::string fegen::RightValue::toStringForCppKind() { return this->content->toStringForCppKind(); } -std::any fegen::RightValue::getContent() { - return this->content->getContent(); -} +std::any fegen::RightValue::getContent() { return this->content->getContent(); } -fegen::Type &fegen::RightValue::getType() { - return this->content->getType(); -} +fegen::TypePtr fegen::RightValue::getType() { return this->content->getType(); } -std::shared_ptr -fegen::RightValue::getExpr() { +std::shared_ptr fegen::RightValue::getExpr() { return this->content; } +bool fegen::RightValue::isConstant() { return this->content->isConstexpr(); } + fegen::RightValue fegen::RightValue::getPlaceHolder() { - return fegen::RightValue( - fegen::RightValue::Expression::getPlaceHolder()); + return fegen::RightValue(fegen::RightValue::Expression::getPlaceHolder()); } -fegen::RightValue fegen::RightValue::getInteger(long long int content, - size_t size) { +fegen::RightValue fegen::RightValue::getInteger(largestInt content, + size_t size) { return fegen::RightValue( fegen::RightValue::Expression::getInteger(content, size)); } -fegen::RightValue -fegen::RightValue::getFloatPoint(long double content, size_t size) { +fegen::RightValue fegen::RightValue::getFloatPoint(long double content, + size_t size) { return fegen::RightValue( fegen::RightValue::Expression::getFloatPoint(content, size)); } fegen::RightValue fegen::RightValue::getString(std::string content) { - return fegen::RightValue( - fegen::RightValue::Expression::getString(content)); + return fegen::RightValue(fegen::RightValue::Expression::getString(content)); } -fegen::RightValue -fegen::RightValue::getType(fegen::Type &content) { +fegen::RightValue fegen::RightValue::getTypeRightValue(fegen::TypePtr content) { return fegen::RightValue( - fegen::RightValue::Expression::getType(content)); + fegen::RightValue::Expression::getTypeRightValue(content)); } fegen::RightValue fegen::RightValue::getList( std::vector> &content) { - return fegen::RightValue( - fegen::RightValue::Expression::getList(content)); + return fegen::RightValue(fegen::RightValue::Expression::getList(content)); } -fegen::RightValue -fegen::RightValue::getLeftValue(fegen::Value *content) { +fegen::RightValue fegen::RightValue::getLeftValue(fegen::Value *content) { return fegen::RightValue( fegen::RightValue::Expression::getLeftValue(content)); } @@ -1096,10 +1069,9 @@ fegen::RightValue fegen::RightValue::getByExpr( } // class FegenValue -fegen::Value::Value(fegen::Type type, std::string name, - fegen::RightValue content) - : type(std::move(type)), name(std::move(name)), - content(std::move(content)) {} +fegen::Value::Value(fegen::TypePtr type, std::string name, + fegen::RightValue content) + : type(type), name(std::move(name)), content(std::move(content)) {} fegen::Value::Value(const fegen::Value &rhs) : type(rhs.type), name(rhs.name), content(rhs.content) {} @@ -1107,14 +1079,12 @@ fegen::Value::Value(fegen::Value &&rhs) : type(std::move(rhs.type)), name(std::move(rhs.name)), content(std::move(rhs.content)) {} -fegen::Value *fegen::Value::get(fegen::Type type, - std::string name, - RightValue content) { - return new fegen::Value(std::move(type), std::move(name), - std::move(content)); +fegen::Value *fegen::Value::get(fegen::TypePtr type, std::string name, + RightValue content) { + return new fegen::Value(type, std::move(name), std::move(content)); } -fegen::Type &fegen::Value::getType() { return this->type; } +fegen::TypePtr fegen::Value::getType() { return this->type; } std::string fegen::Value::getName() { return this->name; } @@ -1142,18 +1112,17 @@ std::string fegen::Value::getContentStringForCppKind() { return this->content.toStringForCppKind(); } -std::shared_ptr -fegen::Value::getExpr() { +std::shared_ptr fegen::Value::getExpr() { return this->content.getExpr(); } fegen::ParserRule::ParserRule(std::string content, fegen::ParserNode *src, - antlr4::ParserRuleContext *ctx) + antlr4::ParserRuleContext *ctx) : content(content), src(src), ctx(ctx) {} fegen::ParserRule *fegen::ParserRule::get(std::string content, - fegen::ParserNode *src, - antlr4::ParserRuleContext *ctx) { + fegen::ParserNode *src, + antlr4::ParserRuleContext *ctx) { return new fegen::ParserRule(content, src, ctx); } @@ -1180,17 +1149,18 @@ bool fegen::ParserRule::addReturn(fegen::Value output) { void fegen::ParserRule::setSrc(ParserNode *src) { this->src = src; } fegen::ParserNode::ParserNode(std::vector &&rules, - antlr4::ParserRuleContext *ctx, - fegen::ParserNode::NodeType ntype) + antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) : rules(rules), ctx(ctx), ntype(ntype) {} -fegen::ParserNode *fegen::ParserNode::get(std::vector rules, - antlr4::ParserRuleContext *ctx, - fegen::ParserNode::NodeType ntype) { +fegen::ParserNode * +fegen::ParserNode::get(std::vector rules, + antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) { return new fegen::ParserNode(std::move(rules), ctx, ntype); } fegen::ParserNode *fegen::ParserNode::get(antlr4::ParserRuleContext *ctx, - fegen::ParserNode::NodeType ntype) { + fegen::ParserNode::NodeType ntype) { std::vector rules; return new fegen::ParserNode(std::move(rules), ctx, ntype); } @@ -1273,9 +1243,9 @@ class StmtGenerator : FegenParserBaseVisitor { : manager(Manager::getManager()), emitter(emitter) {} std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto var = manager.getStmtContent(ctx->identifier()); - switch (var->getType().getTypeKind()) { + switch (var->getType()->getTypeKind()) { case fegen::Type::TypeKind::CPP: { - this->emitter << var->getType().toStringForCppKind() << " " + this->emitter << var->getType()->toStringForCppKind() << " " << var->getName(); if (ctx->expression()) { auto expr = this->manager.getStmtContent( @@ -1392,10 +1362,11 @@ void fegen::Manager::emitTypeDefination() { emitter.tab(); for (size_t i = 0; i <= tyDef->getParameters().size() - 1; i++) { auto param = tyDef->getParameters()[i]; - auto ¶mTy = param->getType(); + auto paramTy = param->getType(); auto paramName = param->getName(); - auto paramTyStr = paramTy.toStringForTypedef(); - emitter << paramTyStr << ":" << "$" << paramName; + auto paramTyStr = paramTy->toStringForTypedef(); + emitter << paramTyStr << ":" + << "$" << paramName; if (i != tyDef->getParameters().size() - 1) { emitter << ", "; } @@ -1491,7 +1462,7 @@ void fegen::Manager::emitOpDefination() { emitter.newLine(); emitter.tab(); for (auto param : opDef->getArguments()) { - auto paramTyStr = param->getType().toStringForOpdef(); + auto paramTyStr = param->getType()->toStringForOpdef(); auto paramName = param->getName(); emitter << paramTyStr << " : $" << paramName; emitter.newLine(); @@ -1504,7 +1475,7 @@ void fegen::Manager::emitOpDefination() { emitter.newLine(); emitter.tab(); for (auto param : opDef->getArguments()) { - auto paramTyStr = param->getType().toStringForOpdef(); + auto paramTyStr = param->getType()->toStringForOpdef(); auto paramName = param->getName(); emitter << paramTyStr << " : $" << paramName; emitter.newLine(); @@ -1583,110 +1554,102 @@ void fegen::Manager::emitTdFiles() { void fegen::Manager::initbuiltinTypes() { // placeholder type auto placeholderTypeDefination = fegen::TypeDefination::get( - "fegen_builtin", FEGEN_PLACEHOLDER, {}, nullptr, false); + FEGEN_DIALECT_NAME, FEGEN_PLACEHOLDER, {}, nullptr, false); this->typeDefMap.insert({FEGEN_PLACEHOLDER, placeholderTypeDefination}); // Type this->typeDefMap.insert( - {FEGEN_TYPE, fegen::TypeDefination::get("fegen_builtin", FEGEN_TYPE, - {}, nullptr, false)}); + {FEGEN_TYPE, fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_TYPE, + {}, nullptr, false)}); // TypeTemplate this->typeDefMap.insert( {FEGEN_TYPETEMPLATE, - fegen::TypeDefination::get("fegen_builtin", FEGEN_TYPETEMPLATE, {}, - nullptr, false)}); - - // recursive define Integer Type - // Integer>> - auto intTypeDefination = fegen::TypeDefination::get( - "fegen_builtin", FEGEN_INTEGER, {}, nullptr, false); - auto intType = fegen::Type( - fegen::Type::TypeKind::CPP, - {fegen::Value::get(fegen::Type::getPlaceHolder(), "size", - fegen::RightValue::getPlaceHolder())}, - intTypeDefination, false); - // parameters of Integer is int32(Integer<32>) - intTypeDefination->parameters.push_back(fegen::Value::get( - intType, "size", fegen::RightValue::getPlaceHolder())); - this->typeDefMap.insert({FEGEN_INTEGER, intTypeDefination}); + fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_TYPETEMPLATE, {}, + nullptr, false)}); + + // Integer + auto intTydef = fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_INTEGER, + {}, nullptr, false); + auto paramOfIntTydef = Value::get( + std::make_shared(RightValue::getInteger(32), intTydef), + "size", fegen::RightValue::getPlaceHolder()); + intTydef->parameters.push_back(paramOfIntTydef); + this->typeDefMap.insert({FEGEN_INTEGER, intTydef}); // FloatPoint this->typeDefMap.insert( {FEGEN_FLOATPOINT, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_FLOATPOINT, + FEGEN_DIALECT_NAME, FEGEN_FLOATPOINT, {fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getPlaceHolder())}, + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); - // Char - this->typeDefMap.insert( - {FEGEN_CHAR, fegen::TypeDefination::get("fegen_builtin", FEGEN_CHAR, - {}, nullptr, false)}); - // String - this->typeDefMap.insert( - {FEGEN_STRING, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_STRING, {}, nullptr, false)}); + this->typeDefMap.insert({FEGEN_STRING, fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_STRING, + {}, nullptr, false)}); // Vector this->typeDefMap.insert( {FEGEN_VECTOR, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_VECTOR, - {fegen::Value::get(fegen::Type::getInt32Type(), "size", - fegen::RightValue::getPlaceHolder()), - fegen::Value::get(fegen::Type::getMetaType(), - "elementType", - fegen::RightValue::getPlaceHolder())}, + FEGEN_DIALECT_NAME, FEGEN_VECTOR, + { + fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder()), + fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getPlaceHolder()), + }, nullptr, false)}); // List (this should be ahead of Tensor and Any Type defination) this->typeDefMap.insert( - {FEGEN_LIST, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_LIST, - {fegen::Value::get( - fegen::Type::getMetaType(), "elementType", - fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); + {FEGEN_LIST, + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_LIST, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false)}); // Tensor this->typeDefMap.insert( {FEGEN_TENSOR, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_TENSOR, - {fegen::Value::get( - fegen::Type::getListType(fegen::Type::getInt32Type()), - "shape", fegen::RightValue::getPlaceHolder()), - fegen::Value::get(fegen::Type::getMetaType(), - "elementType", - fegen::RightValue::getPlaceHolder())}, + FEGEN_DIALECT_NAME, FEGEN_TENSOR, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder()), + fegen::Value::get( + fegen::Type::getListType(fegen::Type::getInt32Type()), "shape", + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Optional this->typeDefMap.insert( - {FEGEN_OPTINAL, fegen::TypeDefination::get( - "fegen_builtin", FEGEN_OPTINAL, - {fegen::Value::get( - fegen::Type::getMetaType(), "elementType", + {FEGEN_OPTINAL, + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_OPTINAL, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); + nullptr, false)}); // Any this->typeDefMap.insert( - {FEGEN_ANY, - fegen::TypeDefination::get( - "fegen_builtin", FEGEN_ANY, - {fegen::Value::get( - fegen::Type::getListType(fegen::Type::getMetaType()), - "elementType", fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); -} - -fegen::TypeDefination * -fegen::Manager::getTypeDefination(std::string name) { - return this->typeDefMap[name]; + {FEGEN_ANY, fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_ANY, + {fegen::Value::get( + fegen::Type::getListType(fegen::Type::getMetaType()), + "elementType", fegen::RightValue::getPlaceHolder())}, + nullptr, false)}); +} + +fegen::TypeDefination *fegen::Manager::getTypeDefination(std::string name) { + auto it = this->typeDefMap.find(name); + if (it != this->typeDefMap.end()) { + return it->second; + } + assert(false); } bool fegen::Manager::addTypeDefination(fegen::TypeDefination *tyDef) { @@ -1697,8 +1660,7 @@ bool fegen::Manager::addTypeDefination(fegen::TypeDefination *tyDef) { return true; } -fegen::Operation * -fegen::Manager::getOperationDefination(std::string name) { +fegen::Operation *fegen::Manager::getOperationDefination(std::string name) { return this->operationMap[name]; } @@ -1711,7 +1673,7 @@ bool fegen::Manager::addOperationDefination(fegen::Operation *opDef) { } void fegen::Manager::addStmtContent(antlr4::ParserRuleContext *ctx, - std::any content) { + std::any content) { this->stmtContentMap.insert({ctx, content}); } @@ -1727,7 +1689,7 @@ fegen::Manager::~Manager() { } } -fegen::Type fegen::inferenceType( +fegen::TypePtr fegen::inferenceType( std::vector> operands, fegen::FegenOperator op) { // TODO: infer type diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 index 23feea1b2a..b7b14433e4 100644 --- a/frontend/FrontendGen/lib/FegenParser.g4 +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -388,7 +388,7 @@ listLiteral // type system typeSpec : valueKind? typeInstance # typeInstanceSpec - | typeTemplate # typeTemplateSpce + | valueKind? typeTemplate # typeTemplateSpce | valueKind? collectType # collectTypeSpec ; @@ -398,9 +398,9 @@ valueKind | ATTRIBUTE ; -// 这里的identifier是不是没用? + typeInstance - : typeTemplate (Less typeTemplateParam (Comma typeTemplateParam)* Greater)? + : typeTemplate Less typeTemplateParam (Comma typeTemplateParam)* Greater | builtinTypeInstances | identifier ; diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp index 5dc096eb8a..882246dedf 100644 --- a/frontend/FrontendGen/lib/FegenVisitor.cpp +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -1,7 +1,7 @@ #include "FegenVisitor.h" bool fegen::checkParams(std::vector &expected, - std::vector &actual) { + std::vector &actual) { return true; } diff --git a/frontend/FrontendGen/lib/Scope.cpp b/frontend/FrontendGen/lib/Scope.cpp index 443536f11e..9294cb8897 100644 --- a/frontend/FrontendGen/lib/Scope.cpp +++ b/frontend/FrontendGen/lib/Scope.cpp @@ -24,18 +24,6 @@ fegen::FegenScope::FegenScope(unsigned int scopeId, fegen::FegenScope *parentScope) : scopeId(scopeId), parentScope(parentScope) {} -fegen::TypeDefination *fegen::FegenScope::findTypeDef(std::string name) { - return this->typeTable.get(name); -} - -void fegen::FegenScope::addTypeDef(TypeDefination *tyDef) { - this->typeTable.add(tyDef->getName(), tyDef); -} - -bool fegen::FegenScope::isExistTypeDef(std::string name) { - return this->typeTable.exist(name); -} - fegen::Value *fegen::FegenScope::findVar(std::string name) { return this->varTable.get(name); } @@ -94,24 +82,4 @@ fegen::Value *fegen::ScopeStack::attemptFindVar(std::string name) { p = p->parentScope; } return nullptr; -} - -bool fegen::ScopeStack::attemptAddTypeDef(fegen::TypeDefination *tyDef) { - if (this->currentScope->isExistTypeDef(tyDef->getName())) { - return false; - } - this->currentScope->addTypeDef(tyDef); - return true; -} - -fegen::TypeDefination * -fegen::ScopeStack::attemptFindTypeDef(std::string name) { - auto p = this->currentScope; - while (p != nullptr) { - if (p->isExistTypeDef(name)) { - return p->findTypeDef(name); - } - p = p->parentScope; - } - return nullptr; } \ No newline at end of file From da506b7a210e3ca32014c28d875883f37fa495c4 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Thu, 18 Jul 2024 08:49:19 +0000 Subject: [PATCH 10/17] [FrontendGen] update function codegen --- examples/FrontendGen/function.fegen | 19 +-- frontend/FrontendGen/include/FegenVisitor.h | 14 ++- frontend/FrontendGen/lib/FegenLexer.g4 | 2 + frontend/FrontendGen/lib/FegenManager.cpp | 122 +++++++++++++------- frontend/FrontendGen/lib/FegenParser.g4 | 5 + 5 files changed, 107 insertions(+), 55 deletions(-) diff --git a/examples/FrontendGen/function.fegen b/examples/FrontendGen/function.fegen index 35da97395d..1c3745e083 100644 --- a/examples/FrontendGen/function.fegen +++ b/examples/FrontendGen/function.fegen @@ -3,15 +3,16 @@ fegen toy double stod(string numStr){ float res = 0.0; int c = 1; - if(c == 0){ - int charNum = 0; - int intNum = 1; - intNum = 1; - }else if (c == 1){ - int charNum = 1; - }else { - int charNum = 2; + for(int i = 0; i < 3; i = i+1){ + if(c == 0){ + int charNum = 0; + int intNum = 1; + intNum = 1; + }else if (c == 1){ + int charNum = 1; + }else { + int charNum = 2; + } } - return res; } \ No newline at end of file diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 10b0d302eb..633029d669 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -610,6 +610,7 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { sstack.pushScope(); auto returnType = std::any_cast(this->visit(ctx->typeSpec())); + manager.addStmtContent(ctx, returnType); auto functionName = std::any_cast(this->visit(ctx->funcName())); auto hasfunc = manager.functionMap.find(functionName); @@ -624,13 +625,14 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->funcParams())); this->visit(ctx->statementBlock()); auto function = fegen::Function::get(functionName, functionParams, &returnType); - manager.functionMap.insert(std::pair{functionName, function}); + manager.functionMap.insert({functionName, function}); sstack.popScope(); return nullptr; } std::any visitFuncName(FegenParser::FuncNameContext *ctx) override { auto functionName = ctx->identifier()->getText(); + manager.addStmtContent(ctx, functionName); return functionName; } @@ -640,17 +642,20 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i < ctx->typeSpec().size(); i++) { auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); + // manager.addStmtContent(ctx, paramType); auto paramName = ctx->identifier(i)->getText(); auto param = fegen::Value::get(paramType, paramName, fegen::RightValue::getPlaceHolder()); paramsList.push_back(param); sstack.attemptAddVar(param); } + manager.addStmtContent(ctx, paramsList); return paramsList; } std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto varType = std::any_cast(this->visit(ctx->typeSpec())); + manager.addStmtContent(ctx, varType); auto varName = ctx->identifier()->getText(); fegen::Value *var; if (ctx->expression()) { @@ -673,7 +678,7 @@ class FegenVisitor : public FegenParserBaseVisitor { fegen::RightValue::getPlaceHolder()); } sstack.attemptAddVar(var); - manager.addStmtContent(ctx, varType); + return var; } @@ -779,6 +784,11 @@ class FegenVisitor : public FegenParserBaseVisitor { return nullptr; } + std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { + this->visit(ctx->expression()); + return nullptr; + } + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { auto opName = ctx->opName()->getText(); auto opDef = std::any_cast(this->visit(ctx->opBlock())); diff --git a/frontend/FrontendGen/lib/FegenLexer.g4 b/frontend/FrontendGen/lib/FegenLexer.g4 index 5dc8006b21..6b5e0aa2a3 100644 --- a/frontend/FrontendGen/lib/FegenLexer.g4 +++ b/frontend/FrontendGen/lib/FegenLexer.g4 @@ -118,6 +118,8 @@ IN: 'in'; WHILE: 'while'; +RETURN: 'return'; + // identifiers LexerRuleName: UPPERCASE (NONDIGIT | DIGIT)*; diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index e513468303..1763d7dde9 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1334,6 +1334,7 @@ void fegen::Manager::emitG4() { emitter.shiftTab(); emitter.newLine(); } + fileStream.close(); } // TODO: emit to file @@ -1739,29 +1740,57 @@ namespace fegen { class StmtVisitor : public FegenParserBaseVisitor { private: Manager &manager; + Emitter &emitter; public: - StmtVisitor() : manager(Manager::getManager()) {} + StmtVisitor(Emitter &emitter) : manager(Manager::getManager()), emitter(emitter) {} + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + auto returnType = std::any_cast(manager.stmtContentMap[ctx]); + auto functionName = std::any_cast(manager.stmtContentMap[ctx->funcName()]); + emitter << returnType.getTypeName() << " " + << functionName << "("; + auto paraList = std::any_cast>(manager.stmtContentMap[ctx->funcParams()]); + for (auto para : paraList) { + emitter << para->getType().getTypeName() << " " << para->getName(); + if (para != paraList.back()) + emitter << ", "; + } + emitter << "){"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + return nullptr; + } + std::any visitStatementBlock(FegenParser::StatementBlockContext *ctx) override { + for(size_t i = 0; i < ctx->statement().size(); i++){ + this->visit(ctx->statement(i)); + if(!(ctx->statement(i)->ifStmt()||ctx->statement(i)->forStmt())) + emitter << ";"; + emitter.newLine(); + } + return nullptr; + } std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - Emitter emitter(std::cout); auto varType = - std::any_cast(manager.stmtContentMap[ctx]); - auto varName = ctx->identifier()->toString(); - auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); - emitter << varType->getName() << " " << varName << " = " << expr->toString() << ";"; - emitter.newLine(); + std::any_cast(manager.stmtContentMap[ctx]); + auto varName = ctx->identifier()->getText(); + emitter << varType.getTypeName() << " " << varName; + if(ctx->expression()){ + auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + emitter << " = " << expr->toString(); + } return nullptr; } std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { - Emitter emitter(std::cout); - auto varName = ctx->identifier()->toString(); + auto varName = ctx->identifier()->getText(); auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); - emitter << varName << " = " << expr->toString() << ";"; - emitter.newLine(); + emitter << varName << " = " << expr->toString(); return nullptr; } std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { - Emitter emitter(std::cout); auto function = std::any_cast(manager.stmtContentMap[ctx]); emitter << function->getName() << " ("; @@ -1776,7 +1805,6 @@ class StmtVisitor : public FegenParserBaseVisitor { return nullptr; } std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { - Emitter emitter(std::cout); this->visit(ctx->ifBlock(0)); for(size_t i = 1; i < ctx->ifBlock().size(); i++){ emitter << " else "; @@ -1786,60 +1814,66 @@ class StmtVisitor : public FegenParserBaseVisitor { return nullptr; } std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { - Emitter emitter(std::cout); auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); emitter << "if (" << expr->toString() << "){"; - emitter.newLine(); emitter.tab(); + emitter.newLine(); this->visit(ctx->statementBlock()); emitter.shiftTab(); emitter << "}"; return nullptr; } std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { - Emitter emitter(std::cout); emitter << "else {"; - emitter.newLine(); emitter.tab(); - this->visit(ctx->statementBlock()); emitter.newLine(); + this->visit(ctx->statementBlock()); emitter.shiftTab(); emitter << "}"; - emitter.newLine(); + return nullptr; } // TODO: 支持for循环 std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { - Emitter emitter(std::cout); - emitter << "for ("; + if (ctx->varDeclStmt()) { + emitter << "for ("; + this->visit(ctx->varDeclStmt()); + emitter << "; "; + auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + emitter << expr->toString() << "; "; + this->visit(ctx->assignStmt(0)); + emitter << ") {"; + } else { + this->visit(ctx->assignStmt(0)); + emitter << " "; + auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + emitter << expr->toString() << "; "; + this->visit(ctx->assignStmt(1)); + emitter << ") {"; + } + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; return nullptr; } + std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << "return " << expr->toString(); + return nullptr; + } }; } // namespace fegen void fegen::Manager::emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST) { - Emitter emitter(std::cout); - fegen::StmtVisitor visitor; - - for (auto function_pair : this->functionMap) { - auto functionName = function_pair.first; - auto function = function_pair.second; - auto paraList = function->getInputTypeList(); - emitter << function->getReturnType()->getTypeName() << " " - << functionName << "("; - for (auto para : paraList) { - emitter << para->getContentStringForTypedef() << " " << para->getName(); - if (para != paraList.back()) - emitter << ", "; - } - emitter << "){"; - emitter.newLine(); - emitter.tab(); - // TODO::function body - visitor.visit(moduleAST); - emitter.shiftTab(); - emitter.newLine(); - emitter << "}"; - } + std::ofstream fileStream; + fileStream.open(this->moduleName + "Function.cpp"); + fegen::Emitter emitter(fileStream); + //Emitter emitter(std::cout); + StmtVisitor visitor(emitter); + visitor.visit(moduleAST); + fileStream.close(); } \ No newline at end of file diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 index 8a6c671888..d37dcb8afc 100644 --- a/frontend/FrontendGen/lib/FegenParser.g4 +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -262,6 +262,7 @@ statement | opInvokeStmt Semi | ifStmt | forStmt + | returnBlock Semi ; varDeclStmt @@ -304,6 +305,10 @@ forStmt : FOR LeftParen (assignStmt | varDeclStmt) Semi expression Semi assignStmt RightParen statementBlock ; +returnBlock + : RETURN expression + ; + // expression expression : andExpr (Logic_OR andExpr)* From c6dfcf7ae1d57817a6a138f6dade36e2bb5dcc38 Mon Sep 17 00:00:00 2001 From: chh Date: Thu, 18 Jul 2024 20:09:26 +0800 Subject: [PATCH 11/17] [FrontendGen] Refactor fegen::Type, add TypeTemplate subclasses, which representing template types. --- frontend/FrontendGen/include/FegenManager.h | 114 ++++++- frontend/FrontendGen/include/FegenVisitor.h | 52 ++- frontend/FrontendGen/lib/FegenManager.cpp | 355 +++++++++++++++----- 3 files changed, 416 insertions(+), 105 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 4a84bbe2b9..dd3d19fb50 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -10,6 +10,7 @@ #include #include +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" @@ -21,7 +22,6 @@ #define FEGEN_TYPETEMPLATE "TypeTemplate" #define FEGEN_INTEGER "Integer" #define FEGEN_FLOATPOINT "FloatPoint" -#define FEGEN_CHAR "Char" #define FEGEN_STRING "String" #define FEGEN_VECTOR "Vector" #define FEGEN_TENSOR "Tensor" @@ -37,7 +37,6 @@ class Type; class Manager; class Value; class RightValue; -class Expression; using TypePtr = std::shared_ptr; using largestInt = long long int; @@ -206,7 +205,33 @@ class Type { static TypePtr getCustomeType(std::vector params, TypeDefination* tydef); - static TypePtr getTemplateType(TypeDefination *typeDefination); + // Integer + static TypePtr getIntegerTemplate(); + + // FloatPoint + static TypePtr getFloatPointTemplate(); + + // string + static TypePtr getStringTemplate(); + + // List (elementType is template) + static TypePtr getListTemplate(TypePtr elementType); + static TypePtr getListTemplate(RightValue elementType); + + // Vector + static TypePtr getVectorTemplate(); + + // Tensor + static TypePtr getTensorTemplate(); + + // Optional (elementType is template) + static TypePtr getOptionalTemplate(TypePtr elementType); + static TypePtr getOptionalTemplate(RightValue elementType); + + // Any<[elementType1, elementType2, ...]> (elementType* is template) + static TypePtr getAnyTemplate(RightValue elementTypes); + + static TypePtr getCustomeTemplate(TypeDefination* tydef); }; class TypeDefination { @@ -554,7 +579,76 @@ class CustomeType : public Type { class TemplateType : public Type { public: TemplateType(TypeDefination* tydef); - TypePtr instantiate(std::vector params); + virtual TypePtr instantiate(std::vector params) = 0; + virtual ~TemplateType() = default; +}; + +// Integer +class IntegerTemplateType : public TemplateType { + public: + IntegerTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; +}; +// FloatPoint +class FloatPointTemplateType : public TemplateType { + public: + FloatPointTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; +}; +// String +class StringTemplateType : public TemplateType { + public: + StringTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; +}; +// List (ty is a template) +class ListTemplateType : public TemplateType { + RightValue elementType; + public: + ListTemplateType(RightValue elementType); + virtual TypePtr instantiate(std::vector params) override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; +}; +// Vector +class VectorTemplateType : public TemplateType { + public: + VectorTemplateType(); + virtual TypePtr instantiate(std::vector params) override; +}; +// Tensor +class TensorTemplateType : public TemplateType { + public: + TensorTemplateType(); + virtual TypePtr instantiate(std::vector params) override; +}; +// Optional (ty is a template) +class OptionalTemplateType : public TemplateType { + RightValue elementType; + public: + OptionalTemplateType(RightValue elementType); + virtual TypePtr instantiate(std::vector params) override; +}; +// Any<[ty1, ty2, ...]> (ty* is a template) +class AnyTemplateType : public TemplateType { + RightValue elementTypes; + public: + AnyTemplateType(RightValue elementTypes); + virtual TypePtr instantiate(std::vector params) override; +}; +// custome type +class CustomeTemplateType : public TemplateType { + public: + CustomeTemplateType(TypeDefination* tydef); + virtual TypePtr instantiate(std::vector params) override; // for generating typedef td file. virtual std::string toStringForTypedef() override; // for generating op def td file. @@ -649,8 +743,16 @@ class FegenVisitor; class Manager { friend class FegenVisitor; +private: +struct OverloadedType { + llvm::SmallVector tys; + OverloadedType(TypeDefination *); + OverloadedType(std::initializer_list&&); + TypeDefination* get(unsigned i); +}; private: + std::map typeDefMap; Manager(); Manager(const Manager &) = delete; const Manager &operator=(const Manager &) = delete; @@ -663,7 +765,7 @@ class Manager { std::vector headFiles; std::map nodeMap; llvm::StringMap typeMap; - std::map typeDefMap; + std::map operationMap; std::map functionMap; // stmt contents @@ -678,7 +780,9 @@ class Manager { void setModuleName(std::string name); TypeDefination *getTypeDefination(std::string name); + TypeDefination* getOverloadedTypeDefination(std::string name); bool addTypeDefination(TypeDefination *tyDef); + bool addOverloadedTypeDefination(TypeDefination *tyDef); Operation *getOperationDefination(std::string name); bool addOperationDefination(Operation *opDef); diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 1c3f40a9aa..7e46f5ca76 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -60,7 +61,7 @@ class FegenVisitor : public FegenParserBaseVisitor { tyDef->setName(typeName); tyDef->setCtx(ctx); // add defination to manager map - this->manager.typeDefMap.insert({typeName, tyDef}); + this->manager.addTypeDefination(tyDef); return nullptr; } @@ -312,7 +313,7 @@ class FegenVisitor : public FegenParserBaseVisitor { return nullptr; } else { // type auto tyDef = this->manager.getTypeDefination(ctx->prefixedName()->identifier(0)->getText()); - return fegen::Type::getTemplateType(tyDef); + return fegen::Type::getCustomeTemplate(tyDef); } } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate return this->visit(ctx->builtinTypeTemplate()); @@ -325,14 +326,14 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitBuiltinTypeTemplate( FegenParser::BuiltinTypeTemplateContext *ctx) override { if (ctx->INTEGER()) { - return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_INTEGER)); + return fegen::Type::getIntegerTemplate(); } else if (ctx->FLOATPOINT()) { - return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_FLOATPOINT)); + return fegen::Type::getFloatPointTemplate(); } else if (ctx->TENSOR()) { // return fegen::FegenType::getTensorTemplate(); - return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_TENSOR)); + return fegen::Type::getTensorTemplate(); } else if (ctx->VECTOR()) { - return fegen::Type::getTemplateType(this->manager.getTypeDefination(FEGEN_VECTOR)); + return fegen::Type::getVectorTemplate(); } else { return nullptr; } @@ -356,16 +357,39 @@ class FegenVisitor : public FegenParserBaseVisitor { auto expr = std::any_cast>( this->visit(ctx->expression())); + if (ctx->collectProtoType()->ANY()) { - std::vector tys; - assert(expr->getKind() == fegen::RightValue::LiteralKind::VECTOR); - return fegen::Type::getAnyType(fegen::RightValue::getByExpr(expr)); + // check to get list type. + std::vector tyexpr = std::any_cast>(expr); + int level = std::any_cast(tyexpr[0]->getContent())->getTypeLevel(); + for(size_t i = 1; i <= tyexpr.size()-1; i++){ + auto expr = tyexpr[i]; + auto t = std::any_cast(expr->getContent()); + if(level != t->getTypeLevel()){ + assert(false); + } + } + if(level == 1 || level == 2){ // template -> any template + return fegen::Type::getAnyTemplate(fegen::RightValue::getByExpr(expr)); + }else{ // instance -> any instance + return fegen::Type::getAnyType(fegen::RightValue::getByExpr(expr)); + } } else if (ctx->collectProtoType()->LIST()) { - assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); - return fegen::Type::getListType(fegen::RightValue::getByExpr(expr)); + // the same as any + int level = std::any_cast(expr->getContent())->getTypeLevel(); + if(level == 1 || level == 2){ + return fegen::Type::getListTemplate(fegen::RightValue::getByExpr(expr)); + }else{ + return fegen::Type::getListType(fegen::RightValue::getByExpr(expr)); + } } else { // optional - assert(expr->getKind() == fegen::RightValue::LiteralKind::TYPE); - return fegen::Type::getOptionalType(fegen::RightValue::getByExpr(expr)); + // the same as any + int level = std::any_cast(expr->getContent())->getTypeLevel(); + if(level == 1 || level == 2){ + return fegen::Type::getOptionalTemplate(fegen::RightValue::getByExpr(expr)); + }else{ + return fegen::Type::getOptionalType(fegen::RightValue::getByExpr(expr)); + } } } @@ -541,7 +565,7 @@ class FegenVisitor : public FegenParserBaseVisitor { // TODO auto tyDef = this->manager.getTypeDefination(name); if (tyDef) { - auto tyVar = fegen::Type::getTemplateType(tyDef); + auto tyVar = fegen::Type::getCustomeTemplate(tyDef); return (std::shared_ptr) fegen::RightValue::Expression::getTypeRightValue(tyVar); } else { diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 7b9c52bc75..fc6383ead4 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -106,6 +106,7 @@ bool fegen::Type::isSameType(fegen::Type *type1, fegen::Type *type2) { } std::string fegen::Type::toStringForTypedef() { + std::cerr << this->getTypeName() <typeLevel == 2 || elementType->typeLevel == 3); - return std::make_shared( + assert(elementType->typeLevel == 3); + return std::make_shared( fegen::RightValue::getTypeRightValue(elementType)); } @@ -195,7 +196,7 @@ fegen::TypePtr fegen::Type::getTensorType(RightValue elementType, } fegen::TypePtr fegen::Type::getOptionalType(fegen::TypePtr elementType) { - assert(elementType->typeLevel == 2 || elementType->typeLevel == 3); + assert(elementType->typeLevel == 3); return std::make_shared( RightValue::getTypeRightValue(elementType)); } @@ -215,9 +216,62 @@ fegen::Type::getCustomeType(std::vector params, return std::make_shared(params, tydef); } -fegen::TypePtr -fegen::Type::getTemplateType(fegen::TypeDefination *typeDefination) { - return std::make_shared(typeDefination); +// Integer +fegen::TypePtr fegen::Type::getIntegerTemplate() { + return std::make_shared(); +} + +// FloatPoint +fegen::TypePtr fegen::Type::getFloatPointTemplate() { + return std::make_shared(); +} + +// string +fegen::TypePtr fegen::Type::getStringTemplate() { + return std::make_shared(); +} + +// List +fegen::TypePtr fegen::Type::getListTemplate(TypePtr elementType) { + assert(elementType->typeLevel == 2 || elementType->typeLevel == 1); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getListTemplate(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getListTemplate(ty); +} + +// Vector +fegen::TypePtr fegen::Type::getVectorTemplate() { + return std::make_shared(); +} + +// Tensor +fegen::TypePtr fegen::Type::getTensorTemplate() { + return std::make_shared(); +} + +// Optional +fegen::TypePtr fegen::Type::getOptionalTemplate(TypePtr elementType) { + assert(elementType->typeLevel == 2); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} +fegen::TypePtr fegen::Type::getOptionalTemplate(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getOptionalTemplate(ty); +} + +// Any<[elementType1, elementType2, ...]> +fegen::TypePtr fegen::Type::getAnyTemplate(RightValue elementTypes) { + return std::make_shared(elementTypes); +} + +fegen::TypePtr fegen::Type::getCustomeTemplate(TypeDefination *tydef) { + assert(tydef->isCustome()); + return std::make_shared(tydef); } /// @brief get name of Type Instance by jointsing template name and parameters, @@ -365,8 +419,7 @@ fegen::StringType::StringType() // class ListType fegen::ListType::ListType(fegen::RightValue elementType) : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_LIST, {elementType}), - fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), - std::any_cast(elementType.getContent())->getTypeLevel(), + fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), 3, elementType.isConstant()), elementType(elementType) {} @@ -411,23 +464,23 @@ fegen::TensorType::TensorType(RightValue elementType, RightValue shape) fegen::OptionalType::OptionalType(RightValue elementType) : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_OPTINAL, {elementType}), - fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), - std::any_cast(elementType.getContent())->getTypeLevel(), + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), 3, elementType.isConstant()), elementType(elementType) {} // class AnyType -inline int getTypeLevelOfListType(fegen::RightValue& elementTypes) { - auto listContent = std::any_cast>(elementTypes.getContent()); - fegen::TypePtr ty = std::any_cast(listContent[0]->getContent()); +inline int getTypeLevelOfListType(fegen::RightValue &elementTypes) { + auto listContent = std::any_cast>( + elementTypes.getContent()); + fegen::TypePtr ty = + std::any_cast(listContent[0]->getContent()); return ty->getTypeLevel(); } fegen::AnyType::AnyType(RightValue elementTypes) : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_ANY, {elementTypes}), - fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), - getTypeLevelOfListType(elementTypes), + fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), 3, elementTypes.isConstant()), elementTypes(elementTypes) {} @@ -451,64 +504,140 @@ fegen::CustomeType::CustomeType(std::vector params, fegen::TemplateType::TemplateType(TypeDefination *tydef) : Type(fegen::Type::TypeKind::CPP, tydef->getName(), tydef, 2, true) {} +// class IntegerTemplateType +fegen::IntegerTemplateType::IntegerTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER)) {} + fegen::TypePtr -fegen::TemplateType::instantiate(std::vector params) { - auto tydef = this->getTypeDefination(); - if (tydef->isCustome()) { - return Type::getCustomeType(params, tydef); - } else if (tydef->getName() == FEGEN_INTEGER) { - assert(params.size() == 1); - return Type::getIntegerType(params[0]); - } else if (tydef->getName() == FEGEN_FLOATPOINT) { - assert(params.size() == 1); - return Type::getFloatPointType(params[0]); - } else if (tydef->getName() == FEGEN_STRING) { - assert(params.size() == 0); - return Type::getStringType(); - } else if (tydef->getName() == FEGEN_LIST) { - assert(params.size() == 1); - return Type::getListType(params[0]); - } else if (tydef->getName() == FEGEN_VECTOR) { - assert(params.size() == 2); - return Type::getVectorType(params[0], params[1]); - } else if (tydef->getName() == FEGEN_TENSOR) { - assert(params.size() == 2); - return Type::getTensorType(params[0], params[1]); - } else if (tydef->getName() == FEGEN_OPTINAL) { - assert(params.size() == 1); - return Type::getOptionalType(params[0]); - } else if (tydef->getName() == FEGEN_ANY) { - assert(params.size() == 1); - return Type::getAnyType(params[0]); - } else { - assert(false); - } +fegen::IntegerTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getIntegerType(params[0]); } -std::string fegen::TemplateType::toStringForTypedef() { - auto tyd = this->getTypeDefination(); - if (tyd->isCustome()) { - return this->getTypeDefination()->getName(); - } else if (tyd->getName() == FEGEN_INTEGER) { - return "Builtin_IntegerAttr"; - } else if (tyd->getName() == FEGEN_FLOATPOINT) { - return "Builtin_FloatAttr"; - } else { - std::cerr << "unsupport type: " << this->getTypeName() << std::endl; - assert(false); - } +std::string fegen::IntegerTemplateType::toStringForTypedef() { + return "Builtin_IntegerAttr"; } -std::string fegen::TemplateType::toStringForOpdef() { - auto tyd = this->getTypeDefination(); - if (tyd->isCustome()) { - return this->getTypeDefination()->getName(); - } else if (tyd->getName() == FEGEN_INTEGER) { - return "Builtin_Integer"; - } else { - std::cerr << "unsupport type: " << this->getTypeName() << std::endl; - assert(false); - } +std::string fegen::IntegerTemplateType::toStringForOpdef() { + return "Builtin_Integer"; +} + +// class FloatPointTemplateType +fegen::FloatPointTemplateType::FloatPointTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT)) {} + +fegen::TypePtr +fegen::FloatPointTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getFloatPointType(params[0]); +} + +std::string fegen::FloatPointTemplateType::toStringForTypedef() { + return "Builtin_FloatAttr"; +} + +// class StringTemplateType +fegen::StringTemplateType::StringTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_STRING)) {} + +fegen::TypePtr +fegen::StringTemplateType::instantiate(std::vector params) { + assert(params.size() == 0); + return Type::getStringType(); +} + +std::string fegen::StringTemplateType::toStringForTypedef() { + return "Builtin_StringAttr"; +} + +// class ListTemplateType +fegen::ListTemplateType::ListTemplateType(fegen::RightValue elementType) + : TemplateType(fegen::Manager::getManager().getTypeDefination(FEGEN_LIST)), + elementType(elementType) {} + +fegen::TypePtr +fegen::ListTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getListType(params[0]); +} + +std::string fegen::ListTemplateType::toStringForTypedef() { + std::string res = "ArrayRefParameter<"; + res.append(this->elementType.toStringForTypedef()); + res.append(">"); + return res; +} + +std::string fegen::ListTemplateType::toStringForOpdef() { + std::string res = "Variadic<"; + res.append(this->elementType.toStringForOpdef()); + res.append(">"); + return res; +} + +// class VectorTemplateType +fegen::VectorTemplateType::VectorTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR)) {} + +fegen::TypePtr +fegen::VectorTemplateType::instantiate(std::vector params) { + assert(params.size() == 2); + return Type::getVectorType(params[0], params[1]); +} + +// class TensorTemplateType +fegen::TensorTemplateType::TensorTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR)) {} + +fegen::TypePtr +fegen::TensorTemplateType::instantiate(std::vector params) { + assert(params.size() == 2); + return Type::getTensorType(params[0], params[1]); +} + +// class OptionalTemplateType +fegen::OptionalTemplateType::OptionalTemplateType(RightValue elementType) + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL)), + elementType(elementType) {} + +fegen::TypePtr +fegen::OptionalTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getOptionalType(params[0]); +} + +// class AnyTemplateType +fegen::AnyTemplateType::AnyTemplateType(RightValue elementTypes) + : TemplateType(fegen::Manager::getManager().getTypeDefination(FEGEN_ANY)), + elementTypes(elementTypes) {} + +fegen::TypePtr +fegen::AnyTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getAnyType(params[0]); +} + +// class CustomeTemplateType +fegen::CustomeTemplateType::CustomeTemplateType(TypeDefination *tydef) + : TemplateType(tydef) {} + +fegen::TypePtr +fegen::CustomeTemplateType::instantiate(std::vector params) { + return Type::getCustomeType(params, this->getTypeDefination()); +} + +std::string fegen::CustomeTemplateType::toStringForTypedef() { + return this->getTypeDefination()->getName(); +} + +std::string fegen::CustomeTemplateType::toStringForOpdef() { + return this->getTypeDefination()->getName(); } // class FegenTypeDefination @@ -1191,6 +1320,17 @@ std::string getChildrenText(antlr4::tree::ParseTree *ctx) { return ruleText; } +fegen::Manager::OverloadedType::OverloadedType(TypeDefination *ty) + : tys({ty}) {} +fegen::Manager::OverloadedType::OverloadedType( + std::initializer_list &&tys) + : tys(tys) {} + +fegen::TypeDefination * +fegen::Manager::OverloadedType::OverloadedType::get(unsigned i) { + return this->tys[i]; +} + fegen::Manager::Manager() {} namespace fegen { @@ -1340,7 +1480,7 @@ void fegen::Manager::emitTypeDefination() { emitter.newLine(); for (auto pair : this->typeDefMap) { - auto tyDef = pair.second; + auto tyDef = pair.second.get(0); if (!tyDef->isCustome()) { continue; } @@ -1605,13 +1745,19 @@ void fegen::Manager::initbuiltinTypes() { nullptr, false)}); // List (this should be ahead of Tensor and Any Type defination) - this->typeDefMap.insert( - {FEGEN_LIST, - fegen::TypeDefination::get( + this->typeDefMap.insert({ + FEGEN_LIST, + {fegen::TypeDefination::get( FEGEN_DIALECT_NAME, FEGEN_LIST, {fegen::Value::get(fegen::Type::getMetaType(), "elementType", fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); + nullptr, false), // element type is type instance + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_LIST, + {fegen::Value::get(fegen::Type::getMetaTemplateType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false)} // element type is type template + }); // Tensor this->typeDefMap.insert( @@ -1628,26 +1774,55 @@ void fegen::Manager::initbuiltinTypes() { // Optional this->typeDefMap.insert( {FEGEN_OPTINAL, - fegen::TypeDefination::get( - FEGEN_DIALECT_NAME, FEGEN_OPTINAL, - {fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); + { + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_OPTINAL, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false), // element type is type instance + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_OPTINAL, + {fegen::Value::get(fegen::Type::getMetaTemplateType(), + "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false) // element type is type template + }}); // Any this->typeDefMap.insert( - {FEGEN_ANY, fegen::TypeDefination::get( - FEGEN_DIALECT_NAME, FEGEN_ANY, - {fegen::Value::get( - fegen::Type::getListType(fegen::Type::getMetaType()), - "elementType", fegen::RightValue::getPlaceHolder())}, - nullptr, false)}); + {FEGEN_ANY, + { + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_ANY, + {fegen::Value::get( + fegen::Type::getListTemplate(fegen::Type::getMetaType()), + "elementType", fegen::RightValue::getPlaceHolder())}, + nullptr, false), // elements are Type, ex: Any<[Integer<32>, + // FloatPoint<32>]> + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_ANY, + {fegen::Value::get(fegen::Type::getListTemplate( + fegen::Type::getMetaTemplateType()), + "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false) // elements are TypeTemplate, ex: Any<[Integer, + // FloatPoint]> + }}); } fegen::TypeDefination *fegen::Manager::getTypeDefination(std::string name) { auto it = this->typeDefMap.find(name); if (it != this->typeDefMap.end()) { - return it->second; + return it->second.get(0); + } + assert(false); +} + +fegen::TypeDefination * +fegen::Manager::getOverloadedTypeDefination(std::string name) { + auto it = this->typeDefMap.find(name); + if (it != this->typeDefMap.end()) { + return it->second.get(1); } assert(false); } @@ -1656,10 +1831,18 @@ bool fegen::Manager::addTypeDefination(fegen::TypeDefination *tyDef) { if (this->typeDefMap.count(tyDef->name) != 0) { return false; } - this->typeDefMap[tyDef->name] = tyDef; + this->typeDefMap.insert({tyDef->name, {tyDef}}); return true; } +bool fegen::Manager::addOverloadedTypeDefination(TypeDefination *tyDef) { + auto it = this->typeDefMap.find(tyDef->name); + if (it != this->typeDefMap.end()) { + it->second.tys[1] = tyDef; + } + assert(false); +} + fegen::Operation *fegen::Manager::getOperationDefination(std::string name) { return this->operationMap[name]; } From 446a1cfc0abe0330c31776def7db50526c93d1f1 Mon Sep 17 00:00:00 2001 From: chh Date: Fri, 19 Jul 2024 16:09:43 +0800 Subject: [PATCH 12/17] [FrontendGen] Add function generation. --- examples/FrontendGen/.gitignore | 3 +- examples/FrontendGen/makefile | 3 + frontend/FrontendGen/CMakeLists.txt | 2 +- frontend/FrontendGen/include/FegenVisitor.h | 79 ++++++------ frontend/FrontendGen/lib/FegenManager.cpp | 126 +++++++++++--------- 5 files changed, 119 insertions(+), 94 deletions(-) diff --git a/examples/FrontendGen/.gitignore b/examples/FrontendGen/.gitignore index 72daf17d09..72d4770020 100644 --- a/examples/FrontendGen/.gitignore +++ b/examples/FrontendGen/.gitignore @@ -1,3 +1,4 @@ test/ *.g4 -*.td \ No newline at end of file +*.td +*.cpp \ No newline at end of file diff --git a/examples/FrontendGen/makefile b/examples/FrontendGen/makefile index 43c29b9586..9fbf4d7466 100644 --- a/examples/FrontendGen/makefile +++ b/examples/FrontendGen/makefile @@ -10,5 +10,8 @@ typeDefine: rule: @${BUDDY_FRONTEND_GEN} -f ./rule.fegen +function: + @${BUDDY_FRONTEND_GEN} -f ./function.fegen + clean: rm -f ./toy* \ No newline at end of file diff --git a/frontend/FrontendGen/CMakeLists.txt b/frontend/FrontendGen/CMakeLists.txt index ca9204b3a3..8294448b47 100644 --- a/frontend/FrontendGen/CMakeLists.txt +++ b/frontend/FrontendGen/CMakeLists.txt @@ -16,4 +16,4 @@ target_link_libraries(buddy-frontendgen fegen_antlr_generated fegenVisitor antlr4_static -) \ No newline at end of file +) diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index eecc310412..f99ca3ece5 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -44,7 +44,9 @@ class FegenVisitor : public FegenParserBaseVisitor { void emitTypeDefination() { this->manager.emitTypeDefination(); } void emitDialectDefination() { this->manager.emitDialectDefination(); } void emitOpDefination() { this->manager.emitOpDefination(); } - void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST){this->manager.emitBuiltinFunction(moduleAST);} + void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST) { + this->manager.emitBuiltinFunction(moduleAST); + } FegenVisitor() : manager(Manager::getManager()), sstack(ScopeStack::getScopeStack()) { @@ -222,9 +224,9 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitTypeInstance(FegenParser::TypeInstanceContext *ctx) override { if (ctx->typeTemplate()) { // typeTemplate (Less typeTemplateParam (Comma // typeTemplateParam)* Greater)? - auto typeTeplt = + auto typeTeplt = std::any_cast(this->visit(ctx->typeTemplate())); - if(ctx->typeTemplate()->TYPE()){ + if (ctx->typeTemplate()->TYPE()) { return typeTeplt; } auto teplt = std::dynamic_pointer_cast(typeTeplt); @@ -250,8 +252,7 @@ class FegenVisitor : public FegenParserBaseVisitor { auto varName = ctx->identifier()->getText(); auto var = this->sstack.attemptFindVar(varName); if (var) { - if (var->getContentKind() == - fegen::RightValue::LiteralKind::TYPE) { + if (var->getContentKind() == fegen::RightValue::LiteralKind::TYPE) { return var->getContent(); } else { std::cerr << "variable " << varName @@ -277,9 +278,8 @@ class FegenVisitor : public FegenParserBaseVisitor { this->visit(ctx->builtinTypeInstances())); return fegen::RightValue::getTypeRightValue(ty); } else { - auto expr = - std::any_cast>( - this->visit(ctx->expression())); + auto expr = std::any_cast>( + this->visit(ctx->expression())); return fegen::RightValue::getByExpr(expr); } } @@ -310,7 +310,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // TODO: return type from other dialect return nullptr; } else { // type - auto tyDef = this->manager.getTypeDefination(ctx->prefixedName()->identifier(0)->getText()); + auto tyDef = this->manager.getTypeDefination( + ctx->prefixedName()->identifier(0)->getText()); return fegen::Type::getCustomeTemplate(tyDef); } } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate @@ -352,40 +353,44 @@ class FegenVisitor : public FegenParserBaseVisitor { // return TypePtr std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { - auto expr = - std::any_cast>( - this->visit(ctx->expression())); - + auto expr = std::any_cast>( + this->visit(ctx->expression())); + if (ctx->collectProtoType()->ANY()) { // check to get list type. - std::vector tyexpr = std::any_cast>(expr); - int level = std::any_cast(tyexpr[0]->getContent())->getTypeLevel(); - for(size_t i = 1; i <= tyexpr.size()-1; i++){ + std::vector tyexpr = + std::any_cast>(expr); + int level = std::any_cast(tyexpr[0]->getContent()) + ->getTypeLevel(); + for (size_t i = 1; i <= tyexpr.size() - 1; i++) { auto expr = tyexpr[i]; auto t = std::any_cast(expr->getContent()); - if(level != t->getTypeLevel()){ + if (level != t->getTypeLevel()) { assert(false); } } - if(level == 1 || level == 2){ // template -> any template + if (level == 1 || level == 2) { // template -> any template return fegen::Type::getAnyTemplate(fegen::RightValue::getByExpr(expr)); - }else{ // instance -> any instance + } else { // instance -> any instance return fegen::Type::getAnyType(fegen::RightValue::getByExpr(expr)); } } else if (ctx->collectProtoType()->LIST()) { // the same as any - int level = std::any_cast(expr->getContent())->getTypeLevel(); - if(level == 1 || level == 2){ + int level = + std::any_cast(expr->getContent())->getTypeLevel(); + if (level == 1 || level == 2) { return fegen::Type::getListTemplate(fegen::RightValue::getByExpr(expr)); - }else{ + } else { return fegen::Type::getListType(fegen::RightValue::getByExpr(expr)); } } else { // optional // the same as any - int level = std::any_cast(expr->getContent())->getTypeLevel(); - if(level == 1 || level == 2){ - return fegen::Type::getOptionalTemplate(fegen::RightValue::getByExpr(expr)); - }else{ + int level = + std::any_cast(expr->getContent())->getTypeLevel(); + if (level == 1 || level == 2) { + return fegen::Type::getOptionalTemplate( + fegen::RightValue::getByExpr(expr)); + } else { return fegen::Type::getOptionalType(fegen::RightValue::getByExpr(expr)); } } @@ -623,7 +628,8 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { sstack.pushScope(); - auto returnType = std::any_cast(this->visit(ctx->typeSpec())); + auto returnType = + std::any_cast(this->visit(ctx->typeSpec())); manager.addStmtContent(ctx, returnType); auto functionName = std::any_cast(this->visit(ctx->funcName())); @@ -658,7 +664,7 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i < ctx->typeSpec().size(); i++) { auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); - // manager.addStmtContent(ctx, paramType); + // manager.addStmtContent(ctx, paramType); auto paramName = ctx->identifier(i)->getText(); auto param = fegen::Value::get(paramType, paramName, fegen::RightValue::getPlaceHolder()); @@ -694,7 +700,7 @@ class FegenVisitor : public FegenParserBaseVisitor { fegen::RightValue::getPlaceHolder()); } sstack.attemptAddVar(var); - + return var; } @@ -715,7 +721,7 @@ class FegenVisitor : public FegenParserBaseVisitor { fegen::Value *stmt = fegen::Value::get( var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); manager.stmtContentMap.insert(std::pair{ctx, stmt}); - + manager.addStmtContent(ctx->expression(), varcontent); return var; } @@ -728,8 +734,9 @@ class FegenVisitor : public FegenParserBaseVisitor { auto paraList = hasFunc->getInputTypeList(); if (paramsNum > 0) { for (size_t i = 0; i < paramsNum; i++) { - auto oprand = std::any_cast>( - this->visit(ctx->expression(i))); + auto oprand = + std::any_cast>( + this->visit(ctx->expression(i))); parasList.push_back(oprand); } size_t len1 = paraList.size(); @@ -743,7 +750,8 @@ class FegenVisitor : public FegenParserBaseVisitor { // for (size_t i = 0; i < len1; i++) { // if (!fegen::Type::isSameType(¶List[i]->getType(), // ¶sList[i]->exprType)) { - // std::cerr << "The function \" " << functionName << "\" parameter" << i + // std::cerr << "The function \" " << functionName << "\" parameter" + // << i // << " type mismatch." << std::endl; // exit(0); // return nullptr; @@ -782,9 +790,10 @@ class FegenVisitor : public FegenParserBaseVisitor { } std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { - sstack.pushScope(); + this->sstack.pushScope(); this->visit(ctx->statementBlock()); - sstack.popScope(); + this->sstack.popScope(); + return nullptr; } std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 7d9a225448..a8ed9e0b7c 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -106,7 +106,7 @@ bool fegen::Type::isSameType(fegen::Type *type1, fegen::Type *type2) { } std::string fegen::Type::toStringForTypedef() { - std::cerr << this->getTypeName() <getTypeName() << std::endl; assert(FEGEN_NOT_IMPLEMENTED_ERROR); } @@ -160,7 +160,7 @@ fegen::TypePtr fegen::Type::getStringType() { fegen::TypePtr fegen::Type::getListType(fegen::TypePtr elementType) { assert(elementType->typeLevel == 3); - return std::make_shared( + return std::make_shared( fegen::RightValue::getTypeRightValue(elementType)); } @@ -1506,8 +1506,7 @@ void fegen::Manager::emitTypeDefination() { auto paramTy = param->getType(); auto paramName = param->getName(); auto paramTyStr = paramTy->toStringForTypedef(); - emitter << paramTyStr << ":" - << "$" << paramName; + emitter << paramTyStr << ":" << "$" << paramName; if (i != tyDef->getParameters().size() - 1) { emitter << ", "; } @@ -1888,50 +1887,55 @@ class StmtVisitor : public FegenParserBaseVisitor { Emitter &emitter; public: - StmtVisitor(Emitter &emitter) : manager(Manager::getManager()), emitter(emitter) {} - std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { - auto returnType = std::any_cast(manager.stmtContentMap[ctx]); - auto functionName = std::any_cast(manager.stmtContentMap[ctx->funcName()]); - emitter << returnType.getTypeName() << " " - << functionName << "("; - auto paraList = std::any_cast>(manager.stmtContentMap[ctx->funcParams()]); - for (auto para : paraList) { - emitter << para->getType().getTypeName() << " " << para->getName(); - if (para != paraList.back()) - emitter << ", "; - } - emitter << "){"; - emitter.tab(); - emitter.newLine(); - this->visit(ctx->statementBlock()); - emitter.shiftTab(); - emitter << "}"; - emitter.newLine(); - return nullptr; - } - std::any visitStatementBlock(FegenParser::StatementBlockContext *ctx) override { - for(size_t i = 0; i < ctx->statement().size(); i++){ - this->visit(ctx->statement(i)); - if(!(ctx->statement(i)->ifStmt()||ctx->statement(i)->forStmt())) - emitter << ";"; - emitter.newLine(); + StmtVisitor(Emitter &emitter) + : manager(Manager::getManager()), emitter(emitter) {} + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + auto returnType = + std::any_cast(manager.stmtContentMap[ctx]); + auto functionName = + std::any_cast(manager.stmtContentMap[ctx->funcName()]); + emitter << returnType->getTypeName() << " " << functionName << "("; + auto paraList = std::any_cast>( + manager.stmtContentMap[ctx->funcParams()]); + for (auto para : paraList) { + emitter << para->getType()->getTypeName() << " " << para->getName(); + if (para != paraList.back()) + emitter << ", "; } + emitter << "){"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); return nullptr; + } + std::any + visitStatementBlock(FegenParser::StatementBlockContext *ctx) override { + for (size_t i = 0; i < ctx->statement().size(); i++) { + this->visit(ctx->statement(i)); + if (!(ctx->statement(i)->ifStmt() || ctx->statement(i)->forStmt())) + emitter << ";"; + emitter.newLine(); } + return nullptr; + } std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - auto varType = - std::any_cast(manager.stmtContentMap[ctx]); + auto varType = std::any_cast(manager.stmtContentMap[ctx]); auto varName = ctx->identifier()->getText(); - emitter << varType.getTypeName() << " " << varName; - if(ctx->expression()){ - auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); - emitter << " = " << expr->toString(); + emitter << varType->getTypeName() << " " << varName; + if (ctx->expression()) { + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << " = " << expr->toString(); } return nullptr; } std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { auto varName = ctx->identifier()->getText(); - auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + auto expr = this->manager.getStmtContent( + ctx->expression()); emitter << varName << " = " << expr->toString(); return nullptr; } @@ -1951,11 +1955,12 @@ class StmtVisitor : public FegenParserBaseVisitor { } std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { this->visit(ctx->ifBlock(0)); - for(size_t i = 1; i < ctx->ifBlock().size(); i++){ - emitter << " else "; - this->visit(ctx->ifBlock(i)); + for (size_t i = 1; i < ctx->ifBlock().size(); i++) { + emitter << " else "; + this->visit(ctx->ifBlock(i)); } - if(ctx->elseBlock()) this->visit(ctx->elseBlock()); + if (ctx->elseBlock()) + this->visit(ctx->elseBlock()); return nullptr; } std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { @@ -1971,28 +1976,30 @@ class StmtVisitor : public FegenParserBaseVisitor { return nullptr; } std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { - emitter << "else {"; - emitter.tab(); - emitter.newLine(); - this->visit(ctx->statementBlock()); - emitter.shiftTab(); - emitter << "}"; - return nullptr; + emitter << "else {"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + return nullptr; } // TODO: 支持for循环 std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { if (ctx->varDeclStmt()) { - emitter << "for ("; + emitter << "for ("; this->visit(ctx->varDeclStmt()); emitter << "; "; - auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); emitter << expr->toString() << "; "; this->visit(ctx->assignStmt(0)); emitter << ") {"; } else { this->visit(ctx->assignStmt(0)); emitter << " "; - auto expr = std::any_cast>(manager.stmtContentMap[ctx->expression()]); + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); emitter << expr->toString() << "; "; this->visit(ctx->assignStmt(1)); emitter << ") {"; @@ -2005,19 +2012,24 @@ class StmtVisitor : public FegenParserBaseVisitor { return nullptr; } std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { - auto expr = std::any_cast>( + auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << "return " << expr->toString(); - return nullptr; + emitter << "return " << expr->toString(); + return nullptr; + } + + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { + return nullptr; } }; } // namespace fegen -void fegen::Manager::emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST) { +void fegen::Manager::emitBuiltinFunction( + fegen::FegenParser::FegenSpecContext *moduleAST) { std::ofstream fileStream; fileStream.open(this->moduleName + "Function.cpp"); fegen::Emitter emitter(fileStream); - //Emitter emitter(std::cout); + // Emitter emitter(std::cout); StmtVisitor visitor(emitter); visitor.visit(moduleAST); fileStream.close(); From 2ea7663593ed4862c64e075265d8824f44a21603 Mon Sep 17 00:00:00 2001 From: chh Date: Mon, 22 Jul 2024 10:11:04 +0800 Subject: [PATCH 13/17] [FrontendGen] Fix FloatPointType getter. --- examples/FrontendGen/opDefine.fegen | 5 + frontend/FrontendGen/include/FegenManager.h | 17 +- frontend/FrontendGen/include/FegenVisitor.h | 1 - frontend/FrontendGen/lib/FegenManager.cpp | 306 ++++++++++++-------- 4 files changed, 209 insertions(+), 120 deletions(-) diff --git a/examples/FrontendGen/opDefine.fegen b/examples/FrontendGen/opDefine.fegen index 60fa22db92..65d21e6369 100644 --- a/examples/FrontendGen/opDefine.fegen +++ b/examples/FrontendGen/opDefine.fegen @@ -6,4 +6,9 @@ opdef add { body { res = lhs + rhs; } +} + +opdef constant { + arguments [attribute double value] + results [operand Tensor> res] } \ No newline at end of file diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 83197141aa..d26886032d 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -193,8 +193,8 @@ class Type { static TypePtr getVectorType(RightValue elementType, RightValue size); // Tensor - static TypePtr getTensorType(TypePtr elementType, RightValue shape); - static TypePtr getTensorType(RightValue elementType, RightValue shape); + static TypePtr getTensorType(TypePtr elementType); + static TypePtr getTensorType(RightValue elementType); // Optional static TypePtr getOptionalType(TypePtr elementType); @@ -369,6 +369,7 @@ class RightValue { }; struct OperatorCall : public ExpressionNode { + static std::unordered_map cppOperatorMap; FegenOperator op; std::vector> params; OperatorCall(FegenOperator, std::vector>); @@ -402,6 +403,7 @@ class RightValue { IntegerLiteral(largestInt content, size_t size); virtual std::any getContent() override; virtual std::string toString() override; + virtual std::string toStringForCppKind() override; virtual TypePtr getType() override; }; @@ -411,6 +413,7 @@ class RightValue { FloatPointLiteral(long double content, size_t size); virtual std::any getContent() override; virtual std::string toString() override; + virtual std::string toStringForCppKind() override; virtual TypePtr getType() override; }; @@ -419,6 +422,7 @@ class RightValue { StringLiteral(std::string content); virtual std::any getContent() override; virtual std::string toString() override; + virtual std::string toStringForCppKind() override; virtual TypePtr getType() override; }; @@ -448,6 +452,7 @@ class RightValue { LeftValue(Value *content); virtual std::any getContent() override; virtual std::string toString() override; + virtual std::string toStringForCppKind() override; virtual TypePtr getType() override; }; @@ -507,6 +512,7 @@ class IntegerType : public Type { public: IntegerType(RightValue size, TypeDefination* tyDef); IntegerType(RightValue size); + largestInt getSize(); // for generating typedef td file. virtual std::string toStringForTypedef() override; // for generating op def td file. @@ -519,6 +525,7 @@ class FloatPointType : public Type { RightValue size; public: FloatPointType(RightValue size); + largestInt getSize(); // for generating typedef td file. virtual std::string toStringForTypedef() override; // for generating op def td file. @@ -550,12 +557,12 @@ class VectorType : public Type { public: VectorType(RightValue elementType, RightValue size); }; -// Tensor +// Tensor class TensorType : public Type { RightValue elementType; - RightValue shape; public: - TensorType(RightValue elementType, RightValue shape); + TensorType(RightValue elementType); + virtual std::string toStringForOpdef() override; }; // Optional class OptionalType : public Type { diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index f99ca3ece5..ba06dffe43 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -329,7 +329,6 @@ class FegenVisitor : public FegenParserBaseVisitor { } else if (ctx->FLOATPOINT()) { return fegen::Type::getFloatPointTemplate(); } else if (ctx->TENSOR()) { - // return fegen::FegenType::getTensorTemplate(); return fegen::Type::getTensorTemplate(); } else if (ctx->VECTOR()) { return fegen::Type::getVectorTemplate(); diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index a8ed9e0b7c..5026773e9b 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -151,7 +151,7 @@ fegen::TypePtr fegen::Type::getIntegerType(fegen::RightValue size) { } fegen::TypePtr fegen::Type::getFloatPointType(fegen::RightValue size) { - return std::make_shared(size); + return std::make_shared(size); } fegen::TypePtr fegen::Type::getStringType() { @@ -182,17 +182,15 @@ fegen::TypePtr fegen::Type::getVectorType(RightValue elementType, return Type::getVectorType(ty, size); } -fegen::TypePtr fegen::Type::getTensorType(fegen::TypePtr elementType, - fegen::RightValue shape) { +fegen::TypePtr fegen::Type::getTensorType(fegen::TypePtr elementType) { assert(elementType->typeLevel == 3); return std::make_shared( - fegen::RightValue::getTypeRightValue(elementType), shape); + fegen::RightValue::getTypeRightValue(elementType)); } -fegen::TypePtr fegen::Type::getTensorType(RightValue elementType, - RightValue shape) { +fegen::TypePtr fegen::Type::getTensorType(RightValue elementType) { auto ty = std::any_cast(elementType.getContent()); - return Type::getTensorType(ty, shape); + return Type::getTensorType(ty); } fegen::TypePtr fegen::Type::getOptionalType(fegen::TypePtr elementType) { @@ -329,6 +327,11 @@ fegen::IntegerType::IntegerType(fegen::RightValue size) size.isConstant()), size(size) {} +fegen::largestInt fegen::IntegerType::getSize() { + assert(this->size.getLiteralKind() == RightValue::LiteralKind::INT); + return std::any_cast(this->size.getContent()); +} + std::string fegen::IntegerType::toStringForTypedef() { auto content = std::any_cast(this->size.getContent()); if (content == 32) { @@ -381,6 +384,11 @@ fegen::FloatPointType::FloatPointType(fegen::RightValue size) size.isConstant()), size(size) {} +fegen::largestInt fegen::FloatPointType::getSize() { + assert(this->size.getLiteralKind() == RightValue::LiteralKind::INT); + return std::any_cast(this->size.getContent()); +} + std::string fegen::FloatPointType::toStringForTypedef() { auto content = std::any_cast(this->size.getContent()); if (content == 32) { @@ -394,16 +402,39 @@ std::string fegen::FloatPointType::toStringForTypedef() { } std::string fegen::FloatPointType::toStringForOpdef() { - return "FloatPointType::toStringForOpdef"; + auto content = std::any_cast(this->size.getContent()); + switch (this->getTypeKind()) { + case Type::TypeKind::ATTRIBUTE: { + if (content == 32) { + return "F32ElementsAttr"; + } else if (content == 64) { + return "F64ElementsAttr"; + } + break; + } + case Type::TypeKind::OPERAND: { + if (content == 32) { + return "F32"; + } else if (content == 64) { + return "F64"; + } + break; + } + default: { + assert(false); + } + } + assert(false); } std::string fegen::FloatPointType::toStringForCppKind() { auto content = std::any_cast(this->size.getContent()); if (content == 32) { return "float"; - } - if (content == 64) { + } else if (content == 64) { return "double"; + } else if (content == 128) { + return "long double"; } else { std::cerr << "unsupport type: " << this->getTypeName() << std::endl; assert(false); @@ -453,12 +484,55 @@ fegen::VectorType::VectorType(RightValue elementType, RightValue size) elementType(elementType), size(size) {} // class TensorType -fegen::TensorType::TensorType(RightValue elementType, RightValue shape) +fegen::TensorType::TensorType(RightValue elementType) : Type(fegen::Type::TypeKind::CPP, - jointTypeName(FEGEN_TENSOR, {elementType, shape}), + jointTypeName(FEGEN_TENSOR, {elementType}), fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR), 3, - (elementType.isConstant() && shape.isConstant())), - elementType(elementType), shape(shape) {} + elementType.isConstant()), + elementType(elementType) {} + +std::string fegen::TensorType::toStringForOpdef() { + auto elemTy = std::any_cast(this->elementType.getContent()); + auto elemTyName = elemTy->getTypeDefination()->getName(); + if (elemTyName != FEGEN_INTEGER && elemTyName != FEGEN_FLOATPOINT) { + assert(false); + } + if (elemTyName == FEGEN_INTEGER) { + auto intTy = std::dynamic_pointer_cast(elemTy); + auto size = intTy->getSize(); + switch (size) { + case 1: + return "I1Tensor"; + case 8: + return "I8Tensor"; + case 16: + return "I16Tensor"; + case 32: + return "I32Tensor"; + case 64: + return "I64Tensor"; + default: { + std::cerr << "unsupprot type: " << this->getTypeName() << std::endl; + exit(0); + } + } + } else { + auto floatTy = std::dynamic_pointer_cast(elemTy); + auto size = floatTy->getSize(); + switch (size) { + case 16: + return "F16Tensor"; + case 32: + return "F32Tensor"; + case 64: + return "F64Tensor"; + default: { + std::cerr << "unsupprot type: " << this->getTypeName() << std::endl; + exit(0); + } + } + } +} // class OptionalType fegen::OptionalType::OptionalType(RightValue elementType) @@ -596,8 +670,8 @@ fegen::TensorTemplateType::TensorTemplateType() fegen::TypePtr fegen::TensorTemplateType::instantiate(std::vector params) { - assert(params.size() == 2); - return Type::getTensorType(params[0], params[1]); + assert(params.size() == 1); + return Type::getTensorType(params[0]); } // class OptionalTemplateType @@ -787,16 +861,6 @@ fegen::TypePtr fegen::RightValue::ExpressionNode::getType() { assert(FEGEN_NOT_IMPLEMENTED_ERROR); } -inline bool isBinaryOperator(fegen::FegenOperator &op) { - switch (op) { - case fegen::FegenOperator::NEG: - case fegen::FegenOperator::NOT: - return false; - default: - return true; - } -} - // class FunctionCall inline bool isFuncParamsAllConstant( std::vector> ¶ms) { @@ -888,8 +952,48 @@ std::string fegen::RightValue::OperatorCall::toStringForOpdef() { return "OperatorCall::toStringForOpdef"; } +inline bool isBinaryOperator(fegen::FegenOperator &op) { + switch (op) { + case fegen::FegenOperator::NEG: + case fegen::FegenOperator::NOT: + return false; + default: + return true; + } +} + +std::unordered_map + fegen::RightValue::OperatorCall::cppOperatorMap = { + {fegen::FegenOperator::OR, "||"}, + {fegen::FegenOperator::AND, "&&"}, + {fegen::FegenOperator::EQUAL, "=="}, + {fegen::FegenOperator::NOT_EQUAL, "!="}, + {fegen::FegenOperator::LESS, "<"}, + {fegen::FegenOperator::LESS_EQUAL, "<="}, + {fegen::FegenOperator::GREATER, ">"}, + {fegen::FegenOperator::GREATER_EQUAL, ">="}, + {fegen::FegenOperator::ADD, "+"}, + {fegen::FegenOperator::SUB, "-"}, + {fegen::FegenOperator::MUL, "*"}, + {fegen::FegenOperator::DIV, "/"}, + {fegen::FegenOperator::MOD, "%"}, + {fegen::FegenOperator::POWER, "pow"}, + {fegen::FegenOperator::NEG, "-"}, + {fegen::FegenOperator::NOT, "!"}}; + std::string fegen::RightValue::OperatorCall::toStringForCppKind() { - return "OperatorCall::toStringForCppKind"; + std::string res; + if (isBinaryOperator(this->op)) { + res.append(this->params[0]->toStringForCppKind()); + res.append(" "); + res.append(OperatorCall::cppOperatorMap[this->op]); + res.append(" "); + res.append(this->params[1]->toStringForCppKind()); + } else { + res.append(OperatorCall::cppOperatorMap[this->op]); + res.append(this->params[0]->toStringForCppKind()); + } + return res; } std::any fegen::RightValue::OperatorCall::getContent() { return this; } @@ -947,6 +1051,10 @@ std::string fegen::RightValue::IntegerLiteral::toString() { return std::to_string(this->content); } +std::string fegen::RightValue::IntegerLiteral::toStringForCppKind() { + return std::to_string(this->content); +} + fegen::TypePtr fegen::RightValue::IntegerLiteral::getType() { return fegen::Type::getIntegerType(fegen::RightValue::getInteger(this->size)); } @@ -965,6 +1073,10 @@ std::string fegen::RightValue::FloatPointLiteral::toString() { return std::to_string(this->content); } +std::string fegen::RightValue::FloatPointLiteral::toStringForCppKind() { + return std::to_string(this->content); +} + fegen::TypePtr fegen::RightValue::FloatPointLiteral::getType() { return fegen::Type::getFloatPointType( fegen::RightValue::getInteger(this->size)); @@ -987,6 +1099,10 @@ std::string fegen::RightValue::StringLiteral::toString() { return res; } +std::string fegen::RightValue::StringLiteral::toStringForCppKind() { + return "\"" + this->content + "\""; +} + fegen::TypePtr fegen::RightValue::StringLiteral::getType() { return fegen::Type::getStringType(); } @@ -1122,6 +1238,10 @@ std::string fegen::RightValue::LeftValue::toString() { return this->content->getName(); } +std::string fegen::RightValue::LeftValue::toStringForCppKind() { + return this->content->getName(); +} + fegen::TypePtr fegen::RightValue::LeftValue::getType() { return this->content->getType(); } @@ -1372,51 +1492,6 @@ class Emitter { return this->stream; } }; - -class StmtGenerator : FegenParserBaseVisitor { -private: - Manager &manager; - Emitter &emitter; - -public: - StmtGenerator(Emitter &emitter) - : manager(Manager::getManager()), emitter(emitter) {} - std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - auto var = manager.getStmtContent(ctx->identifier()); - switch (var->getType()->getTypeKind()) { - case fegen::Type::TypeKind::CPP: { - this->emitter << var->getType()->toStringForCppKind() << " " - << var->getName(); - if (ctx->expression()) { - auto expr = this->manager.getStmtContent( - ctx->expression()); - this->emitter << " = " << expr->toStringForCppKind(); - } - this->emitter << ";"; - this->emitter.newLine(); - break; - } - case fegen::Type::TypeKind::ATTRIBUTE: { - break; - } - case fegen::Type::TypeKind::OPERAND: { - break; - } - } - return nullptr; - } - - std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override {} - - std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override {} - - std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override {} - - std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override {} - - std::any visitForStmt(FegenParser::ForStmtContext *ctx) override {} -}; - } // namespace fegen void fegen::Manager::emitG4() { @@ -1591,40 +1666,46 @@ void fegen::Manager::emitOpDefination() { emitter << "def " << opName << " : " << classname << "<\"" << opName << "\", [Pure]> {"; emitter.newLine(); - emitter.tab(); - // summary and description - emitter << "let summary = \"This is generated by buddy fegen.\";"; - emitter.newLine(); - emitter << "let description = [{This is generated by buddy fegen.}];"; - emitter.newLine(); - // arguments - emitter << "let arguments = ( ins "; - emitter.newLine(); - emitter.tab(); - for (auto param : opDef->getArguments()) { - auto paramTyStr = param->getType()->toStringForOpdef(); - auto paramName = param->getName(); - emitter << paramTyStr << " : $" << paramName; + { + emitter.tab(); + // summary and description + emitter << "let summary = \"This is generated by buddy fegen.\";"; emitter.newLine(); - } - emitter.shiftTab(); - emitter << ");"; - emitter.newLine(); - // results - emitter << "let results = (outs "; - emitter.newLine(); - emitter.tab(); - for (auto param : opDef->getArguments()) { - auto paramTyStr = param->getType()->toStringForOpdef(); - auto paramName = param->getName(); - emitter << paramTyStr << " : $" << paramName; + emitter << "let description = [{This is generated by buddy fegen.}];"; + emitter.newLine(); + // arguments + emitter << "let arguments = ( ins "; + emitter.newLine(); + { + emitter.tab(); + for (auto param : opDef->getArguments()) { + auto paramTyStr = param->getType()->toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + } + emitter << ");"; + emitter.newLine(); + // results + emitter << "let results = (outs "; emitter.newLine(); + { + emitter.tab(); + for (auto param : opDef->getResults()) { + auto paramTyStr = param->getType()->toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + } + emitter << ");"; + emitter.newLine(); + // end of def + emitter.shiftTab(); } - emitter.shiftTab(); - emitter << ");"; - emitter.newLine(); - // end of def - emitter.shiftTab(); emitter << "}"; emitter.newLine(); } @@ -1765,10 +1846,7 @@ void fegen::Manager::initbuiltinTypes() { fegen::TypeDefination::get( FEGEN_DIALECT_NAME, FEGEN_TENSOR, {fegen::Value::get(fegen::Type::getMetaType(), "elementType", - fegen::RightValue::getPlaceHolder()), - fegen::Value::get( - fegen::Type::getListType(fegen::Type::getInt32Type()), "shape", - fegen::RightValue::getPlaceHolder())}, + fegen::RightValue::getPlaceHolder())}, nullptr, false)}); // Optional @@ -1924,11 +2002,11 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto varType = std::any_cast(manager.stmtContentMap[ctx]); auto varName = ctx->identifier()->getText(); - emitter << varType->getTypeName() << " " << varName; + emitter << varType->toStringForCppKind() << " " << varName; if (ctx->expression()) { auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << " = " << expr->toString(); + emitter << " = " << expr->toStringForCppKind(); } return nullptr; } @@ -1936,7 +2014,7 @@ class StmtVisitor : public FegenParserBaseVisitor { auto varName = ctx->identifier()->getText(); auto expr = this->manager.getStmtContent( ctx->expression()); - emitter << varName << " = " << expr->toString(); + emitter << varName << " = " << expr->toStringForCppKind(); return nullptr; } std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { @@ -1967,7 +2045,7 @@ class StmtVisitor : public FegenParserBaseVisitor { auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << "if (" << expr->toString() << "){"; + emitter << "if (" << expr->toStringForCppKind() << "){"; emitter.tab(); emitter.newLine(); this->visit(ctx->statementBlock()); @@ -1992,7 +2070,7 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter << "; "; auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << expr->toString() << "; "; + emitter << expr->toStringForCppKind() << "; "; this->visit(ctx->assignStmt(0)); emitter << ") {"; } else { @@ -2000,7 +2078,7 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter << " "; auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << expr->toString() << "; "; + emitter << expr->toStringForCppKind() << "; "; this->visit(ctx->assignStmt(1)); emitter << ") {"; } @@ -2014,7 +2092,7 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); - emitter << "return " << expr->toString(); + emitter << "return " << expr->toStringForCppKind(); return nullptr; } From 4d6dfed4147e931236f51b3c9db8abf2c7c2d32a Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 23 Jul 2024 08:40:11 +0000 Subject: [PATCH 14/17] [FrontendGen] update function generation. --- frontend/FrontendGen/include/FegenManager.h | 4 +++- frontend/FrontendGen/include/FegenVisitor.h | 23 +++++++++++++-------- frontend/FrontendGen/lib/FegenManager.cpp | 23 +++++++++++++-------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index 83197141aa..bc7c8c4de7 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -151,7 +151,7 @@ class Type { virtual std::string toStringForOpdef(); // for generating cpp type kind. virtual std::string toStringForCppKind(); - static bool isSameType(Type *type1, Type *type2); + static bool isSameType(TypePtr type1, TypePtr type2); virtual ~Type() = default; // placeholder @@ -530,6 +530,8 @@ class FloatPointType : public Type { class StringType : public Type { public: StringType(); + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; }; // List class ListType : public Type { diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index f99ca3ece5..327b142b17 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -664,7 +664,6 @@ class FegenVisitor : public FegenParserBaseVisitor { for (size_t i = 0; i < ctx->typeSpec().size(); i++) { auto paramType = std::any_cast(this->visit(ctx->typeSpec(i))); - // manager.addStmtContent(ctx, paramType); auto paramName = ctx->identifier(i)->getText(); auto param = fegen::Value::get(paramType, paramName, fegen::RightValue::getPlaceHolder()); @@ -684,8 +683,8 @@ class FegenVisitor : public FegenParserBaseVisitor { auto varcontent = std::any_cast>( this->visit(ctx->expression())); - // TODO: check error - // if (!fegen::Type::isSameType(&varType, &varcontent->exprType)) { + // TODO: 支持获取expression的type后,可正常使用 + // if (!fegen::Type::isSameType(var->getType(), varcontent->getType())) { // std::cerr << "The variabel \" " << varName << "\" need \"" // << varType.getTypeName() // << " \" type rightvalue. Now the expression is " @@ -710,14 +709,16 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any_cast>( this->visit(ctx->expression())); auto var = sstack.attemptFindVar(varName); - // TODO - // if (!fegen::Type::isSameType(&var->getType(), &varcontent->exprType)) { + + // TODO: 支持获取expression的type后,可正常使用 + // if (!fegen::Type::isSameType(var->getType(), varcontent->getType())) { // std::cerr << "The variabel \" " << varName << "\" need \"" - // << var->getType().getTypeName() << " \" type rightvalue." + // << var->getType()->toStringForCppKind() << " \" type rightvalue." // << std::endl; // exit(0); // return nullptr; // } + fegen::Value *stmt = fegen::Value::get( var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); manager.stmtContentMap.insert(std::pair{ctx, stmt}); @@ -725,6 +726,7 @@ class FegenVisitor : public FegenParserBaseVisitor { return var; } + // TODO:测试并补足函数调用 std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { std::vector> parasList = {}; auto functionName = @@ -747,9 +749,11 @@ class FegenVisitor : public FegenParserBaseVisitor { exit(0); return nullptr; } + + // TODO: check parameter type // for (size_t i = 0; i < len1; i++) { - // if (!fegen::Type::isSameType(¶List[i]->getType(), - // ¶sList[i]->exprType)) { + // if (!fegen::Type::isSameType(paraList[i]->getType(), + // parasList[i]->exprType)) { // std::cerr << "The function \" " << functionName << "\" parameter" // << i // << " type mismatch." << std::endl; @@ -764,7 +768,8 @@ class FegenVisitor : public FegenParserBaseVisitor { manager.stmtContentMap.insert(std::pair{ctx, funcCall}); return returnType; } - + + // TODO:add op invoke std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override { return nullptr; } diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index a8ed9e0b7c..70bbcfb30a 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -98,9 +98,12 @@ int fegen::Type::getTypeLevel() { return this->typeLevel; } bool fegen::Type::isConstant() { return this->isConstType; } -bool fegen::Type::isSameType(fegen::Type *type1, fegen::Type *type2) { - if (type1->getTypeName() == type2->getTypeName()) +bool fegen::Type::isSameType(fegen::TypePtr type1, fegen::TypePtr type2) { + if (type1->getTypeName() == type2->getTypeName()){ + std::cout << "1" << std::endl; return true; + } + else return false; } @@ -416,6 +419,10 @@ fegen::StringType::StringType() fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3, true) {} +std::string fegen::StringType::toStringForCppKind() { + return "string"; +} + // class ListType fegen::ListType::ListType(fegen::RightValue elementType) : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_LIST, {elementType}), @@ -1894,11 +1901,11 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any_cast(manager.stmtContentMap[ctx]); auto functionName = std::any_cast(manager.stmtContentMap[ctx->funcName()]); - emitter << returnType->getTypeName() << " " << functionName << "("; + emitter << returnType->toStringForCppKind() << " " << functionName << "("; auto paraList = std::any_cast>( manager.stmtContentMap[ctx->funcParams()]); for (auto para : paraList) { - emitter << para->getType()->getTypeName() << " " << para->getName(); + emitter << para->getType()->toStringForCppKind() << " " << para->getName(); if (para != paraList.back()) emitter << ", "; } @@ -1924,7 +1931,7 @@ class StmtVisitor : public FegenParserBaseVisitor { std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { auto varType = std::any_cast(manager.stmtContentMap[ctx]); auto varName = ctx->identifier()->getText(); - emitter << varType->getTypeName() << " " << varName; + emitter << varType->toStringForCppKind() << " " << varName; if (ctx->expression()) { auto expr = std::any_cast>( manager.stmtContentMap[ctx->expression()]); @@ -1939,6 +1946,7 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter << varName << " = " << expr->toString(); return nullptr; } + // TODO:测试并补足函数调用 std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { auto function = std::any_cast(manager.stmtContentMap[ctx]); @@ -1984,7 +1992,6 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter << "}"; return nullptr; } - // TODO: 支持for循环 std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { if (ctx->varDeclStmt()) { emitter << "for ("; @@ -2018,9 +2025,7 @@ class StmtVisitor : public FegenParserBaseVisitor { return nullptr; } - std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { - return nullptr; - } + // TODO: add op declaration/invoke }; } // namespace fegen From 87bc2faf5e665dbaf599f966f28f50f44078ee5b Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 23 Jul 2024 08:43:04 +0000 Subject: [PATCH 15/17] [FrontendGen] Add function generation. --- frontend/FrontendGen/include/FegenManager.h | 146 ++++++++++---------- frontend/FrontendGen/include/FegenVisitor.h | 7 +- frontend/FrontendGen/lib/FegenManager.cpp | 11 +- 3 files changed, 84 insertions(+), 80 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index bc7c8c4de7..bca53b11f2 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -31,7 +31,6 @@ #define FEGEN_DIALECT_NAME "fegen_builtin" #define FEGEN_NOT_IMPLEMENTED_ERROR false - namespace fegen { class Type; class Manager; @@ -70,14 +69,12 @@ class Function { std::vector inputTypeList; // return type TypePtr returnType; - explicit Function(std::string name, - std::vector &&inputTypeList, - TypePtr returnType); + explicit Function(std::string name, std::vector &&inputTypeList, + TypePtr returnType); public: - static Function *get(std::string name, - std::vector inputTypeList, - TypePtr returnType = nullptr); + static Function *get(std::string name, std::vector inputTypeList, + TypePtr returnType = nullptr); ~Function() = default; std::string getName(); std::vector &getInputTypeList(); @@ -99,9 +96,9 @@ class Operation { // operation body context FegenParser::BodySpecContext *ctx; explicit Operation(std::string dialectName, std::string operationName, - std::vector &&arguments, - std::vector &&results, - FegenParser::BodySpecContext *ctx); + std::vector &&arguments, + std::vector &&results, + FegenParser::BodySpecContext *ctx); public: void setOpName(std::string); @@ -111,9 +108,9 @@ class Operation { std::vector &getResults(); Value *getResults(size_t i); static Operation *get(std::string operationName, - std::vector arguments, - std::vector results, - FegenParser::BodySpecContext *ctx); + std::vector arguments, + std::vector results, + FegenParser::BodySpecContext *ctx); ~Operation() = default; }; @@ -134,7 +131,8 @@ class Type { bool isConstType; public: - Type(TypeKind kind, std::string name, TypeDefination *tyDef, int typeLevel, bool isConstType); + Type(TypeKind kind, std::string name, TypeDefination *tyDef, int typeLevel, + bool isConstType); Type(const Type &) = default; Type(Type &&) = default; @@ -203,7 +201,8 @@ class Type { // Any<[elementType1, elementType2, ...]> static TypePtr getAnyType(RightValue elementTypes); - static TypePtr getCustomeType(std::vector params, TypeDefination* tydef); + static TypePtr getCustomeType(std::vector params, + TypeDefination *tydef); // Integer static TypePtr getIntegerTemplate(); @@ -231,7 +230,7 @@ class Type { // Any<[elementType1, elementType2, ...]> (elementType* is template) static TypePtr getAnyTemplate(RightValue elementTypes); - static TypePtr getCustomeTemplate(TypeDefination* tydef); + static TypePtr getCustomeTemplate(TypeDefination *tydef); }; class TypeDefination { @@ -247,13 +246,12 @@ class TypeDefination { public: TypeDefination(std::string dialectName, std::string name, - std::vector parameters, - FegenParser::TypeDefinationDeclContext *ctx, - bool ifCustome); + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome); static TypeDefination *get(std::string dialectName, std::string name, - std::vector parameters, - FegenParser::TypeDefinationDeclContext *ctx, - bool ifCustome = true); + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome = true); std::string getDialectName(); void setDialectName(std::string); std::string getName(); @@ -269,7 +267,7 @@ class TypeDefination { class RightValue { friend class Type; friend class Value; - + public: enum class LiteralKind { MONOSTATE, @@ -330,8 +328,7 @@ class RightValue { static std::shared_ptr getTypeRightValue(TypePtr); static std::shared_ptr getList(std::vector> &); - static std::shared_ptr - getLeftValue(fegen::Value *); + static std::shared_ptr getLeftValue(fegen::Value *); }; struct ExpressionNode : public Expression { @@ -472,8 +469,7 @@ class RightValue { static RightValue getFloatPoint(long double content, size_t size = 32); static RightValue getString(std::string content); static RightValue getTypeRightValue(TypePtr content); - static RightValue - getList(std::vector> &content); + static RightValue getList(std::vector> &content); static RightValue getLeftValue(fegen::Value *content); static RightValue getByExpr(std::shared_ptr expr); ~RightValue() = default; @@ -484,28 +480,28 @@ class RightValue { // PlaceHolder class PlaceHolderType : public Type { - public: +public: PlaceHolderType(); }; // Type class MetaType : public Type { - public: +public: MetaType(); // for generating typedef td file. virtual std::string toStringForTypedef() override; - }; // Template class MetaTemplate : public Type { - public: +public: MetaTemplate(); }; // Integer class IntegerType : public Type { RightValue size; - public: - IntegerType(RightValue size, TypeDefination* tyDef); + +public: + IntegerType(RightValue size, TypeDefination *tyDef); IntegerType(RightValue size); // for generating typedef td file. virtual std::string toStringForTypedef() override; @@ -517,7 +513,8 @@ class IntegerType : public Type { // FloatPoint class FloatPointType : public Type { RightValue size; - public: + +public: FloatPointType(RightValue size); // for generating typedef td file. virtual std::string toStringForTypedef() override; @@ -528,7 +525,7 @@ class FloatPointType : public Type { }; // String class StringType : public Type { - public: +public: StringType(); // for generating cpp type kind. virtual std::string toStringForCppKind() override; @@ -536,7 +533,8 @@ class StringType : public Type { // List class ListType : public Type { RightValue elementType; - public: + +public: ListType(RightValue elementType); // for generating typedef td file. virtual std::string toStringForTypedef() override; @@ -549,45 +547,50 @@ class ListType : public Type { class VectorType : public Type { RightValue elementType; RightValue size; - public: + +public: VectorType(RightValue elementType, RightValue size); }; // Tensor class TensorType : public Type { RightValue elementType; RightValue shape; - public: + +public: TensorType(RightValue elementType, RightValue shape); }; // Optional class OptionalType : public Type { RightValue elementType; - public: + +public: OptionalType(RightValue elementType); }; // Any<[ty1, ty2, ...]> class AnyType : public Type { RightValue elementTypes; - public: + +public: AnyType(RightValue elementTypes); }; // custome type class CustomeType : public Type { std::vector params; - public: - CustomeType(std::vector params, TypeDefination* tydef); + +public: + CustomeType(std::vector params, TypeDefination *tydef); }; class TemplateType : public Type { - public: - TemplateType(TypeDefination* tydef); +public: + TemplateType(TypeDefination *tydef); virtual TypePtr instantiate(std::vector params) = 0; virtual ~TemplateType() = default; }; // Integer class IntegerTemplateType : public TemplateType { - public: +public: IntegerTemplateType(); virtual TypePtr instantiate(std::vector params) override; // for generating typedef td file. @@ -597,7 +600,7 @@ class IntegerTemplateType : public TemplateType { }; // FloatPoint class FloatPointTemplateType : public TemplateType { - public: +public: FloatPointTemplateType(); virtual TypePtr instantiate(std::vector params) override; // for generating typedef td file. @@ -605,7 +608,7 @@ class FloatPointTemplateType : public TemplateType { }; // String class StringTemplateType : public TemplateType { - public: +public: StringTemplateType(); virtual TypePtr instantiate(std::vector params) override; // for generating typedef td file. @@ -614,7 +617,8 @@ class StringTemplateType : public TemplateType { // List (ty is a template) class ListTemplateType : public TemplateType { RightValue elementType; - public: + +public: ListTemplateType(RightValue elementType); virtual TypePtr instantiate(std::vector params) override; virtual std::string toStringForTypedef() override; @@ -622,34 +626,36 @@ class ListTemplateType : public TemplateType { }; // Vector class VectorTemplateType : public TemplateType { - public: +public: VectorTemplateType(); virtual TypePtr instantiate(std::vector params) override; }; // Tensor class TensorTemplateType : public TemplateType { - public: +public: TensorTemplateType(); virtual TypePtr instantiate(std::vector params) override; }; // Optional (ty is a template) class OptionalTemplateType : public TemplateType { RightValue elementType; - public: + +public: OptionalTemplateType(RightValue elementType); virtual TypePtr instantiate(std::vector params) override; }; // Any<[ty1, ty2, ...]> (ty* is a template) class AnyTemplateType : public TemplateType { RightValue elementTypes; - public: + +public: AnyTemplateType(RightValue elementTypes); virtual TypePtr instantiate(std::vector params) override; }; // custome type class CustomeTemplateType : public TemplateType { - public: - CustomeTemplateType(TypeDefination* tydef); +public: + CustomeTemplateType(TypeDefination *tydef); virtual TypePtr instantiate(std::vector params) override; // for generating typedef td file. virtual std::string toStringForTypedef() override; @@ -657,7 +663,6 @@ class CustomeTemplateType : public TemplateType { virtual std::string toStringForOpdef() override; }; - class Value { friend class Type; @@ -671,8 +676,7 @@ class Value { Value(const Value &rhs); Value(Value &&rhs); - static Value *get(TypePtr type, std::string name, - RightValue constant); + static Value *get(TypePtr type, std::string name, RightValue constant); std::string getName(); TypePtr getType(); @@ -705,11 +709,11 @@ class ParserRule { // context in parser tree antlr4::ParserRuleContext *ctx; explicit ParserRule(std::string content, ParserNode *src, - antlr4::ParserRuleContext *ctx); + antlr4::ParserRuleContext *ctx); public: static ParserRule *get(std::string content, ParserNode *src, - antlr4::ParserRuleContext *ctx); + antlr4::ParserRuleContext *ctx); llvm::StringRef getContent(); // check and add input value bool addInput(Value input); @@ -730,11 +734,11 @@ class ParserNode { antlr4::ParserRuleContext *ctx; NodeType ntype; explicit ParserNode(std::vector &&rules, - antlr4::ParserRuleContext *ctx, NodeType ntype); + antlr4::ParserRuleContext *ctx, NodeType ntype); public: static ParserNode *get(std::vector rules, - antlr4::ParserRuleContext *ctx, NodeType ntype); + antlr4::ParserRuleContext *ctx, NodeType ntype); static ParserNode *get(antlr4::ParserRuleContext *ctx, NodeType ntype); void addFegenRule(ParserRule *rule); // release rules first @@ -745,13 +749,14 @@ class FegenVisitor; class Manager { friend class FegenVisitor; + private: -struct OverloadedType { - llvm::SmallVector tys; - OverloadedType(TypeDefination *); - OverloadedType(std::initializer_list&&); - TypeDefination* get(unsigned i); -}; + struct OverloadedType { + llvm::SmallVector tys; + OverloadedType(TypeDefination *); + OverloadedType(std::initializer_list &&); + TypeDefination *get(unsigned i); + }; private: std::map typeDefMap; @@ -782,7 +787,7 @@ struct OverloadedType { void setModuleName(std::string name); TypeDefination *getTypeDefination(std::string name); - TypeDefination* getOverloadedTypeDefination(std::string name); + TypeDefination *getOverloadedTypeDefination(std::string name); bool addTypeDefination(TypeDefination *tyDef); bool addOverloadedTypeDefination(TypeDefination *tyDef); @@ -796,9 +801,8 @@ struct OverloadedType { void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *); }; -TypePtr - inferenceType(std::vector>, - FegenOperator); +TypePtr inferenceType(std::vector>, + FegenOperator); } // namespace fegen diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 327b142b17..039fd0a86c 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -709,11 +709,12 @@ class FegenVisitor : public FegenParserBaseVisitor { std::any_cast>( this->visit(ctx->expression())); auto var = sstack.attemptFindVar(varName); - + // TODO: 支持获取expression的type后,可正常使用 // if (!fegen::Type::isSameType(var->getType(), varcontent->getType())) { // std::cerr << "The variabel \" " << varName << "\" need \"" - // << var->getType()->toStringForCppKind() << " \" type rightvalue." + // << var->getType()->toStringForCppKind() << " \" type + // rightvalue." // << std::endl; // exit(0); // return nullptr; @@ -768,7 +769,7 @@ class FegenVisitor : public FegenParserBaseVisitor { manager.stmtContentMap.insert(std::pair{ctx, funcCall}); return returnType; } - + // TODO:add op invoke std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override { return nullptr; diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 70bbcfb30a..62a691a242 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -99,11 +99,11 @@ int fegen::Type::getTypeLevel() { return this->typeLevel; } bool fegen::Type::isConstant() { return this->isConstType; } bool fegen::Type::isSameType(fegen::TypePtr type1, fegen::TypePtr type2) { - if (type1->getTypeName() == type2->getTypeName()){ + if (type1->getTypeName() == type2->getTypeName()) { std::cout << "1" << std::endl; return true; } - + else return false; } @@ -419,9 +419,7 @@ fegen::StringType::StringType() fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3, true) {} -std::string fegen::StringType::toStringForCppKind() { - return "string"; -} +std::string fegen::StringType::toStringForCppKind() { return "string"; } // class ListType fegen::ListType::ListType(fegen::RightValue elementType) @@ -1905,7 +1903,8 @@ class StmtVisitor : public FegenParserBaseVisitor { auto paraList = std::any_cast>( manager.stmtContentMap[ctx->funcParams()]); for (auto para : paraList) { - emitter << para->getType()->toStringForCppKind() << " " << para->getName(); + emitter << para->getType()->toStringForCppKind() << " " + << para->getName(); if (para != paraList.back()) emitter << ", "; } From cde5c7eca125b45662bc97c594b9deedf5cd8714 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 23 Jul 2024 08:46:45 +0000 Subject: [PATCH 16/17] [FrontendGen] update function generation. --- frontend/FrontendGen/include/FegenManager.h | 2 +- frontend/FrontendGen/include/FegenVisitor.h | 2 +- frontend/FrontendGen/lib/FegenManager.cpp | 47 +-------------------- 3 files changed, 3 insertions(+), 48 deletions(-) diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h index bca53b11f2..ada44b40a1 100644 --- a/frontend/FrontendGen/include/FegenManager.h +++ b/frontend/FrontendGen/include/FegenManager.h @@ -806,4 +806,4 @@ TypePtr inferenceType(std::vector>, } // namespace fegen -#endif \ No newline at end of file +#endif diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 039fd0a86c..1bfe3e1148 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -852,4 +852,4 @@ class FegenVisitor : public FegenParserBaseVisitor { } }; } // namespace fegen -#endif \ No newline at end of file +#endif diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 62a691a242..96631ccc95 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1377,51 +1377,6 @@ class Emitter { return this->stream; } }; - -class StmtGenerator : FegenParserBaseVisitor { -private: - Manager &manager; - Emitter &emitter; - -public: - StmtGenerator(Emitter &emitter) - : manager(Manager::getManager()), emitter(emitter) {} - std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { - auto var = manager.getStmtContent(ctx->identifier()); - switch (var->getType()->getTypeKind()) { - case fegen::Type::TypeKind::CPP: { - this->emitter << var->getType()->toStringForCppKind() << " " - << var->getName(); - if (ctx->expression()) { - auto expr = this->manager.getStmtContent( - ctx->expression()); - this->emitter << " = " << expr->toStringForCppKind(); - } - this->emitter << ";"; - this->emitter.newLine(); - break; - } - case fegen::Type::TypeKind::ATTRIBUTE: { - break; - } - case fegen::Type::TypeKind::OPERAND: { - break; - } - } - return nullptr; - } - - std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override {} - - std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override {} - - std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override {} - - std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override {} - - std::any visitForStmt(FegenParser::ForStmtContext *ctx) override {} -}; - } // namespace fegen void fegen::Manager::emitG4() { @@ -2037,4 +1992,4 @@ void fegen::Manager::emitBuiltinFunction( StmtVisitor visitor(emitter); visitor.visit(moduleAST); fileStream.close(); -} \ No newline at end of file +} From 10601780aaff765e6983938eab16386439ad4116 Mon Sep 17 00:00:00 2001 From: chh Date: Tue, 23 Jul 2024 17:00:59 +0800 Subject: [PATCH 17/17] [Frontend] Update function generation. --- frontend/FrontendGen/include/FegenVisitor.h | 2 +- frontend/FrontendGen/lib/FegenManager.cpp | 3 +++ frontend/FrontendGen/lib/FegenVisitor.cpp | 3 +-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h index 1bfe3e1148..3db3ca1aec 100644 --- a/frontend/FrontendGen/include/FegenVisitor.h +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -722,7 +722,7 @@ class FegenVisitor : public FegenParserBaseVisitor { fegen::Value *stmt = fegen::Value::get( var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); - manager.stmtContentMap.insert(std::pair{ctx, stmt}); + manager.addStmtContent(ctx, stmt); manager.addStmtContent(ctx->expression(), varcontent); return var; } diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp index 96631ccc95..998b7fd542 100644 --- a/frontend/FrontendGen/lib/FegenManager.cpp +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -1978,6 +1978,9 @@ class StmtVisitor : public FegenParserBaseVisitor { emitter << "return " << expr->toString(); return nullptr; } + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { + return nullptr; + } // TODO: add op declaration/invoke }; diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp index 882246dedf..761fe0529d 100644 --- a/frontend/FrontendGen/lib/FegenVisitor.cpp +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -6,7 +6,6 @@ bool fegen::checkParams(std::vector &expected, } bool fegen::checkListLiteral( - std::vector> - &listLiteral) { + std::vector> &listLiteral) { return true; } \ No newline at end of file