Skip to content

Commit

Permalink
Merge pull request #46 from Princess-org/lambdas
Browse files Browse the repository at this point in the history
Add Lambdas
  • Loading branch information
Victorious3 authored Mar 15, 2024
2 parents cefaf41 + 056caa1 commit 650b265
Show file tree
Hide file tree
Showing 14 changed files with 854 additions and 135 deletions.
2 changes: 2 additions & 0 deletions src/codegen.pr
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def type_to_str(tpe: &typechecking::Type) -> Str {
ret = "%\"" + tpe.type_name + '"'
case typechecking::TypeKind::INTERFACE_IMPL
ret = type_to_str(tpe.tpe)
case typechecking::TypeKind::TO_INFER
ret = "<?>"
case
error(debug::type_to_str(tpe), " ", typechecking::is_polymorph(tpe), "\n")
error(tpe.kind, "\n")
Expand Down
82 changes: 66 additions & 16 deletions src/compiler.pr
Original file line number Diff line number Diff line change
Expand Up @@ -4743,6 +4743,8 @@ export def walk_expression(node: &parser::Node, state: &State) -> Value {
case parser::NodeKind::IF_EXPR
expr = walk_IfExpr(node, state)
case parser::NodeKind::RANGE, parser::NodeKind::RANGE_INC
case parser::NodeKind::LAMBDA
expr = walk_Lambda(node, state)
case;
error(node.kind, "\n")
assert(false)
Expand Down Expand Up @@ -5778,6 +5780,31 @@ export def walk_VarDecl(node: &parser::Node, state: &State, set_constant: bool =
}
}

def walk_Lambda(node: &parser::Node, state: &State) -> Value {
import_cstd_function("malloc", state)
let function = node.value.lambda.function
if not function { return NO_VALUE }

let loc = make_location(node, state)
let context_tpe = function.state

let ret = state.alloca(node.value.lambda.closure_type, loc)

typechecking::create_type_entry(typechecking::reference(context_tpe))
let context_ptr = create_closure_context(function, ret, loc, state, insert_temporary = true)

predeclare_function(function)
create_function(node, function.tpe, node.value.lambda.body, node.inner_scope, null, state, is_closure = true, params = node.value.lambda.parameters)
state.module.imported.add(node.tpe.type_name)

let context = create_closure_context_captures(function, loc, state)
state.store(context_ptr, context, loc)

push_local_var(context_tpe.type_name, reference(context_tpe), state.current_function)

return state.load(node.value.lambda.closure_type, ret, loc)
}

def walk_Def(node: &parser::Node, state: &State) {
import_cstd_function("malloc", state)
let function = node.value.def_.function
Expand Down Expand Up @@ -5809,6 +5836,21 @@ def walk_Def(node: &parser::Node, state: &State) {

push_declare(node, ret, value.name, state)

let context_ptr = create_closure_context(function, ret, loc, state)

predeclare_function(function)
create_function(node, node.tpe, node.value.def_.body, node.inner_scope, null, state, is_closure = true, params = node.value.def_.params)
state.module.imported.add(node.tpe.type_name)

let context = create_closure_context_captures(function, loc, state)
state.store(context_ptr, context, loc)

push_local_var(context_tpe.type_name, reference(context_tpe), state.current_function)
}

def create_closure_context(function: &Function, ret: Value, loc: &Value, state: &State, insert_temporary: bool = false) -> Value {
let context_tpe = function.state

let context_ptr_i8 = state.call("malloc", pointer(builtins::int8_), [[ kind = ValueKind::INT, tpe = builtins::size_t_, i = context_tpe.size ] !Value], loc)
let ref_count_i8 = state.call("malloc", pointer(builtins::int8_), [[ kind = ValueKind::INT, tpe = builtins::size_t_, i = builtins::int64_.size ] !Value], loc)
let ref_count = state.bitcast(pointer(builtins::int64_), ref_count_i8, loc)
Expand All @@ -5821,16 +5863,23 @@ def walk_Def(node: &parser::Node, state: &State) {
closure = state.insert_value(typechecking::reference(null), closure, context_ptr_i8, [1], loc)
closure = state.insert_value(typechecking::reference(null), closure, context_tpe_value, [2], loc)

let context_fun_ptr = state.gep(pointer(pointer(node.tpe)), value.tpe, ret, [make_int_value(0), make_int_value(0)], loc)
state.store(context_fun_ptr, [ kind = ValueKind::GLOBAL, tpe = pointer(node.tpe), name = node.tpe.type_name ] !Value, loc)
let context_ref_ptr = state.gep(pointer(typechecking::reference(null)), value.tpe, ret, [make_int_value(0), make_int_value(1)], loc)
let context_fun_ptr = state.gep(pointer(pointer(function.tpe)), ret.tpe.tpe, ret, [make_int_value(0), make_int_value(0)], loc)
state.store(context_fun_ptr, [ kind = ValueKind::GLOBAL, tpe = pointer(function.tpe), name = function.tpe.type_name ] !Value, loc)
let context_ref_ptr = state.gep(pointer(typechecking::reference(null)), ret.tpe.tpe, ret, [make_int_value(0), make_int_value(1)], loc)
state.store(context_ref_ptr, closure, loc)

predeclare_function(function)
create_function(node, node.tpe, node.value.def_.body, node.inner_scope, null, state, is_closure = true)
state.module.imported.add(node.tpe.type_name)
if insert_temporary {
let ctx = state.alloca(typechecking::reference(null), loc)
state.store(ctx, closure, loc)
create_temporary(ctx, closure, loc, state)
}

let context_ptr = state.bitcast(pointer(context_tpe), context_ptr_i8, loc)
return context_ptr
}

def create_closure_context_captures(function: &Function, loc: &Value, state: &State) -> Value {
let context_tpe = function.state
var context = [ kind = ValueKind::UNDEF, tpe = context_tpe ] !Value

for var i in 0..function.captures.length {
Expand All @@ -5854,9 +5903,7 @@ def walk_Def(node: &parser::Node, state: &State) {
}
context = state.insert_value(context_tpe, context, value, [i], loc)
}
state.store(context_ptr, context, loc)

push_local_var(context_tpe.type_name, reference(context_tpe), state.current_function)
return context
}

def walk_Defer(node: &parser::Node, state: &State) {
Expand Down Expand Up @@ -7180,7 +7227,8 @@ export def create_function(
block: &Block,
state: &State,
no_cleanup: bool = false,
is_closure: bool = false
is_closure: bool = false,
params: &Vector(&Node) = null
) {
if not tpe { return }
let function = state.module.result.functions.get_or_default(tpe.type_name, null)
Expand All @@ -7203,9 +7251,9 @@ export def create_function(
state.inline_start_block = block

function.locals = map::make(Str, type &typechecking::Type)
if node.value.def_.params {
for var i in 0..vector::length(node.value.def_.params) {
let param = node.value.def_.params(i).value.param.name
if params {
for var i in 0..vector::length(params) {
let param = params(i).value.param.name
if not param or not param.svalue { continue }
state.add_local(function, param.svalue, param.svalue.tpe)
}
Expand Down Expand Up @@ -7264,10 +7312,10 @@ export def create_function(
add_type_meta(tpe.parameter_t(i).tpe, state)
}

if node and node.value.def_.params {
if params {
errors::current_signature = node.signature_hash
for var i in 0..vector::length(node.value.def_.params) {
let param = node.value.def_.params(i).value.param.name
for var i in 0..vector::length(params) {
let param = params(i).value.param.name
if not param { continue }
scope::create_dependency(state.current_value(), param.svalue)
}
Expand Down Expand Up @@ -9090,6 +9138,8 @@ def do_create_type(tpe: &typechecking::Type, svalue: &scope::Value, module: &too
value.values(8) = [ kind = ValueKind::STRUCT, tpe = array(pointer(builtins::Type_)), values = array_values ] !Value

push_variants(tpe, global, module, state, cache)
case typechecking::TypeKind::TO_INFER
// This should only happen in an error case
case
error(tpe.kind, "\n")
assert(false)
Expand Down
27 changes: 26 additions & 1 deletion src/consteval.pr
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,28 @@ def unwrap_type_def(tpe: &typechecking::Type) -> &typechecking::Type {
return res
}

def walk_Lambda(node: &parser::Node, state: &typechecking::State) {
let body = node.value.lambda.body
node.inner_scope = scope::enter_function_scope(state.module.scope)

// Add return for single value
if body.length == 1 {
let expr = body(0)
if not expr { return }

switch expr.kind {
case parser::NodeKind::INTEGER..=parser::NodeKind::IDENTIFIER, parser::NodeKind::DEFINED,
parser::NodeKind::RANGE..=parser::NodeKind::SHR_EQ,
parser::NodeKind::ASSIGN, parser::NodeKind::FUNC_CALL, parser::NodeKind::LAMBDA

let n = [ kind = parser::NodeKind::RETURN] !&Node
n.value.body = vector::make(type &Node)
n.value.body.push(body(0))
body(0) = n
}
}
}

export def walk_Def(node: &parser::Node, state: &typechecking::State) {
let share = node.value.def_.share
let name = node.value.def_.name
Expand Down Expand Up @@ -793,7 +815,7 @@ export def compile_function(value: &scope::Value, context: &scope::Scope, argume
// This is a big ugly but what can we do
let debug = toolchain::debug_sym
toolchain::debug_sym = false
compiler::create_function(node, node.tpe, node.value.def_.body, node.inner_scope, null, compiler_state)
compiler::create_function(node, node.tpe, node.value.def_.body, node.inner_scope, null, compiler_state, params = node.value.def_.params)
toolchain::debug_sym = debug

if function.defer_functions {
Expand Down Expand Up @@ -1096,6 +1118,8 @@ def do_walk(node: &parser::Node, state: &typechecking::State) {
walk_Assert(node, state)
case parser::NodeKind::FROM
walk_From(node, state)
case parser::NodeKind::LAMBDA
walk_Lambda(node, state)
}
}

Expand Down Expand Up @@ -1154,6 +1178,7 @@ export def consteval(state: &typechecking::State) {
compiler::predeclare_functions(state.module)

let function = [
unmangled = "__main",
is_global = true,
locals = map::make(type &typechecking::Type)
] !&compiler::Function
Expand Down
7 changes: 7 additions & 0 deletions src/debug.pr
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ export def node_to_json(node: &parser::Node, types: bool = false) -> &Json {
case parser::NodeKind::STAR
res = json::make_object()
res("kind") = "Star"
case parser::NodeKind::LAMBDA
res = json::make_object()
res("kind") = "Lambda"
res("parameters") = node_vec_to_json(node.value.lambda.parameters, types)
res("body") = node_vec_to_json(node.value.lambda.body, types)
case
error(node.kind, "\n")
assert
Expand Down Expand Up @@ -876,6 +881,8 @@ export def type_to_str(tpe: &typechecking::Type, full_name: bool = false) -> Str
return variant_t_to_string(tpe, full_name)
case typechecking::TypeKind::INTERFACE_IMPL
return type_to_str(tpe.tpe, full_name) + "/" + type_to_str(tpe.intf, full_name) + "#" + tpe.module.module
case typechecking::TypeKind::TO_INFER
return "<?>"
case
error(tpe.kind, "\n")
assert
Expand Down
45 changes: 34 additions & 11 deletions src/lexer.pr
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import util
import json
import vector

// Reserve 0 for empty token
export type TokenType = enum {
Expand Down Expand Up @@ -111,6 +112,12 @@ export type TokenType = enum {
EOF
}

export type Brace = enum {
PAREN
SQUARE
BRACE
}

export type TokenValue = struct #union {
str: StringSlice
ch: char
Expand Down Expand Up @@ -779,10 +786,6 @@ def parse_symbol(s: Str, i: *int, line: *int, column: *int) -> Token {
switch first {
case '>'; tt = TokenType::OP_GT
case '<'; tt = TokenType::OP_LT
case '{'; tt = TokenType::O_BRACE
case '}'; tt = TokenType::C_BRACE
case '['; tt = TokenType::O_SQUARE
case ']'; tt = TokenType::C_SQUARE
case '+'; tt = TokenType::OP_ADD
case '-'; tt = TokenType::OP_SUB
case '*'; tt = TokenType::OP_MUL
Expand Down Expand Up @@ -819,12 +822,12 @@ def is_whitespace(c: char) -> bool {
return c == ' ' or c == '\t' or c == '\r'
}

def parse_whitespace(depth: int, s: Str, i: *int, line: *int, column: *int) -> Token {
def parse_whitespace(brace_stack: &Vector(Brace), s: Str, i: *int, line: *int, column: *int) -> Token {
let start_line = @line
let start_column = @column

var c = peek_char(s, i, 0)
while is_whitespace(c) or (c == '\n' and depth > 0) {
while is_whitespace(c) or (c == '\n' and brace_stack.length > 0 and brace_stack.peek() == Brace::PAREN) {
var is_newline = c == '\n'
c = next_char(s, i, line, column)
if is_newline {
Expand All @@ -839,7 +842,7 @@ export def lex(s: Str, line: int = 0, column: int = 0, end_line: int = MAX_INT32

var token_list = zero_allocate(TokenList)
var head = token_list
var depth = 0
var brace_stack = vector::make(Brace)

var i = 0
var start_column = 0
Expand All @@ -863,24 +866,44 @@ export def lex(s: Str, line: int = 0, column: int = 0, end_line: int = MAX_INT32
let c = peek_char(s, *i, 0)

var token: Token
if is_whitespace(c) or c == '\n' and depth > 0 {
if is_whitespace(c) or c == '\n' and brace_stack.length > 0 and brace_stack.peek() == Brace::PAREN {
// TODO Make this work inside {}
token = parse_whitespace(depth, s, *i, *line, *column)
token = parse_whitespace(brace_stack, s, *i, *line, *column)
} else if c == '\n' {
token = simple_token(TokenType::NEW_LINE, line, column, line, column + 1)
column = 0
line += 1
i += 1
} else if c == '(' {
depth += 1
brace_stack.push(Brace::PAREN)
token = simple_token(TokenType::O_PAREN, line, column, line, column + 1)
i += 1
column += 1
} else if c == '[' {
brace_stack.push(Brace::SQUARE)
token = simple_token(TokenType::O_SQUARE, line, column, line, column + 1)
i += 1
column += 1
} else if c == '{' {
brace_stack.push(Brace::BRACE)
token = simple_token(TokenType::O_BRACE, line, column, line, column + 1)
i += 1
column += 1
} else if c == ')' {
depth -= 1
if brace_stack.length > 0 { brace_stack.pop() }
token = simple_token(TokenType::C_PAREN, line, column, line, column + 1)
i += 1
column += 1
} else if c == ']' {
if brace_stack.length > 0 { brace_stack.pop() }
token = simple_token(TokenType::C_SQUARE, line, column, line, column + 1)
i += 1
column += 1
} else if c == '}' {
if brace_stack.length > 0 { brace_stack.pop() }
token = simple_token(TokenType::C_BRACE, line, column, line, column + 1)
i += 1
column += 1
} else if c == '"' {
var triple_quoted = false
if peek_char(s, *i, 1) == '"' and peek_char(s, *i, 2) == '"' {
Expand Down
Loading

0 comments on commit 650b265

Please sign in to comment.