From 95c7c74d3a20f8b946c5a5a4f29e0e3e834d332c Mon Sep 17 00:00:00 2001 From: Matthias Kurtenacker Date: Tue, 5 Nov 2024 15:12:40 +0100 Subject: [PATCH] Allow for static values as array sizes. This applies to both array types `[T * n]`, as well as RepeatArrayExpr `[E; n]`. In both cases, n must be defined as `static n = ...`, or the emitter might fail. --- include/artic/ast.h | 13 +++--- include/artic/parser.h | 3 +- src/bind.cpp | 8 ++++ src/check.cpp | 95 ++++++++++++++++++++++++++++++++++++++++-- src/emit.cpp | 2 +- src/parser.cpp | 28 ++++++------- src/print.cpp | 20 ++++++++- 7 files changed, 142 insertions(+), 27 deletions(-) diff --git a/include/artic/ast.h b/include/artic/ast.h index 0b1739f..584d740 100644 --- a/include/artic/ast.h +++ b/include/artic/ast.h @@ -356,14 +356,15 @@ struct ArrayType : public Type { /// Sized array type. struct SizedArrayType : public ArrayType { - size_t size; + std::variant size; bool is_simd; - SizedArrayType(const Loc& loc, Ptr&& elem, size_t size, bool is_simd) - : ArrayType(loc, std::move(elem)), size(size), is_simd(is_simd) + SizedArrayType(const Loc& loc, Ptr&& elem, std::variant&& size, bool is_simd) + : ArrayType(loc, std::move(elem)), size(std::move(size)), is_simd(is_simd) {} const artic::Type* infer(TypeChecker&) override; + void bind(NameBinder&) override; void print(Printer&) const override; }; @@ -662,11 +663,11 @@ struct ArrayExpr : public Expr { /// Array expression repeating a given value a given number of times. struct RepeatArrayExpr : public Expr { Ptr elem; - size_t size; + std::variant size; bool is_simd; - RepeatArrayExpr(const Loc& loc, Ptr&& elem, size_t size, bool is_simd) - : Expr(loc), elem(std::move(elem)), size(size), is_simd(is_simd) + RepeatArrayExpr(const Loc& loc, Ptr&& elem, std::variant&& size, bool is_simd) + : Expr(loc), elem(std::move(elem)), size(std::move(size)), is_simd(is_simd) {} bool is_jumping() const override; diff --git a/include/artic/parser.h b/include/artic/parser.h index 15b248f..8d5c0de 100644 --- a/include/artic/parser.h +++ b/include/artic/parser.h @@ -108,9 +108,10 @@ class Parser : public Logger { ast::AsmExpr::Constr parse_constr(); Literal parse_lit(); std::string parse_str(); - std::optional parse_array_size(); size_t parse_addr_space(); + std::optional> parse_array_size(); + std::pair, Ptr> parse_cond_and_block(); struct Tracker { diff --git a/src/bind.cpp b/src/bind.cpp index e094c60..a09d9be 100644 --- a/src/bind.cpp +++ b/src/bind.cpp @@ -113,6 +113,12 @@ void ArrayType::bind(NameBinder& binder) { binder.bind(*elem); } +void SizedArrayType::bind(NameBinder& binder) { + binder.bind(*elem); + if (std::holds_alternative(size)) + binder.bind(std::get(size)); +} + void FnType::bind(NameBinder& binder) { binder.bind(*from); if (to) binder.bind(*to); @@ -177,6 +183,8 @@ void ArrayExpr::bind(NameBinder& binder) { void RepeatArrayExpr::bind(NameBinder& binder) { binder.bind(*elem); + if (std::holds_alternative(size)) + binder.bind(std::get(size)); } void FnExpr::bind(NameBinder& binder, bool in_for_loop) { diff --git a/src/check.cpp b/src/check.cpp index ccbacbd..606d860 100644 --- a/src/check.cpp +++ b/src/check.cpp @@ -844,7 +844,37 @@ const artic::Type* SizedArrayType::infer(TypeChecker& checker) { auto elem_type = checker.infer(*elem); if (is_simd && !elem_type->isa()) return checker.invalid_simd(loc, elem_type); - return checker.type_table.sized_array_type(elem_type, size, is_simd); + + if (std::holds_alternative(size)) { + auto &path = std::get(size); + const auto* decl = path.start_decl; + + for (size_t i = 0, n = path.elems.size(); i < n; ++i) { + if (path.elems[i].is_super()) + decl = i == 0 ? path.start_decl : decl->as()->super; + if (auto mod_type = path.elems[i].type->isa()) { + decl = &mod_type->member(path.elems[i + 1].index); + } else if (!path.is_ctor) { + assert(path.elems[i].inferred_args.empty()); + assert(decl->isa() && "The only supported type right now."); + break; + } else if (match_app(path.elems[i].type).second) { + assert(false && "This is not supported as a size for repeated arrays."); + } else if (auto [type_app, enum_type] = match_app(path.elems[i].type); enum_type) { + assert(false && "This is not supported as a size for repeated arrays."); + } + } + + auto static_decl = decl->as(); + assert(!static_decl->is_mut); + assert(static_decl->init); + auto& value = static_decl->init; + auto lit_value = value->as()->lit; + + size = lit_value.as_integer(); + } + + return checker.type_table.sized_array_type(elem_type, std::get(size), is_simd); } const artic::Type* UnsizedArrayType::infer(TypeChecker& checker) { @@ -983,12 +1013,71 @@ const artic::Type* RepeatArrayExpr::infer(TypeChecker& checker) { auto elem_type = checker.deref(elem); if (is_simd && !elem_type->isa()) return checker.invalid_simd(loc, elem_type); - return checker.type_table.sized_array_type(elem_type, size, is_simd); + + if (std::holds_alternative(size)) { + auto &path = std::get(size); + const auto* decl = path.start_decl; + + for (size_t i = 0, n = path.elems.size(); i < n; ++i) { + if (path.elems[i].is_super()) + decl = i == 0 ? path.start_decl : decl->as()->super; + if (auto mod_type = path.elems[i].type->isa()) { + decl = &mod_type->member(path.elems[i + 1].index); + } else if (!path.is_ctor) { + assert(path.elems[i].inferred_args.empty()); + assert(decl->isa() && "The only supported type right now."); + break; + } else if (match_app(path.elems[i].type).second) { + assert(false && "This is not supported as a size for repeated arrays."); + } else if (auto [type_app, enum_type] = match_app(path.elems[i].type); enum_type) { + assert(false && "This is not supported as a size for repeated arrays."); + } + } + + auto static_decl = decl->as(); + assert(!static_decl->is_mut); + assert(static_decl->init); + auto& value = static_decl->init; + auto lit_value = value->as()->lit; + + size = lit_value.as_integer(); + } + + return checker.type_table.sized_array_type(elem_type, std::get(size), is_simd); } const artic::Type* RepeatArrayExpr::check(TypeChecker& checker, const artic::Type* expected) { + if (std::holds_alternative(size)) { + auto &path = std::get(size); + const auto* decl = path.start_decl; + + for (size_t i = 0, n = path.elems.size(); i < n; ++i) { + if (path.elems[i].is_super()) + decl = i == 0 ? path.start_decl : decl->as()->super; + if (auto mod_type = path.elems[i].type->isa()) { + decl = &mod_type->member(path.elems[i + 1].index); + } else if (!path.is_ctor) { + assert(path.elems[i].inferred_args.empty()); + assert(decl->isa() && "The only supported type right now."); + break; + } else if (match_app(path.elems[i].type).second) { + assert(false && "This is not supported as a size for repeated arrays."); + } else if (auto [type_app, enum_type] = match_app(path.elems[i].type); enum_type) { + assert(false && "This is not supported as a size for repeated arrays."); + } + } + + auto static_decl = decl->as(); + assert(!static_decl->is_mut); + assert(static_decl->init); + auto& value = static_decl->init; + auto lit_value = value->as()->lit; + + size = lit_value.as_integer(); + } + return checker.check_array(loc, "array expression", - expected, size, is_simd, [&] (auto elem_type) { + expected, std::get(size), is_simd, [&] (auto elem_type) { checker.coerce(elem, elem_type); }); } diff --git a/src/emit.cpp b/src/emit.cpp index 2e49d83..331a8fa 100644 --- a/src/emit.cpp +++ b/src/emit.cpp @@ -1204,7 +1204,7 @@ const thorin::Def* ArrayExpr::emit(Emitter& emitter) const { } const thorin::Def* RepeatArrayExpr::emit(Emitter& emitter) const { - thorin::Array ops(size, emitter.emit(*elem)); + thorin::Array ops(std::get(size), emitter.emit(*elem)); return is_simd ? emitter.world.vector(ops, emitter.debug_info(*this)) : emitter.world.definite_array(ops, emitter.debug_info(*this)); diff --git a/src/parser.cpp b/src/parser.cpp index 2f3634e..589d2df 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -563,7 +563,7 @@ Ptr Parser::parse_array_expr() { auto size = parse_array_size(); expect(Token::RBracket); if (size) - return make_ptr(tracker(), std::move(elems.front()), *size, is_simd); + return make_ptr(tracker(), std::move(elems.front()), std::move(*size), is_simd); return make_ptr(tracker(), std::move(elems), is_simd); } else if (accept(Token::Comma)) { parse_list(Token::RBracket, Token::Comma, [&] { @@ -1065,15 +1065,17 @@ Ptr Parser::parse_array_type() { bool is_simd = accept(Token::Simd); expect(Token::LBracket); auto elem = parse_type(); - std::optional size; if (is_simd || ahead().tag() == Token::Mul) { expect(Token::Mul); - size = parse_array_size(); + auto size = parse_array_size(); + expect(Token::RBracket); + if (size) + return make_ptr(tracker(), std::move(elem), std::move(*size), is_simd); + return make_ptr(tracker(), std::move(elem)); + } else { + expect(Token::RBracket); + return make_ptr(tracker(), std::move(elem)); } - expect(Token::RBracket); - if (size) - return make_ptr(tracker(), std::move(elem), *size, is_simd); - return make_ptr(tracker(), std::move(elem)); } Ptr Parser::parse_fn_type() { @@ -1238,17 +1240,15 @@ std::string Parser::parse_str() { return str; } -std::optional Parser::parse_array_size() { - std::optional size; +std::optional> Parser::parse_array_size() { if (ahead().is_literal() && ahead().literal().is_integer()) { - size = ahead().literal().as_integer(); + auto size = ahead().literal().as_integer(); eat(Token::Lit); + return size; } else { - error(ahead().loc(), "expected integer literal as array size"); - if (ahead().tag() != Token::RBracket) - next(); + auto path = parse_path(); + return path; } - return size; } size_t Parser::parse_addr_space() { diff --git a/src/print.cpp b/src/print.cpp index acecc94..15f3377 100644 --- a/src/print.cpp +++ b/src/print.cpp @@ -165,7 +165,15 @@ void RepeatArrayExpr::print(Printer& p) const { p << log::keyword_style("simd"); p << '['; elem->print(p); - p << "; " << size << ']'; + p << "; "; + std::visit([&] (auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) + p << arg; + else if constexpr (std::is_same_v) + arg->print(p); + }, size); + p << ']'; } void FnExpr::print(Printer& p) const { @@ -647,7 +655,15 @@ void SizedArrayType::print(Printer& p) const { p << log::keyword_style("simd"); p << '['; elem->print(p); - p << " * " << size << ']'; + p << " * "; + std::visit([&] (auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) + p << arg; + else if constexpr (std::is_same_v) + arg->print(p); + }, size); + p << ']'; } void UnsizedArrayType::print(Printer& p) const {