diff --git a/src/codegen.pr b/src/codegen.pr index 6994729..f8ae454 100644 --- a/src/codegen.pr +++ b/src/codegen.pr @@ -796,10 +796,6 @@ export def gen(module: &toolchain::Module) { gen_header(fp, module) if module.module == "main" { - module.imported.add("malloc") - module.imported.add("free") - module.imported.add("strlen") - gen_main_function(fp) } diff --git a/src/compiler.pr b/src/compiler.pr index c4c1879..a561bb9 100644 --- a/src/compiler.pr +++ b/src/compiler.pr @@ -741,6 +741,7 @@ def meta_to_debug_value(meta: &Value) -> DebugValue { } export def make_location(node: &parser::Node, state: &State) -> &Value { + if not node { return null } let discope = vector::peek(state.discope) if state.discope.length > 0 else null !&Value if not toolchain::debug_sym { return null } @@ -9533,19 +9534,26 @@ def generate_vtable_function(function: &Function, tpe: &typechecking::Type, stat if is_const { state.ret(@const_field.value) } else { - var deref = state.extract_value(pointer(type_entry.tpe.tpe), reference, [1]) - var value = state.load(type_entry.tpe.tpe, deref) - var findex: size_t = 0 - var ftpe: &typechecking::Type - for var field in @type_entry.tpe.tpe.fields { - if field.name == name { - findex = field.index - ftpe = field.tpe - } + var stpe = type_entry.tpe.tpe + // Resolve member + var deref = state.extract_value(pointer(stpe), reference, [1]) + var value = state.load(stpe, deref) + value.addr = deref + + let vec = vector::make(Member) + if not resolve_member(vec, stpe, name) { + return + } + + let len = vector::length(vec) + for var i in 0..len { + let j = len - i - 1 + let member = vec(j) + value = walk_MemberAccess_struct(null, stpe, member, value, state) + stpe = member.tpe } - value = state.extract_value(ftpe, value, [findex !int]) - state.ret(value) + state.ret(load_value(value, null, state)) } } } @@ -9555,7 +9563,7 @@ def generate_vtable_function(function: &Function, tpe: &typechecking::Type, stat push_label(end_label, state) swtch.value.switch_.otherwise = end_label - state.module.imported.add("abort") + import_cstd_function("abort", state) state.call("abort", null, [] ![Value]) state.module.imported.add(function.name) @@ -9621,6 +9629,11 @@ export def compile(module: &toolchain::Module) { export def compile(state: &State, is_main: bool, no_cleanup: bool = false) { toolchain::progress_update(state.module, toolchain::ProgressUpdate::START) + // Import required functions + import_cstd_function("malloc", state) + import_cstd_function("free", state) + import_cstd_function("strlen", state) + let node = state.module.node assert(node.kind == parser::NodeKind::PROGRAM)