Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for static values as array sizes. #28

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/artic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,15 @@ struct ArrayType : public Type {

/// Sized array type.
struct SizedArrayType : public ArrayType {
size_t size;
std::variant<size_t, ast::Path> size;
bool is_simd;

SizedArrayType(const Loc& loc, Ptr<Type>&& elem, size_t size, bool is_simd)
: ArrayType(loc, std::move(elem)), size(size), is_simd(is_simd)
SizedArrayType(const Loc& loc, Ptr<Type>&& elem, std::variant<size_t, ast::Path>&& size, bool is_simd)
: ArrayType(loc, std::move(elem)), size(std::move(size)), is_simd(is_simd)
{}

const artic::Type* infer(TypeChecker&) override;
void bind(NameBinder&) override;
void print(Printer&) const override;
};

Expand Down Expand Up @@ -673,11 +674,11 @@ struct ArrayExpr : public Expr {
/// Array expression repeating a given value a given number of times.
struct RepeatArrayExpr : public Expr {
Ptr<Expr> elem;
size_t size;
std::variant<size_t, ast::Path> size;
bool is_simd;

RepeatArrayExpr(const Loc& loc, Ptr<Expr>&& elem, size_t size, bool is_simd)
: Expr(loc), elem(std::move(elem)), size(size), is_simd(is_simd)
RepeatArrayExpr(const Loc& loc, Ptr<Expr>&& elem, std::variant<size_t, ast::Path>&& size, bool is_simd)
: Expr(loc), elem(std::move(elem)), size(std::move(size)), is_simd(is_simd)
{}

bool is_jumping() const override;
Expand Down
3 changes: 2 additions & 1 deletion include/artic/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ class Parser : public Logger {
ast::AsmExpr::Constr parse_constr();
Literal parse_lit();
std::string parse_str();
std::optional<size_t> parse_array_size();
size_t parse_addr_space();

std::optional<std::variant<size_t, ast::Path>> parse_array_size();

std::pair<Ptr<ast::Expr>, Ptr<ast::Expr>> parse_cond_and_block();

struct Tracker {
Expand Down
8 changes: 8 additions & 0 deletions src/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ void ArrayType::bind(NameBinder& binder) {
binder.bind(*elem);
}

void SizedArrayType::bind(NameBinder& binder) {
binder.bind(*elem);
if (std::holds_alternative<ast::Path>(size))
binder.bind(std::get<ast::Path>(size));
}

void FnType::bind(NameBinder& binder) {
binder.bind(*from);
if (to) binder.bind(*to);
Expand Down Expand Up @@ -179,6 +185,8 @@ void ArrayExpr::bind(NameBinder& binder) {

void RepeatArrayExpr::bind(NameBinder& binder) {
binder.bind(*elem);
if (std::holds_alternative<ast::Path>(size))
binder.bind(std::get<ast::Path>(size));
}

void FnExpr::bind(NameBinder& binder, bool in_for_loop) {
Expand Down
95 changes: 92 additions & 3 deletions src/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,37 @@ const artic::Type* SizedArrayType::infer(TypeChecker& checker) {
auto elem_type = checker.infer(*elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
return checker.invalid_simd(loc, elem_type);
return checker.type_table.sized_array_type(elem_type, size, is_simd);

if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.type_table.sized_array_type(elem_type, std::get<size_t>(size), is_simd);
}

const artic::Type* UnsizedArrayType::infer(TypeChecker& checker) {
Expand Down Expand Up @@ -994,12 +1024,71 @@ const artic::Type* RepeatArrayExpr::infer(TypeChecker& checker) {
auto elem_type = checker.deref(elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
return checker.invalid_simd(loc, elem_type);
return checker.type_table.sized_array_type(elem_type, size, is_simd);

if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.type_table.sized_array_type(elem_type, std::get<size_t>(size), is_simd);
}

const artic::Type* RepeatArrayExpr::check(TypeChecker& checker, const artic::Type* expected) {
if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.check_array(loc, "array expression",
expected, size, is_simd, [&] (auto elem_type) {
expected, std::get<size_t>(size), is_simd, [&] (auto elem_type) {
checker.coerce(elem, elem_type);
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ const thorin::Def* ArrayExpr::emit(Emitter& emitter) const {
}

const thorin::Def* RepeatArrayExpr::emit(Emitter& emitter) const {
thorin::Array<const thorin::Def*> ops(size, emitter.emit(*elem));
thorin::Array<const thorin::Def*> ops(std::get<size_t>(size), emitter.emit(*elem));
return is_simd
? emitter.world.vector(ops, emitter.debug_info(*this))
: emitter.world.definite_array(ops, emitter.debug_info(*this));
Expand Down
28 changes: 14 additions & 14 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ Ptr<ast::Expr> Parser::parse_array_expr() {
auto size = parse_array_size();
expect(Token::RBracket);
if (size)
return make_ptr<ast::RepeatArrayExpr>(tracker(), std::move(elems.front()), *size, is_simd);
return make_ptr<ast::RepeatArrayExpr>(tracker(), std::move(elems.front()), std::move(*size), is_simd);
return make_ptr<ast::ArrayExpr>(tracker(), std::move(elems), is_simd);
} else if (accept(Token::Comma)) {
parse_list(Token::RBracket, Token::Comma, [&] {
Expand Down Expand Up @@ -1069,15 +1069,17 @@ Ptr<ast::ArrayType> Parser::parse_array_type() {
bool is_simd = accept(Token::Simd);
expect(Token::LBracket);
auto elem = parse_type();
std::optional<size_t> size;
if (is_simd || ahead().tag() == Token::Mul) {
expect(Token::Mul);
size = parse_array_size();
auto size = parse_array_size();
expect(Token::RBracket);
if (size)
return make_ptr<ast::SizedArrayType>(tracker(), std::move(elem), std::move(*size), is_simd);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
} else {
expect(Token::RBracket);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
}
expect(Token::RBracket);
if (size)
return make_ptr<ast::SizedArrayType>(tracker(), std::move(elem), *size, is_simd);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
}

Ptr<ast::FnType> Parser::parse_fn_type() {
Expand Down Expand Up @@ -1246,17 +1248,15 @@ std::string Parser::parse_str() {
return str;
}

std::optional<size_t> Parser::parse_array_size() {
std::optional<size_t> size;
std::optional<std::variant<size_t, ast::Path>> Parser::parse_array_size() {
if (ahead().is_literal() && ahead().literal().is_integer()) {
size = ahead().literal().as_integer();
auto size = ahead().literal().as_integer();
eat(Token::Lit);
return size;
} else {
error(ahead().loc(), "expected integer literal as array size");
if (ahead().tag() != Token::RBracket)
next();
auto path = parse_path();
return path;
}
return size;
}

size_t Parser::parse_addr_space() {
Expand Down
20 changes: 18 additions & 2 deletions src/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,15 @@ void RepeatArrayExpr::print(Printer& p) const {
p << log::keyword_style("simd");
p << '[';
elem->print(p);
p << "; " << size << ']';
p << "; ";
std::visit([&] (auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, size_t>)
p << arg;
else if constexpr (std::is_same_v<T, ast::Path&>)
arg->print(p);
}, size);
p << ']';
}

void FnExpr::print(Printer& p) const {
Expand Down Expand Up @@ -647,7 +655,15 @@ void SizedArrayType::print(Printer& p) const {
p << log::keyword_style("simd");
p << '[';
elem->print(p);
p << " * " << size << ']';
p << " * ";
std::visit([&] (auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, size_t>)
p << arg;
else if constexpr (std::is_same_v<T, ast::Path&>)
arg->print(p);
}, size);
p << ']';
}

void UnsizedArrayType::print(Printer& p) const {
Expand Down
Loading