From a55ed3aff1347f5290f2b05fea0a1be1472fccff Mon Sep 17 00:00:00 2001
From: Vic Nightfall <vic@nightfall.moe>
Date: Fri, 12 Jan 2024 13:26:56 +0100
Subject: [PATCH] Tuples

---
 src/codegen.pr      |   2 +-
 src/debug.pr        |   4 +-
 src/parser.pr       |  65 +++++++++++++++++++++----
 src/serialize.pr    |  12 +++++
 src/typechecking.pr | 116 +++++++++++++++++++++++++++++---------------
 5 files changed, 148 insertions(+), 51 deletions(-)

diff --git a/src/codegen.pr b/src/codegen.pr
index 4413a2c7..acdf3976 100644
--- a/src/codegen.pr
+++ b/src/codegen.pr
@@ -58,7 +58,7 @@ def type_to_str(tpe: &typechecking::Type) -> Str {
         case typechecking::TypeKind::STATIC_ARRAY:
             assert tpe.length !uint64 < std::MAX_UINT64
             ret = "[" + tpe.length + " x " + type_to_str(tpe.tpe) + ']'
-        case typechecking::TypeKind::STRUCT, typechecking::TypeKind::CLOSURE:
+        case typechecking::TypeKind::STRUCT, typechecking::TypeKind::CLOSURE, typechecking::TypeKind::TUPLE:
             if tpe.type_name {
                 ret = "%\"" + tpe.type_name + '"'
             } else {
diff --git a/src/debug.pr b/src/debug.pr
index 97e29a1a..684db1b7 100644
--- a/src/debug.pr
+++ b/src/debug.pr
@@ -734,7 +734,7 @@ def static_array_t_to_string(tpe: &typechecking::Type, full_name: bool) -> Str {
 }
 
 def tuple_t_to_string(tpe: &typechecking::Type, full_name: bool) -> Str {
-    var ret: StringBuffer = "("
+    var ret: StringBuffer = "["
     let len = vector::length(tpe.return_t)
     for var i in 0..len {
         let rtpe = tpe.return_t[i]
@@ -743,7 +743,7 @@ def tuple_t_to_string(tpe: &typechecking::Type, full_name: bool) -> Str {
             ret += ", "
         }
     }
-    ret += ')'
+    ret += ']'
     return ret
 }
 
diff --git a/src/parser.pr b/src/parser.pr
index 82cb2d2d..67a9f7ac 100644
--- a/src/parser.pr
+++ b/src/parser.pr
@@ -120,6 +120,7 @@ export type NodeKind = enum {
     STRUCTURAL_T_MEMBER
     TYPE_CONSTRUCTOR
     VARIANT_T
+    TUPLE_T // Uses variant_t
 }
 
 export type ShareMarker = enum {
@@ -569,7 +570,7 @@ export def destruct(node: *Node) {
             __destruct__(*node.value.t_arrs)
         case NodeKind::STRUCTURAL_T_MEMBER:
             __destruct__(*node.value.structural_member)
-        case NodeKind::VARIANT_T:
+        case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
             __destruct__(*node.value.t_variant)
         case NodeKind::ERROR, NodeKind::DEFINED, NodeKind::SIZE_OF, NodeKind::ALIGN_OF, NodeKind::TYPE_OF_T,
             NodeKind::UADD..=NodeKind::NOT, NodeKind::ID_ASSIGN, NodeKind::UNSIGNED_T, NodeKind::TYPE_T, NodeKind::YIELD_FROM:
@@ -682,7 +683,7 @@ export def construct(copy: *Node, node: *Node) {
             copy.value.t_parr = node.value.t_parr
         case NodeKind::ARRAY_STATIC_T:
             copy.value.t_arrs = node.value.t_arrs
-        case NodeKind::VARIANT_T:
+        case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
             copy.value.t_variant = node.value.t_variant
         case NodeKind::STRUCTURAL_T_MEMBER:
             copy.value.structural_member = node.value.structural_member
@@ -825,7 +826,7 @@ export def offset(node: &Node, changes: &[server::TextDocumentChangeEvent]) {
             offset(node.value.t_parr.tpe, changes)
         case NodeKind::ARRAY_STATIC_T:
             offset(node.value.t_arrs.tpe, changes)
-        case NodeKind::VARIANT_T:
+        case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
             offset(node.value.t_variant.variants, changes)
         case NodeKind::STRUCTURAL_T_MEMBER:
             offset(node.value.structural_member.name, changes)
@@ -956,7 +957,7 @@ export def clear(node: &Node) {
             clear(node.value.t_parr.tpe)
         case NodeKind::ARRAY_STATIC_T:
             clear(node.value.t_arrs.tpe)
-        case NodeKind::VARIANT_T:
+        case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
             clear(node.value.t_variant.variants)
         case NodeKind::STRUCTURAL_T_MEMBER:
             clear(node.value.structural_member.name)
@@ -1158,7 +1159,7 @@ export def find(node: &Node, line: int, column: int) -> &Node {
                 if n2 { return n2 }
                 n2 = find(node.value.t_arrs.tpe, line, column)
                 if n2 { return n2 }
-            case NodeKind::VARIANT_T:
+            case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
                 var n2 = find(node.value.t_variant.variants, line, column)
                 if n2 { return n2 }
             case NodeKind::STRUCTURAL_T_MEMBER:
@@ -1364,7 +1365,7 @@ export def deep_copy_node(node: &Node, clear_svalue: bool = true) -> &Node {
             copy.value.t_parr.tpe = deep_copy_node(node.value.t_parr.tpe, clear_svalue)
         case NodeKind::ARRAY_STATIC_T:
             copy.value.t_arrs.tpe = deep_copy_node(node.value.t_arrs.tpe, clear_svalue)
-        case NodeKind::VARIANT_T:
+        case NodeKind::VARIANT_T, NodeKind::TUPLE_T:
             copy.value.t_variant.variants = deep_copy_vector_of_nodes(node.value.t_variant.variants, clear_svalue)
         case NodeKind::STRUCTURAL_T_MEMBER:
             copy.value.structural_member.name = deep_copy_node(node.value.structural_member.name, clear_svalue)
@@ -1718,7 +1719,7 @@ def parse_array_n(parse_state: &ParseState) -> &Node {
     return node
 }
 
-def expect_array(parse_state: &ParseState) -> &Node {
+def expect_array_or_tuple(parse_state: &ParseState) -> &Node {
     var tok = peek(parse_state)
     let line = tok.line
     let column = tok.column
@@ -1733,18 +1734,62 @@ def expect_array(parse_state: &ParseState) -> &Node {
 
     tok = expect(parse_state, lexer::TokenType::O_SQUARE, "Expected '['")
 
+    tok = peek(parse_state)
+    if tok.tpe == lexer::TokenType::C_SQUARE {
+        pop(parse_state)
+
+        node = make_node(NodeKind::TUPLE_T, line, column, parse_state)
+        node.value.t_variant = {
+            variants = vector::make(type &Node)
+        } !NodeVariantT
+        node._hash = node.kind !uint64
+
+        return node
+    }
+
     // [let T], [var T] and [T]
 
+    var may_be_tuple = true
+
     var kw = VarDecl::VAR
     tok = peek(parse_state)
     if tok.tpe == lexer::TokenType::K_VAR {
         pop(parse_state)
+        may_be_tuple = false
     } else if tok.tpe == lexer::TokenType::K_LET {
         pop(parse_state)
         kw = VarDecl::LET
+        may_be_tuple = false
     }
 
-    let tpe = expect_type(parse_state)
+    var tpe = expect_type(parse_state)
+    tok = peek(parse_state)
+
+    if may_be_tuple {
+        skip_newline(parse_state)
+        if tok.tpe == lexer::TokenType::COMMA {
+            let variants = vector::make(type &Node)
+
+            while tok.tpe == lexer::TokenType::COMMA {
+                variants.push(tpe)
+                pop(parse_state)
+
+                tpe = expect_type(parse_state)
+                skip_newline(parse_state)
+                tok = peek(parse_state)
+            }
+
+            tok = expect(parse_state, lexer::TokenType::C_SQUARE, "Expected ']'")
+            
+            node = make_node(NodeKind::TUPLE_T, line, column, parse_state)
+            node.value.t_variant = {
+                variants = variants
+            } !NodeVariantT
+            node._hash = combine_hashes(node.kind !uint64, hash(variants))
+            
+            return node
+        }
+    }
 
     tok = expect(parse_state, lexer::TokenType::C_SQUARE, "Expected ']'")
 
@@ -2288,7 +2333,7 @@ def parse_type2(parse_state: &ParseState, inline_types: bool) -> &Node {
         return node
     } else if tok.tpe == lexer::TokenType::O_SQUARE {
         back(parse_state)
-        return expect_array(parse_state)
+        return expect_array_or_tuple(parse_state)
     } else if tok.tpe == lexer::TokenType::OP_MUL or
         tok.tpe == lexer::TokenType::OP_BAND {
         back(parse_state)
@@ -2411,7 +2456,7 @@ def expect_array_lit(parse_state: &ParseState) -> &Node {
     // [var N]
     if token.tpe == lexer::TokenType::K_VAR or token.tpe == lexer::TokenType::K_LET {
         parse_state.tokens = tokens
-        return expect_array(parse_state)
+        return expect_array_or_tuple(parse_state)
     }
 
     if token.tpe != lexer::TokenType::C_SQUARE {
diff --git a/src/serialize.pr b/src/serialize.pr
index 15451af2..1e958bc5 100644
--- a/src/serialize.pr
+++ b/src/serialize.pr
@@ -280,6 +280,11 @@ def serialize_type(fp: File, tpe: &typechecking::Type, state: &Serialize) {
         } else if tpe.kind == typechecking::TypeKind::INTERFACE_IMPL {
             write_type(fp, tpe.tpe, state)
             write_type(fp, tpe.intf, state)
+        } else if tpe.kind == typechecking::TypeKind::TUPLE {
+            fp.write(*tpe.return_t.length)
+            for var e in tpe.return_t {
+                write_type(fp, e, state)
+            }
         } else {
             error(tpe.kind, "\n")
             assert
@@ -664,6 +669,13 @@ def deserialize_type(deserialize: &Deserialize, fp: File, tpe: &typechecking::Ty
         case typechecking::TypeKind::INTERFACE_IMPL:
             tpe._tpe = deserialize_type(deserialize, fp)
             tpe.intf = deserialize_type(deserialize, fp)
+        case typechecking::TypeKind::TUPLE:
+            tpe.return_t = vector::make(type &typechecking::Type)
+            var size: uint64
+            fp.read(*size)
+            for var i in 0..size {
+                tpe.return_t.push(deserialize_type(deserialize, fp))
+            }
         case:
             error(tpe.kind, "\n")
             assert
diff --git a/src/typechecking.pr b/src/typechecking.pr
index f8498cee..a2f15773 100644
--- a/src/typechecking.pr
+++ b/src/typechecking.pr
@@ -697,6 +697,14 @@ export def append_module(name: Str, module: Str) -> Str {
     return res
 }
 
+export def make_type_type(tpe: &Type) -> &Type {
+    return {
+        line = -1,
+        kind = TypeKind::TYPE,
+        _tpe = tpe
+    } !Type
+}
+
 export def make_type_raw(kind: TypeKind) -> &Type {
     return {
         line = -1,
@@ -1697,8 +1705,13 @@ export def is_polymorph(tpe: &Type, is_ref: bool = false) -> bool {
         return is_polymorph(tpe.tpe)
     } else if tpe.kind == TypeKind::REFERENCE {
         return is_polymorph(tpe.tpe, true)
-    } else if tpe.kind == TypeKind::FUNCTION or tpe.kind == TypeKind::TUPLE {
+    } else if tpe.kind == TypeKind::FUNCTION {
         return is_polymorph(tpe.parameter_t)
+    } else if tpe.kind == TypeKind::TUPLE {
+        for var tpe in tpe.return_t {
+            if is_polymorph(tpe) { return true }
+        }
+        return false
     }
     return false
 }
@@ -1948,7 +1961,14 @@ export def overload_score(
             // TODO This is not documented and weird in more than one place
             // I think the confusion comes from the fact that sometimes the type is stored in the NamedParameter
             // and sometimes in the type directly. The latter would make more sense
-            if equals(right.tpe, pointer(builtins::Type_)) {
+            if right.tpe.may_be_type {
+                // This is what it should look like for the others as well
+                if not left.tpe.tpe {
+                    score = 4
+                } else if right.tpe.may_be_type.tpe.tpe and equals(left.tpe.tpe, right.tpe.may_be_type.tpe.tpe) {
+                    score = 0
+                }
+            } else if equals(right.tpe, pointer(builtins::Type_)) {
                 if left.value {
                     if equals(left.value.value_tpe, right.tpe.tpe.tpe) {
                         score = 0
@@ -1968,13 +1988,6 @@ export def overload_score(
                 } else if left.tpe.tpe and right.tpe.tpe and equals(left.tpe.tpe, right.tpe.tpe) {
                     score = 0
                 }
-            } else if right.tpe.may_be_type {
-                // This is what it should look like for the others as well
-                if not left.tpe.tpe {
-                    score = 4
-                } else if right.tpe.may_be_type.tpe and equals(left.tpe.tpe, right.tpe.may_be_type.tpe) {
-                    score = 0
-                }
             }
         } else {
             score = convert_type_score(lvalue, right.tpe, module, impl = impl)
@@ -2307,31 +2320,36 @@ export def type_lookup(node: &parser::Node, state: &State, current_type: &Type =
     return tpe
 }
 
-def convert_ambiguous_expr_to_type(node: &parser::Node, state: &State) -> &Type {
+def convert_ambiguous_expr_to_type(node: &parser::Node, state: &State) -> &Node {
     if node.kind == parser::NodeKind::ARRAY_LIT {
-        var res = { kind = parser::NodeKind::ARRAY_T } !&parser::Node
-        res.value.t_parr.tpe = node.value.body[0]
-        res.value.t_parr.tpe.tpe = convert_ambiguous_expr_to_type(res.value.t_parr.tpe, state)
-        res.tpe = type_lookup(res, state)
-        let arr_tpe = make_type_raw(TypeKind::TYPE)
-        arr_tpe._tpe = res.tpe
-        res.tpe = arr_tpe
-        res.tpe.node = null
+        if node.value.body.length == 1 {
+            var res = { kind = parser::NodeKind::ARRAY_T } !&parser::Node
+            res.value.t_parr.tpe = convert_ambiguous_expr_to_type(node.value.body[0], state)
+            walk(null, res, state)
+            res.tpe.node = null // This may be needed to avoid a dangling weak reference
+            return res
+        } else {
+            var res = { kind = parser::NodeKind::TUPLE_T } !&parser::Node
+            var types = vector::make(type &parser::Node)
+            for var node in node.value.body {
+                types.push(convert_ambiguous_expr_to_type(node, state))
+            }
+            res.value.body = types
 
-        return res.tpe
+            walk(null, res, state)
+            res.tpe.node = null
+            return res
+        }
     } else if node.kind == parser::NodeKind::PTR {
         var res = { kind = parser::NodeKind::PTR_T } !&parser::Node
-        res.value.t_parr.tpe = node.value.expr
-        res.value.t_parr.tpe.tpe = convert_ambiguous_expr_to_type(res.value.t_parr.tpe, state)
-        let ptr_tpe = make_type_raw(TypeKind::TYPE)
-        res.tpe = type_lookup(res, state)
-        ptr_tpe._tpe = res.tpe
-        res.tpe = ptr_tpe
-        res.tpe.node = null
+        res.value.t_parr.tpe = convert_ambiguous_expr_to_type(node.value.expr, state)
 
-        return res.tpe
+        walk(null, res, state)
+        res.tpe.node = null
+        return res
     } else if node.kind == parser::NodeKind::IDENTIFIER {
-        return type_lookup(node, state)
+        type_lookup(node, state)
+        return node
     }
 }
 
@@ -2903,6 +2921,29 @@ export def do_type_lookup(node: &parser::Node, state: &State, current_type: &Typ
         tpe.variants = set::make(variants)
         tpe.line = node.loc.line
 
+        return tpe
+    } else if node.kind == parser::NodeKind::TUPLE_T {
+        let tpe = make_type_raw(TypeKind::TUPLE)
+
+        let fields = vector::make(StructMember)
+        let return_t = vector::make(type &Type)
+        for var i in 0..node.value.body.length {
+            var n = node.value.body[i]
+            var cur: &Type = null
+            if current_type { cur = current_type.field_types[i].tpe }
+
+            let e = type_lookup(n, state, cur, lookup_default, cache)
+            return_t.push(e)
+            fields.push({ tpe = e } !StructMember)
+        }
+
+
+        let struct_type = make_struct_type(fields.to_array()) // TODO Maybe a have a seperate function to calculate align and size
+        tpe.return_t = return_t
+        tpe.align = struct_type.align
+        tpe.size = struct_type.size
+        tpe.line = node.loc.line
+
         return tpe
     }
 
@@ -2961,7 +3002,7 @@ def walk_ArrayStaticT(node: &parser::Node, state: &State) {
         size = value.i
     }
 
-    let tpe2 = copy(builtins::Type_)
+    let tpe2 = copy(builtins::Type_) // TODO Use TypeKind::Type
     tpe2._tpe = type_lookup(node, state)
     node.tpe = pointer(tpe2)
 }
@@ -2969,7 +3010,7 @@ def walk_ArrayStaticT(node: &parser::Node, state: &State) {
 def walk_TypeOfT(node: &parser::Node, state: &State) {
     let expr = node.value.expr
     walk(node, expr, state)
-    let tpe = copy(builtins::Type_)
+    let tpe = copy(builtins::Type_) // TODO Use TypeKind::Type
     tpe._tpe = expr.tpe
     node.tpe = pointer(tpe)
 }
@@ -3091,7 +3132,7 @@ def walk_Identifier(node: &parser::Node, state: &State) {
     scope::add_reference(value, node)
 
     if value.tpe and value.tpe.kind == TypeKind::TYPE {
-        let tpe = copy(builtins::Type_)
+        let tpe = copy(builtins::Type_) // TODO Use TypeKind::Type
         tpe._tpe = value.value.value_tpe
         node.tpe = pointer(tpe)
     } else {
@@ -3127,7 +3168,7 @@ def walk_Identifier(node: &parser::Node, state: &State) {
 def implicit_conversion(node: &parser::Node, tpe: &Type, state: &State) {
     if not tpe { return }
     // Convert type
-    if tpe.kind == TypeKind::TYPE and node.tpe and node.tpe.may_be_type {
+    if node.tpe and node.tpe.may_be_type {
         node.tpe = node.tpe.may_be_type
     } else if node.kind == parser::NodeKind::NULL and is_pointer(tpe) or
         is_arithmetic(tpe) and (node.kind == parser::NodeKind::INTEGER or
@@ -4914,7 +4955,7 @@ def walk_Ptr(node: &parser::Node, state: &State) {
     node.tpe = pointer(tpe, tpe.kw)
 
     if equals(tpe, pointer(builtins::Type_)) or tpe.may_be_type {
-        node.tpe.may_be_type = convert_ambiguous_expr_to_type(node, state)
+        node.tpe.may_be_type = convert_ambiguous_expr_to_type(node, state).tpe
     }
 }
 
@@ -5263,7 +5304,7 @@ def walk_ArrayLit(node: &parser::Node, state: &State) {
 
     // Check if we have a type array, in this case it might also be a type
     if equals(tpe, pointer(builtins::Type_)) or tpe.may_be_type {
-        ret_tpe.may_be_type = convert_ambiguous_expr_to_type(node, state)
+        ret_tpe.may_be_type = convert_ambiguous_expr_to_type(node, state).tpe
     }
 
     node.tpe = ret_tpe
@@ -5502,8 +5543,8 @@ export def walk(parent: &parser::Node, node: &parser::Node, state: &State) {
         case parser::NodeKind::PTR_T, parser::NodeKind::REF_T, 
             parser::NodeKind::ARRAY_T, parser::NodeKind::WEAK_REF_T,
             parser::NodeKind::TYPE_CONSTRUCTOR, parser::NodeKind::FUNCTION_T,
-            parser::NodeKind::CLOSURE_T:
-            let tpe = copy(builtins::Type_)
+            parser::NodeKind::CLOSURE_T, parser::NodeKind::TUPLE_T:
+            let tpe = copy(builtins::Type_) // TODO use TypeKind::TYPE
             tpe._tpe = type_lookup(node, state)
             node.tpe = pointer(tpe)
         case parser::NodeKind::ARRAY_STATIC_T:
@@ -5648,8 +5689,7 @@ export def walk_Def_with_type_argument(node: &parser::Node, parameter_t: &Vector
         if left and equals(left.tpe, builtins::type_) {
             if np.tpe and np.tpe.may_be_type {
                 np._tpe = np.tpe.may_be_type
-                left.value = { kind = compiler::ValueKind::TYPE, tpe = builtins::type_, value_tpe = np.tpe.tpe } !&compiler::Value
-                left.tpe = np.tpe.may_be_type
+                left.value = { kind = compiler::ValueKind::TYPE, tpe = builtins::type_, value_tpe = np.tpe.tpe.tpe } !&compiler::Value
                 parameter_t[i] = np
             } else {
                 left.value = { kind = compiler::ValueKind::TYPE, tpe = builtins::type_, value_tpe = np.tpe.tpe.tpe } !&compiler::Value