From d885f90a0a88d21713c9f2b1b3d838a850ac6599 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jan 2024 21:41:23 -0300 Subject: [PATCH] big code reorganization: TypeChecker record Several big changes, that were done in tandem, and which would be too troublesome to break into separate commits. The goal here is to ultimately be able to break tl.tl into multiple files (because its size started hitting limits in both Lua 5.1 (number of upvalues) and Lua 5.4 (number of locals). Here's a high-level summary of the changes: * new Errors record, encapsulating error-reporting concerns; * all Type occurrences have unique objects reporting their locations (no more singletons for base types such as BOOLEAN and INVALID); * some enums renamed for more consistency across Gen and Feat options; * TypeCheckOptions and EnvOptions tables reorganized for easier forwarding of options across them; * simplifications in the various function signatures of the public API; * all Types and Nodes store filename, line and column location (`f`, `y`, `x`); * Scope is now a record containing the variables map and unresolved items -- no more "@unresolved" pseudo-variable and `unresolved` pseudo-type for storing this data in the symbols table; * `type_check` now uses a TypeChecker object for storing all state, instead of relying on closures and function nesting (that's a bit sad is it ended up spreading `self:` and extra function arguments everywhere, but I guess state management will be more explicit for others reading the code now...); * all Fact objects have a Where location as well, and supressions of inference data in error messages for widened-back types is marked explicitly with `no_infer` instead of missing a `w` field; * general simplification of the sourcing of error locations (though I would still like to improve that further); --- spec/api/gen_spec.lua | 4 +- spec/api/get_types_spec.lua | 4 +- spec/api/pretty_print_ast.lua | 2 +- spec/call/generic_function_spec.lua | 2 +- spec/cli/types_spec.lua | 20 +- spec/declaration/record_method_spec.lua | 4 +- spec/parser/parser_error_spec.lua | 4 +- spec/parser/parser_spec.lua | 1 + spec/stdlib/require_spec.lua | 2 +- spec/stdlib/xpcall_spec.lua | 2 +- spec/util.lua | 10 +- tl | 10 +- tl.lua | 7604 ++++++++++++----------- tl.tl | 4964 ++++++++------- 14 files changed, 6465 insertions(+), 6168 deletions(-) diff --git a/spec/api/gen_spec.lua b/spec/api/gen_spec.lua index a53538d46..baee93bcc 100644 --- a/spec/api/gen_spec.lua +++ b/spec/api/gen_spec.lua @@ -69,7 +69,7 @@ describe("tl.gen", function() print(math.floor(2)) ]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal('print(math.floor(2))', output) @@ -83,7 +83,7 @@ describe("tl.gen", function() print(math.floor(2))]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal(input, output) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index 26bbf1d05..b6f55ec83 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -8,7 +8,7 @@ describe("tl.get_types", function() local function a() ::continue:: end - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) assert(tr) @@ -25,7 +25,7 @@ describe("tl.get_types", function() end R.f("hello") - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) local y = 6 diff --git a/spec/api/pretty_print_ast.lua b/spec/api/pretty_print_ast.lua index d87d1ea86..d1d149786 100644 --- a/spec/api/pretty_print_ast.lua +++ b/spec/api/pretty_print_ast.lua @@ -4,7 +4,7 @@ local util = require("spec.util") describe("tl.pretty_print_ast", function() it("returns error for attribute on non 5.4 target", function() local input = [[local x = io.open("foobar", "r")]] - local result = tl.process_string(input, false, tl.init_env(false, "off", "5.4"), "foo.tl") + local result = tl.process_string(input, tl.init_env(false, "off", "5.4"), "foo.tl") local output, err = tl.pretty_print_ast(result.ast, "5.3") assert.is_nil(output) diff --git a/spec/call/generic_function_spec.lua b/spec/call/generic_function_spec.lua index 2fb8cf4d6..ec68bb3ff 100644 --- a/spec/call/generic_function_spec.lua +++ b/spec/call/generic_function_spec.lua @@ -370,7 +370,7 @@ describe("generic function", function() recurse_node(ast, visit_node, visit_type) end ]], { - { x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } + { y = 16, x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } })) it("inference trickles down to function arguments, pass", util.check([[ diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index a792f8433..c94b51d7a 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -199,7 +199,6 @@ describe("tl types works like check", function() local by_pos = types.by_pos[next(types.by_pos)] assert(by_pos["1"]) assert(by_pos["1"]["13"]) -- require - assert(by_pos["1"]["20"]) -- ( assert(by_pos["1"]["21"]) -- "os" assert(by_pos["1"]["26"]) -- . end) @@ -217,18 +216,17 @@ describe("tl types works like check", function() assert(types.by_pos) local by_pos = types.by_pos[next(types.by_pos)] assert.same({ - ["19"] = 2, - ["20"] = 5, - ["22"] = 2, - ["39"] = 6, - ["41"] = 2, + ["19"] = 8, + ["22"] = 8, + ["23"] = 6, + ["30"] = 2, + ["41"] = 8, }, by_pos["1"]) assert.same({ - ["17"] = 3, - ["20"] = 4, - ["25"] = 17, - ["30"] = 16, - ["31"] = 2, + ["17"] = 6, + ["20"] = 2, + ["25"] = 9, + ["31"] = 8, }, by_pos["2"]) end) end) diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index 20cbde3dc..7f8cf2db6 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -239,8 +239,8 @@ describe("record method", function() return "hello" end ]], { - { msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, - { msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, + { y = 5, msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, + { y = 6, msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, })) it("allows functions declared on method tables (#27)", function() diff --git a/spec/parser/parser_error_spec.lua b/spec/parser/parser_error_spec.lua index ed50e80c9..cfd2e077c 100644 --- a/spec/parser/parser_error_spec.lua +++ b/spec/parser/parser_error_spec.lua @@ -2,7 +2,7 @@ local tl = require("tl") describe("parser errors", function() it("parse errors include filename", function () - local result = tl.process_string("local x 1", false, nil, "foo.tl") + local result = tl.process_string("local x 1", nil, "foo.tl") assert.same("foo.tl", result.syntax_errors[1].filename, "parse errors should contain .filename property") end) @@ -30,7 +30,7 @@ describe("parser errors", function() local code = [[ local bar = require "bar" ]] - local result = tl.process_string(code, true, nil, "foo.tl") + local result = tl.process_string(code, nil, "foo.tl") assert.is_not_nil(string.match(result.env.loaded["./bar.tl"].syntax_errors[1].filename, "bar.tl$"), "errors should contain .filename property") end) end) diff --git a/spec/parser/parser_spec.lua b/spec/parser/parser_spec.lua index d1e66fb38..870260f90 100644 --- a/spec/parser/parser_spec.lua +++ b/spec/parser/parser_spec.lua @@ -19,6 +19,7 @@ describe("parser", function() assert.same({ kind = "statements", tk = "$EOF$", + f = "", x = 1, y = 1, xend = 5, diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 999cedfd2..17d0d1ba0 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -401,7 +401,7 @@ describe("require", function() local result, err = tl.process("foo.tl") assert.same(0, #result.syntax_errors) - assert.same(0, #result.env.loaded["foo.tl"].type_errors) + assert.same({}, result.env.loaded["foo.tl"].type_errors) assert.same(1, #result.env.loaded["./box.tl"].type_errors) assert.match("cannot use operator ..", result.env.loaded["./box.tl"].type_errors[1].msg) end) diff --git a/spec/stdlib/xpcall_spec.lua b/spec/stdlib/xpcall_spec.lua index 87089f162..16e911a7f 100644 --- a/spec/stdlib/xpcall_spec.lua +++ b/spec/stdlib/xpcall_spec.lua @@ -105,7 +105,7 @@ describe("xpcall", function() { msg = "xyz: got boolean, expected number" } })) - it("type checks the message handler", util.check_type_error([[ + it("#only type checks the message handler", util.check_type_error([[ local function f(a: string, b: number) end diff --git a/spec/util.lua b/spec/util.lua index fb9aeeab3..ccaf59e7f 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -435,7 +435,7 @@ local function check(lax, code, unknowns, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.lua", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) batch:add(assert.same, {}, result.type_errors) if unknowns then @@ -456,7 +456,7 @@ local function check_type_error(lax, code, type_errors, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) local result_type_errors = combine_result(result, "type_errors") batch_compare(batch, "type errors", type_errors, result_type_errors) @@ -525,7 +525,7 @@ function util.check_syntax_error(code, syntax_errors) local batch = batch_assertions() batch_compare(batch, "syntax errors", syntax_errors, errors) batch:assert() - tl.type_check(ast, { filename = "foo.tl", lax = false }) + tl.type_check(ast, "foo.tl", { feat_lax = "off" }) end end @@ -564,7 +564,7 @@ function util.check_types(code, types) local batch = batch_assertions() local env = tl.init_env() env.report_types = true - local result = tl.type_check(ast, { filename = "foo.tl", env = env, lax = false }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = "off" }, env) batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") local tr = env.reporter:get_report() @@ -596,7 +596,7 @@ local function gen(lax, code, expected, gen_target) return function() local ast, syntax_errors = tl.parse(code, "foo.tl") assert.same({}, syntax_errors, "Code was not expected to have syntax errors") - local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target }) assert.same({}, result.type_errors) local output_code = tl.pretty_print_ast(ast) diff --git a/tl b/tl index 8892d1831..d8516b8f5 100755 --- a/tl +++ b/tl @@ -163,10 +163,12 @@ local function setup_env(tlconfig, filename) end local opts = { - lax_mode = lax_mode, - feat_arity = tlconfig["feat_arity"], - gen_compat = tlconfig["gen_compat"], - gen_target = tlconfig["gen_target"], + defaults = { + feat_lax = lax_mode and "on" or "off", + feat_arity = tlconfig["feat_arity"], + gen_compat = tlconfig["gen_compat"], + gen_target = tlconfig["gen_target"], + }, predefined_modules = tlconfig._init_env_modules, } diff --git a/tl.lua b/tl.lua index 47281d4da..f87e65196 100644 --- a/tl.lua +++ b/tl.lua @@ -1,4 +1,4 @@ -local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack +local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table local VERSION = "0.15.3+dev" local stdlib = [=====[ @@ -481,10 +481,16 @@ end -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } +local Errors = {} + + + +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } + + @@ -629,6 +635,7 @@ local TypeReporter = {} + tl.version = function() return VERSION end @@ -699,6 +706,12 @@ tl.typecodes = { +local DEFAULT_GEN_COMPAT = "optional" +local DEFAULT_GEN_TARGET = "5.3" + + + + @@ -1517,7 +1530,6 @@ end - local table_types = { @@ -1552,7 +1564,6 @@ local table_types = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1577,6 +1588,9 @@ local table_types = { +local function is_numeric_type(t) + return t.typename == "number" or t.typename == "integer" +end @@ -1852,14 +1866,12 @@ local table_types = { -local TruthyFact = {} -local NotFact = {} @@ -1868,7 +1880,6 @@ local NotFact = {} -local AndFact = {} @@ -1878,33 +1889,34 @@ local AndFact = {} -local OrFact = {} +local TruthyFact = {} +local NotFact = {} -local EqFact = {} +local AndFact = {} -local IsFact = {} +local OrFact = {} @@ -1914,22 +1926,17 @@ local IsFact = {} +local EqFact = {} -local attributes = { - ["const"] = true, - ["close"] = true, - ["total"] = true, -} -local is_attribute = attributes -local Node = {ExpectedContext = {}, } +local IsFact = {} @@ -1951,6 +1958,15 @@ local Node = {ExpectedContext = {}, } +local attributes = { + ["const"] = true, + ["close"] = true, + ["total"] = true, +} +local is_attribute = attributes + +local Node = {ExpectedContext = {}, } + @@ -2032,9 +2048,6 @@ local Node = {ExpectedContext = {}, } -local function is_number_type(t) - return t.typename == "number" or t.typename == "integer" -end @@ -2051,95 +2064,34 @@ end -local parse_type_list -local parse_expression -local parse_expression_and_tk -local parse_statements -local parse_argument_list -local parse_argument_type_list -local parse_type -local parse_newtype -local parse_interface_name -local parse_enum_body -local parse_record_body -local parse_type_body_fns -local function fail(ps, i, msg) - if not ps.tokens[i] then - local eof = ps.tokens[#ps.tokens] - table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) - return #ps.tokens - end - table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) - return math.min(#ps.tokens, i + 1) -end -local function end_at(node, tk) - node.yend = tk.y - node.xend = tk.x + #tk.tk - 1 -end -local function verify_tk(ps, i, tk) - if ps.tokens[i].tk == tk then - return i + 1 - end - return fail(ps, i, "syntax error, expected '" .. tk .. "'") -end -local function verify_end(ps, i, istart, node) - if ps.tokens[i].tk == "end" then - local endy, endx = ps.tokens[i].y, ps.tokens[i].x - node.yend = endy - node.xend = endx + 2 - if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then - if not ps.end_alignment_hint then - ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } - end - end - return i + 1 - end - end_at(node, ps.tokens[i]) - if ps.end_alignment_hint then - table.insert(ps.errs, ps.end_alignment_hint) - ps.end_alignment_hint = nil - end - return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") -end -local function new_node(tokens, i, kind) - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } -end -local function a_type(typename, t) + +local function a_type(w, typename, t) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function edit_type(t, typename) +local function edit_type(w, t, typename) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function new_type(ps, i, typename) - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - - }) -end -local function new_typedecl(ps, i, def) - local t = new_type(ps, i, "typedecl") - t.def = def - return t -end @@ -2151,20 +2103,28 @@ end +local function a_function(w, t) + assert(t.min_arity) + return a_type(w, "function", t) +end +local function a_vararg(w, t) + local typ = a_type(w, "tuple", { tuple = t }) + typ.is_va = true + return typ +end -local function a_function(t) - assert(t.min_arity) - return a_type("function", t) -end +local function a_nominal(n, names) + return a_type(n, "nominal", { names = names }) +end @@ -2174,16 +2134,63 @@ end +local an_operator +local function shallow_copy_new_type(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy +end +local function shallow_copy_table(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + return copy +end -local function va_args(args) - args.is_va = true - return args +local function clear_redundant_errors(errors) + local redundant = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a, b) + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf or + (af == bf and (a.y < b.y or + (a.y == b.y and (a.x < b.x or + (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end end +local simple_types = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} +do @@ -2191,194 +2198,232 @@ end -local function a_fn(f) - local args_t = a_type("tuple", { tuple = {} }) - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a.opttype then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - local rets_t = a_type("tuple", { tuple = {} }) - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a.typename) - table.insert(tup, a) - end - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - }) -end -local function a_vararg(t) - local typ = a_type("tuple", { tuple = t }) - typ.is_va = true - return typ -end + local parse_type_list + local parse_expression + local parse_expression_and_tk + local parse_statements + local parse_argument_list + local parse_argument_type_list + local parse_type + local parse_newtype + local parse_interface_name + local parse_enum_body + local parse_record_body + local parse_type_body_fns -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_type("map", { keys = ANY, values = ANY }) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) + local function fail(ps, i, msg) + if not ps.tokens[i] then + local eof = ps.tokens[#ps.tokens] + table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) + return #ps.tokens + end + table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) + return math.min(#ps.tokens, i + 1) + end -local function shallow_copy_new_type(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function end_at(node, tk) + node.yend = tk.y + node.xend = tk.x + #tk.tk - 1 end - copy.typeid = new_typeid() - return copy -end -local function shallow_copy_table(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function verify_tk(ps, i, tk) + if ps.tokens[i].tk == tk then + return i + 1 + end + return fail(ps, i, "syntax error, expected '" .. tk .. "'") + end + + local function verify_end(ps, i, istart, node) + if ps.tokens[i].tk == "end" then + local endy, endx = ps.tokens[i].y, ps.tokens[i].x + node.yend = endy + node.xend = endx + 2 + if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then + if not ps.end_alignment_hint then + ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } + end + end + return i + 1 + end + end_at(node, ps.tokens[i]) + if ps.end_alignment_hint then + table.insert(ps.errs, ps.end_alignment_hint) + ps.end_alignment_hint = nil + end + return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end - return copy -end -local function verify_kind(ps, i, kind, node_kind) - if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + local function new_node(ps, i, kind) + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } end - return fail(ps, i, "syntax error, expected " .. kind) -end + local function new_type(ps, i, typename) + local token = ps.tokens[i] + local t = {} + t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y + t.typename = typename + return t + end + local function new_typedecl(ps, i, def) + local t = new_type(ps, i, "typedecl") + t.def = def + return t + end -local function skip(ps, i, skip_fn) - local err_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = {}, - } - return skip_fn(err_ps, i) -end + local function new_tuple(ps, i, types, is_va) + local t = new_type(ps, i, "tuple") + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple + end -local function failskip(ps, i, msg, skip_fn, starti) - local skip_i = skip(ps, starti or i, skip_fn) - fail(ps, i, msg) - return skip_i -end + local function new_typealias(ps, i, alias_to) + local t = new_type(ps, i, "typealias") + t.alias_to = alias_to + return t + end -local function skip_type_body(ps, i) - local tn = ps.tokens[i].tk - i = i + 1 - assert(parse_type_body_fns[tn], tn .. " has no parse body function") - return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) -end + local function new_nominal(ps, i, name) + local t = new_type(ps, i, "nominal") + if name then + t.names = { name } + end + return t + end -local function parse_table_value(ps, i) - local next_word = ps.tokens[i].tk - if next_word == "record" or next_word == "interface" then - local skip_i, e = skip(ps, i, skip_type_body) - if e then - fail(ps, i, next_word == "record" and - "syntax error: this syntax is no longer valid; declare nested record inside a record" or - "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + local function verify_kind(ps, i, kind, node_kind) + if ps.tokens[i].kind == kind then + return i + 1, new_node(ps, i, node_kind) end - elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then - i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return fail(ps, i, "syntax error, expected " .. kind) end - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") + + + local function skip(ps, i, skip_fn) + local err_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = {}, + } + return skip_fn(err_ps, i) end - return i, e -end -local function parse_table_item(ps, i, n) - local node = new_node(ps.tokens, i, "literal_table_item") - if ps.tokens[i].kind == "$EOF$" then - return fail(ps, i, "unexpected eof") + local function failskip(ps, i, msg, skip_fn, starti) + local skip_i = skip(ps, starti or i, skip_fn) + fail(ps, i, msg) + return skip_i end - if ps.tokens[i].tk == "[" then - node.key_parsed = "long" + local function skip_type_body(ps, i) + local tn = ps.tokens[i].tk i = i + 1 - i, node.key = parse_expression_and_tk(ps, i, "]") - i = verify_tk(ps, i, "=") - i, node.value = parse_table_value(ps, i) - return i, node, n - elseif ps.tokens[i].kind == "identifier" then - if ps.tokens[i + 1].tk == "=" then - node.key_parsed = "short" - i, node.key = verify_kind(ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' + assert(parse_type_body_fns[tn], tn .. " has no parse body function") + return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) + end + + local function parse_table_value(ps, i) + local next_word = ps.tokens[i].tk + if next_word == "record" or next_word == "interface" then + local skip_i, e = skip(ps, i, skip_type_body) + if e then + fail(ps, i, next_word == "record" and + "syntax error: this syntax is no longer valid; declare nested record inside a record" or + "syntax error: cannot declare interface inside a table; use a statement") + return skip_i, new_node(ps, i, "error_node") + end + elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then + i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) + return i, new_node(ps, i - 1, "error_node") + end + + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") + end + return i, e + end + + local function parse_table_item(ps, i, n) + local node = new_node(ps, i, "literal_table_item") + if ps.tokens[i].kind == "$EOF$" then + return fail(ps, i, "unexpected eof") + end + + if ps.tokens[i].tk == "[" then + node.key_parsed = "long" + i = i + 1 + i, node.key = parse_expression_and_tk(ps, i, "]") i = verify_tk(ps, i, "=") i, node.value = parse_table_value(ps, i) return i, node, n - elseif ps.tokens[i + 1].tk == ":" then - node.key_parsed = "short" - local orig_i = i - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - i, node.key = verify_kind(try_ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' - i = verify_tk(try_ps, i, ":") - i, node.itemtype = parse_type(try_ps, i) - if node.itemtype and ps.tokens[i].tk == "=" then - i = verify_tk(try_ps, i, "=") - i, node.value = parse_table_value(try_ps, i) - if node.value then - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) + elseif ps.tokens[i].kind == "identifier" then + if ps.tokens[i + 1].tk == "=" then + node.key_parsed = "short" + i, node.key = verify_kind(ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(ps, i, "=") + i, node.value = parse_table_value(ps, i) + return i, node, n + elseif ps.tokens[i + 1].tk == ":" then + node.key_parsed = "short" + local orig_i = i + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + i, node.key = verify_kind(try_ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(try_ps, i, ":") + i, node.itemtype = parse_type(try_ps, i) + if node.itemtype and ps.tokens[i].tk == "=" then + i = verify_tk(try_ps, i, "=") + i, node.value = parse_table_value(try_ps, i) + if node.value then + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + return i, node, n end - return i, node, n end - end - node.itemtype = nil - i = orig_i + node.itemtype = nil + i = orig_i + end end - end - node.key = new_node(ps.tokens, i, "integer") - node.key_parsed = "implicit" - node.key.constnum = n - node.key.tk = tostring(n) - i, node.value = parse_expression(ps, i) - if not node.value then - return fail(ps, i, "expected an expression") + node.key = new_node(ps, i, "integer") + node.key_parsed = "implicit" + node.key.constnum = n + node.key.tk = tostring(n) + i, node.value = parse_expression(ps, i) + if not node.value then + return fail(ps, i, "expected an expression") + end + return i, node, n + 1 end - return i, node, n + 1 -end @@ -2387,786 +2432,772 @@ end -local function parse_list(ps, i, list, close, sep, parse_item) - local n = 1 - while ps.tokens[i].kind ~= "$EOF$" do - if close[ps.tokens[i].tk] then - end_at(list, ps.tokens[i]) - break - end - local item - local oldn = n - i, item, n = parse_item(ps, i, n) - n = n or oldn - table.insert(list, item) - if ps.tokens[i].tk == "," then - i = i + 1 - if sep == "sep" and close[ps.tokens[i].tk] then - fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") - return i, list - end - elseif sep == "term" and ps.tokens[i].tk == ";" then - i = i + 1 - elseif not close[ps.tokens[i].tk] then - local options = {} - for k, _ in pairs(close) do - table.insert(options, "'" .. k .. "'") - end - table.sort(options) - local first = options[1]:sub(2, -2) - local msg - - if first == ")" and ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" - i = failskip(ps, i, msg, parse_expression, i + 1) - else - table.insert(options, "','") - msg = "syntax error, expected one of: " .. table.concat(options, ", ") - fail(ps, i, msg) + local function parse_list(ps, i, list, close, sep, parse_item) + local n = 1 + while ps.tokens[i].kind ~= "$EOF$" do + if close[ps.tokens[i].tk] then + end_at(list, ps.tokens[i]) + break end + local item + local oldn = n + i, item, n = parse_item(ps, i, n) + n = n or oldn + table.insert(list, item) + if ps.tokens[i].tk == "," then + i = i + 1 + if sep == "sep" and close[ps.tokens[i].tk] then + fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") + return i, list + end + elseif sep == "term" and ps.tokens[i].tk == ";" then + i = i + 1 + elseif not close[ps.tokens[i].tk] then + local options = {} + for k, _ in pairs(close) do + table.insert(options, "'" .. k .. "'") + end + table.sort(options) + local first = options[1]:sub(2, -2) + local msg + + if first == ")" and ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + i = failskip(ps, i, msg, parse_expression, i + 1) + else + table.insert(options, "','") + msg = "syntax error, expected one of: " .. table.concat(options, ", ") + fail(ps, i, msg) + end - if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then + if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then - table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) - return i, list + table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) + return i, list + end end end + return i, list end - return i, list -end -local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) - i = verify_tk(ps, i, open) - i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) - i = verify_tk(ps, i, close) - return i, list -end + local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) + i = verify_tk(ps, i, open) + i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) + i = verify_tk(ps, i, close) + return i, list + end -local function parse_table_literal(ps, i) - local node = new_node(ps.tokens, i, "literal_table") - return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) -end + local function parse_table_literal(ps, i) + local node = new_node(ps, i, "literal_table") + return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) + end -local function parse_trying_list(ps, i, list, parse_item) - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - local tryi, item = parse_item(try_ps, i) - if not item then + local function parse_trying_list(ps, i, list, parse_item) + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + local tryi, item = parse_item(try_ps, i) + if not item then + return i, list + end + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + i = tryi + table.insert(list, item) + if ps.tokens[i].tk == "," then + while ps.tokens[i].tk == "," do + i = i + 1 + i, item = parse_item(ps, i) + table.insert(list, item) + end + end return i, list end - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) - end - i = tryi - table.insert(list, item) - if ps.tokens[i].tk == "," then - while ps.tokens[i].tk == "," do + + local function parse_anglebracket_list(ps, i, parse_item) + if ps.tokens[i + 1].tk == ">" then + return fail(ps, i + 1, "type argument list cannot be empty") + end + local types = {} + i = verify_tk(ps, i, "<") + i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) + if ps.tokens[i].tk == ">" then i = i + 1 - i, item = parse_item(ps, i) - table.insert(list, item) + elseif ps.tokens[i].tk == ">>" then + + ps.tokens[i].tk = ">" + else + return fail(ps, i, "syntax error, expected '>'") end + return i, types end - return i, list -end -local function parse_anglebracket_list(ps, i, parse_item) - if ps.tokens[i + 1].tk == ">" then - return fail(ps, i + 1, "type argument list cannot be empty") + local function parse_typearg(ps, i) + local name = ps.tokens[i].tk + local constraint + i = verify_kind(ps, i, "identifier") + if ps.tokens[i].tk == "is" then + i = i + 1 + i, constraint = parse_interface_name(ps, i) + end + local t = new_type(ps, i, "typearg") + t.typearg = name + t.constraint = constraint + return i, t end - local types = {} - i = verify_tk(ps, i, "<") - i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) - if ps.tokens[i].tk == ">" then - i = i + 1 - elseif ps.tokens[i].tk == ">>" then - ps.tokens[i].tk = ">" - else - return fail(ps, i, "syntax error, expected '>'") + local function parse_return_types(ps, i) + local iprev = i - 1 + local t + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end - return i, types -end -local function parse_typearg(ps, i) - local name = ps.tokens[i].tk - local constraint - i = verify_kind(ps, i, "identifier") - if ps.tokens[i].tk == "is" then + local function parse_function_type(ps, i) + local typ = new_type(ps, i, "function") i = i + 1 - i, constraint = parse_interface_name(ps, i) + if ps.tokens[i].tk == "<" then + i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + if ps.tokens[i].tk == "(" then + i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) + i, typ.rets = parse_return_types(ps, i) + else + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + end + return i, typ end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - }) -end - -local function parse_return_types(ps, i) - return parse_type_list(ps, i, "rets") -end -local function parse_function_type(ps, i) - local typ = new_type(ps, i, "function") - i = i + 1 - if ps.tokens[i].tk == "<" then - i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) - i, typ.rets = parse_return_types(ps, i) - else - typ.args = a_vararg({ ANY }) - typ.rets = a_vararg({ ANY }) - end - return i, typ -end + local function parse_simple_type_or_nominal(ps, i) + local tk = ps.tokens[i].tk + local st = simple_types[tk] + if st then + return i + 1, new_type(ps, i, tk) + elseif tk == "table" then + local typ = new_type(ps, i, "map") + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ + end -local simple_types = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} + local typ = new_nominal(ps, i, tk) + i = i + 1 + while ps.tokens[i].tk == "." do + i = i + 1 + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) + i = i + 1 + else + return fail(ps, i, "syntax error, expected identifier") + end + end -local function parse_simple_type_or_nominal(ps, i) - local tk = ps.tokens[i].tk - local st = simple_types[tk] - if st then - return i + 1, st + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) + end + return i, typ end - local typ = new_type(ps, i, "nominal") - typ.names = { tk } - i = i + 1 - while ps.tokens[i].tk == "." do - i = i + 1 + + local function parse_base_type(ps, i) + local tk = ps.tokens[i].tk if ps.tokens[i].kind == "identifier" then - table.insert(typ.names, ps.tokens[i].tk) + return parse_simple_type_or_nominal(ps, i) + elseif tk == "{" then + local istart = i i = i + 1 - else - return fail(ps, i, "syntax error, expected identifier") + local t + i, t = parse_type(ps, i) + if not t then + return i + end + if ps.tokens[i].tk == "}" then + local decl = new_type(ps, istart, "array") + decl.elements = t + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == "," then + local decl = new_type(ps, istart, "tupletable") + decl.types = { t } + local n = 2 + repeat + i = i + 1 + i, decl.types[n] = parse_type(ps, i) + if not decl.types[n] then + break + end + n = n + 1 + until ps.tokens[i].tk ~= "," + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == ":" then + local decl = new_type(ps, istart, "map") + i = i + 1 + decl.keys = t + i, decl.values = parse_type(ps, i) + if not decl.values then + return i + end + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + end + return fail(ps, i, "syntax error; did you forget a '}'?") + elseif tk == "function" then + return parse_function_type(ps, i) + elseif tk == "nil" then + return i + 1, new_type(ps, i, "nil") end + return fail(ps, i, "expected a type") end - if ps.tokens[i].tk == "<" then - i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) - end - return i, typ -end + parse_type = function(ps, i) + if ps.tokens[i].tk == "(" then + i = i + 1 + local t + i, t = parse_type(ps, i) + i = verify_tk(ps, i, ")") + return i, t + end -local function parse_base_type(ps, i) - local tk = ps.tokens[i].tk - if ps.tokens[i].kind == "identifier" then - return parse_simple_type_or_nominal(ps, i) - elseif tk == "{" then + local bt local istart = i - i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then + i, bt = parse_base_type(ps, i) + if not bt then return i end - if ps.tokens[i].tk == "}" then - local decl = new_type(ps, istart, "array") - decl.elements = t - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == "," then - local decl = new_type(ps, istart, "tupletable") - decl.types = { t } - local n = 2 - repeat + if ps.tokens[i].tk == "|" then + local u = new_type(ps, istart, "union") + u.types = { bt } + while ps.tokens[i].tk == "|" do i = i + 1 - i, decl.types[n] = parse_type(ps, i) - if not decl.types[n] then - break + i, bt = parse_base_type(ps, i) + if not bt then + return i end - n = n + 1 - until ps.tokens[i].tk ~= "," - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == ":" then - local decl = new_type(ps, istart, "map") - i = i + 1 - decl.keys = t - i, decl.values = parse_type(ps, i) - if not decl.values then - return i + table.insert(u.types, bt) end - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - end - return fail(ps, i, "syntax error; did you forget a '}'?") - elseif tk == "function" then - return parse_function_type(ps, i) - elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") - typ.keys = ANY - typ.values = ANY - return i + 1, typ - end - return fail(ps, i, "expected a type") -end - -parse_type = function(ps, i) - if ps.tokens[i].tk == "(" then - i = i + 1 - local t - i, t = parse_type(ps, i) - i = verify_tk(ps, i, ")") - return i, t + bt = u + end + return i, bt end - local bt - local istart = i - i, bt = parse_base_type(ps, i) - if not bt then - return i - end - if ps.tokens[i].tk == "|" then - local u = new_type(ps, istart, "union") - u.types = { bt } - while ps.tokens[i].tk == "|" do - i = i + 1 - i, bt = parse_base_type(ps, i) - if not bt then - return i + parse_type_list = function(ps, i, mode) + local t, list = new_tuple(ps, i) + + local first_token = ps.tokens[i].tk + if mode == "rets" or mode == "decltuple" then + if first_token == ":" then + i = i + 1 + else + return i, t end - table.insert(u.types, bt) end - bt = u - end - return i, bt -end -local function new_tuple(ps, i) - local t = new_type(ps, i, "tuple") - t.tuple = {} - return t, t.tuple -end + local optional_paren = false + if ps.tokens[i].tk == "(" then + optional_paren = true + i = i + 1 + end -parse_type_list = function(ps, i, mode) - local t, list = new_tuple(ps, i) + local prev_i = i + i = parse_trying_list(ps, i, list, parse_type) + if i == prev_i and ps.tokens[i].tk ~= ")" then + fail(ps, i - 1, "expected a type list") + end - local first_token = ps.tokens[i].tk - if mode == "rets" or mode == "decltuple" then - if first_token == ":" then + if mode == "rets" and ps.tokens[i].tk == "..." then i = i + 1 - else - return i, t + local nrets = #list + if nrets > 0 then + t.is_va = true + else + fail(ps, i, "unexpected '...'") + end end - end - local optional_paren = false - if ps.tokens[i].tk == "(" then - optional_paren = true - i = i + 1 - end + if optional_paren then + i = verify_tk(ps, i, ")") + end - local prev_i = i - i = parse_trying_list(ps, i, list, parse_type) - if i == prev_i and ps.tokens[i].tk ~= ")" then - fail(ps, i - 1, "expected a type list") + return i, t end - if mode == "rets" and ps.tokens[i].tk == "..." then - i = i + 1 - local nrets = #list - if nrets > 0 then - t.is_va = true - else - fail(ps, i, "unexpected '...'") + local function parse_function_args_rets_body(ps, i, node) + local istart = i - 1 + if ps.tokens[i].tk == "<" then + i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end + i, node.args, node.min_arity = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.body = parse_statements(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + return i, node end - if optional_paren then - i = verify_tk(ps, i, ")") + local function parse_function_value(ps, i) + local node = new_node(ps, i, "function") + i = verify_tk(ps, i, "function") + return parse_function_args_rets_body(ps, i, node) end - return i, t -end - -local function parse_function_args_rets_body(ps, i, node) - local istart = i - 1 - if ps.tokens[i].tk == "<" then - i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - i, node.args, node.min_arity = parse_argument_list(ps, i) - i, node.rets = parse_return_types(ps, i) - i, node.body = parse_statements(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_function_value(ps, i) - local node = new_node(ps.tokens, i, "function") - i = verify_tk(ps, i, "function") - return parse_function_args_rets_body(ps, i, node) -end - -local function unquote(str) - local f = str:sub(1, 1) - if f == '"' or f == "'" then - return str:sub(2, -2), false + local function unquote(str) + local f = str:sub(1, 1) + if f == '"' or f == "'" then + return str:sub(2, -2), false + end + f = str:match("^%[=*%[") + local l = #f + 1 + return str:sub(l, -l), true end - f = str:match("^%[=*%[") - local l = #f + 1 - return str:sub(l, -l), true -end - -local function parse_literal(ps, i) - local tk = ps.tokens[i].tk - local kind = ps.tokens[i].kind - if kind == "identifier" then - return verify_kind(ps, i, "identifier", "variable") - elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") - node.conststr, node.is_longstring = unquote(tk) - return i + 1, node - elseif kind == "number" or kind == "integer" then - local n = tonumber(tk) - local node - i, node = verify_kind(ps, i, kind) - node.constnum = n - return i, node - elseif tk == "true" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "false" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "nil" then - return verify_kind(ps, i, "keyword", "nil") - elseif tk == "function" then - return parse_function_value(ps, i) - elseif tk == "{" then - return parse_table_literal(ps, i) - elseif kind == "..." then - return verify_kind(ps, i, "...") - elseif kind == "$ERR invalid_string$" then - return fail(ps, i, "malformed string") - elseif kind == "$ERR invalid_number$" then - return fail(ps, i, "malformed number") - end - return fail(ps, i, "syntax error") -end -local function node_is_require_call(n) - if n.e1 and n.e2 and - n.e1.kind == "variable" and n.e1.tk == "require" and - n.e2.kind == "expression_list" and #n.e2 == 1 and - n.e2[1].kind == "string" then - - return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" and - n.e1 and n.e1.tk == "pcall" and - n.e2 and #n.e2 == 2 and - n.e2[1].kind == "variable" and n.e2[1].tk == "require" and - n.e2[2].kind == "string" and n.e2[2].conststr then - - return n.e2[2].conststr - else - return nil + local function parse_literal(ps, i) + local tk = ps.tokens[i].tk + local kind = ps.tokens[i].kind + if kind == "identifier" then + return verify_kind(ps, i, "identifier", "variable") + elseif kind == "string" then + local node = new_node(ps, i, "string") + node.conststr, node.is_longstring = unquote(tk) + return i + 1, node + elseif kind == "number" or kind == "integer" then + local n = tonumber(tk) + local node + i, node = verify_kind(ps, i, kind) + node.constnum = n + return i, node + elseif tk == "true" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "false" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "nil" then + return verify_kind(ps, i, "keyword", "nil") + elseif tk == "function" then + return parse_function_value(ps, i) + elseif tk == "{" then + return parse_table_literal(ps, i) + elseif kind == "..." then + return verify_kind(ps, i, "...") + elseif kind == "$ERR invalid_string$" then + return fail(ps, i, "malformed string") + elseif kind == "$ERR invalid_number$" then + return fail(ps, i, "malformed number") + end + return fail(ps, i, "syntax error") + end + + local function node_is_require_call(n) + if n.e1 and n.e2 and + n.e1.kind == "variable" and n.e1.tk == "require" and + n.e2.kind == "expression_list" and #n.e2 == 1 and + n.e2[1].kind == "string" then + + return n.e2[1].conststr + elseif n.op and n.op.op == "@funcall" and + n.e1 and n.e1.tk == "pcall" and + n.e2 and #n.e2 == 2 and + n.e2[1].kind == "variable" and n.e2[1].tk == "require" and + n.e2[2].kind == "string" and n.e2[2].conststr then + + return n.e2[2].conststr + else + return nil + end end -end - -local an_operator -do - local precedences = { - [1] = { - ["not"] = 11, - ["#"] = 11, - ["-"] = 11, - ["~"] = 11, - }, - [2] = { - ["or"] = 1, - ["and"] = 2, - ["is"] = 3, - ["<"] = 3, - [">"] = 3, - ["<="] = 3, - [">="] = 3, - ["~="] = 3, - ["=="] = 3, - ["|"] = 4, - ["~"] = 5, - ["&"] = 6, - ["<<"] = 7, - [">>"] = 7, - [".."] = 8, - ["+"] = 9, - ["-"] = 9, - ["*"] = 10, - ["/"] = 10, - ["//"] = 10, - ["%"] = 10, - ["^"] = 12, - ["as"] = 50, - ["@funcall"] = 100, - ["@index"] = 100, - ["."] = 100, - [":"] = 100, - }, - } - - local is_right_assoc = { - ["^"] = true, - [".."] = true, - } + do + local precedences = { + [1] = { + ["not"] = 11, + ["#"] = 11, + ["-"] = 11, + ["~"] = 11, + }, + [2] = { + ["or"] = 1, + ["and"] = 2, + ["is"] = 3, + ["<"] = 3, + [">"] = 3, + ["<="] = 3, + [">="] = 3, + ["~="] = 3, + ["=="] = 3, + ["|"] = 4, + ["~"] = 5, + ["&"] = 6, + ["<<"] = 7, + [">>"] = 7, + [".."] = 8, + ["+"] = 9, + ["-"] = 9, + ["*"] = 10, + ["/"] = 10, + ["//"] = 10, + ["%"] = 10, + ["^"] = 12, + ["as"] = 50, + ["@funcall"] = 100, + ["@index"] = 100, + ["."] = 100, + [":"] = 100, + }, + } - local function new_operator(tk, arity, op) - return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local is_right_assoc = { + ["^"] = true, + [".."] = true, + } - an_operator = function(node, arity, op) - return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local function new_operator(tk, arity, op) + return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local args_starters = { - ["("] = true, - ["{"] = true, - ["string"] = true, - } + an_operator = function(node, arity, op) + return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local E + local args_starters = { + ["("] = true, + ["{"] = true, + ["string"] = true, + } - local function after_valid_prefixexp(ps, prevnode, i) - return ps.tokens[i - 1].kind == ")" or - (prevnode.kind == "op" and - (prevnode.op.op == "@funcall" or - prevnode.op.op == "@index" or - prevnode.op.op == "." or - prevnode.op.op == ":")) or + local E - prevnode.kind == "identifier" or - prevnode.kind == "variable" - end + local function after_valid_prefixexp(ps, prevnode, i) + return ps.tokens[i - 1].kind == ")" or + (prevnode.kind == "op" and + (prevnode.op.op == "@funcall" or + prevnode.op.op == "@index" or + prevnode.op.op == "." or + prevnode.op.op == ":")) or + prevnode.kind == "identifier" or + prevnode.kind == "variable" + end - local function failstore(tkop, e1) - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } - end - local function P(ps, i) - if ps.tokens[i].kind == "$EOF$" then - return i + local function failstore(ps, tkop, e1) + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end - local e1 - local t1 = ps.tokens[i] - if precedences[1][t1.tk] ~= nil then - local op = new_operator(t1, 1, t1.tk) - i = i + 1 - local prev_i = i - i, e1 = P(ps, i) - if not e1 then - fail(ps, prev_i, "expected an expression") + + local function P(ps, i) + if ps.tokens[i].kind == "$EOF$" then return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } - elseif ps.tokens[i].tk == "(" then - i = i + 1 - local prev_i = i - i, e1 = parse_expression_and_tk(ps, i, ")") + local e1 + local t1 = ps.tokens[i] + if precedences[1][t1.tk] ~= nil then + local op = new_operator(t1, 1, t1.tk) + i = i + 1 + local prev_i = i + i, e1 = P(ps, i) + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + elseif ps.tokens[i].tk == "(" then + i = i + 1 + local prev_i = i + i, e1 = parse_expression_and_tk(ps, i, ")") + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + else + i, e1 = parse_literal(ps, i) + end + if not e1 then - fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } - else - i, e1 = parse_literal(ps, i) - end - - if not e1 then - return i - end - while true do - local tkop = ps.tokens[i] - if tkop.kind == "," or tkop.kind == ")" then - break - end - if tkop.tk == "." or tkop.tk == ":" then - local op = new_operator(tkop, 2, tkop.tk) + while true do + local tkop = ps.tokens[i] + if tkop.kind == "," or tkop.kind == ")" then + break + end + if tkop.tk == "." or tkop.tk == ":" then + local op = new_operator(tkop, 2, tkop.tk) - local prev_i = i + local prev_i = i - local key - i = i + 1 - if ps.tokens[i].kind ~= "identifier" then - local skipped = skip(ps, i, parse_type) - if skipped > i + 1 then - fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + local key + i = i + 1 + if ps.tokens[i].kind ~= "identifier" then + local skipped = skip(ps, i, parse_type) + if skipped > i + 1 then + fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") + return skipped, failstore(ps, tkop, e1) + end + end + i, key = verify_kind(ps, i, "identifier") + if not key then + return i, failstore(ps, tkop, e1) end - end - i, key = verify_kind(ps, i, "identifier") - if not key then - return i, failstore(tkop, e1) - end - if op.op == ":" then - if not args_starters[ps.tokens[i].kind] then - if ps.tokens[i].tk == "=" then - fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") - else - fail(ps, i, "expected a function call for a method") + if op.op == ":" then + if not args_starters[ps.tokens[i].kind] then + if ps.tokens[i].tk == "=" then + fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") + else + fail(ps, i, "expected a function call for a method") + end + return i, failstore(ps, tkop, e1) end - return i, failstore(tkop, e1) - end - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call a method on this expression") + return i, failstore(ps, tkop, e1) + end end - end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } - elseif tkop.tk == "(" then - local op = new_operator(tkop, 2, "@funcall") + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + elseif tkop.tk == "(" then + local op = new_operator(tkop, 2, "@funcall") - local prev_i = i + local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") - i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) + local args = new_node(ps, i, "expression_list") + i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) - end - - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call this expression") + return i, failstore(ps, tkop, e1) + end - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "[" then - local op = new_operator(tkop, 2, "@index") + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - local prev_i = i + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "[" then + local op = new_operator(tkop, 2, "@index") - local idx - i = i + 1 - i, idx = parse_expression_and_tk(ps, i, "]") + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) - end + local idx + i = i + 1 + i, idx = parse_expression_and_tk(ps, i, "]") - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } - elseif tkop.kind == "string" or tkop.kind == "{" then - local op = new_operator(tkop, 2, "@funcall") + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot index this expression") + return i, failstore(ps, tkop, e1) + end - local prev_i = i + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + elseif tkop.kind == "string" or tkop.kind == "{" then + local op = new_operator(tkop, 2, "@funcall") - local args = new_node(ps.tokens, i, "expression_list") - local argument - if tkop.kind == "string" then - argument = new_node(ps.tokens, i) - argument.conststr = unquote(tkop.tk) - i = i + 1 - else - i, argument = parse_table_literal(ps, i) - end + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then + local args = new_node(ps, i, "expression_list") + local argument if tkop.kind == "string" then - fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + argument = new_node(ps, i) + argument.conststr = unquote(tkop.tk) + i = i + 1 else - fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + i, argument = parse_table_literal(ps, i) + end + + if not after_valid_prefixexp(ps, e1, prev_i) then + if tkop.kind == "string" then + fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + else + fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + end + return i, failstore(ps, tkop, e1) end - return i, failstore(tkop, e1) - end - table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + table.insert(args, argument) + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "as" or tkop.tk == "is" then - local op = new_operator(tkop, 2, tkop.tk) + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "as" or tkop.tk == "is" then + local op = new_operator(tkop, 2, tkop.tk) - i = i + 1 - local cast = new_node(ps.tokens, i, "cast") - if ps.tokens[i].tk == "(" then - i, cast.casttype = parse_type_list(ps, i, "casttype") + i = i + 1 + local cast = new_node(ps, i, "cast") + if ps.tokens[i].tk == "(" then + i, cast.casttype = parse_type_list(ps, i, "casttype") + else + i, cast.casttype = parse_type(ps, i) + end + if not cast.casttype then + return i, failstore(ps, tkop, e1) + end + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else - i, cast.casttype = parse_type(ps, i) - end - if not cast.casttype then - return i, failstore(tkop, e1) + break end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } - else - break end - end - return i, e1 - end + return i, e1 + end - E = function(ps, i, lhs, min_precedence) - local lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do - local t1 = ps.tokens[i] - local op = new_operator(t1, 2, t1.tk) - i = i + 1 - local rhs - i, rhs = P(ps, i) - if not rhs then - fail(ps, i, "expected an expression") - return i - end - lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or - (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do - i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + E = function(ps, i, lhs, min_precedence) + local lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do + local t1 = ps.tokens[i] + local op = new_operator(t1, 2, t1.tk) + i = i + 1 + local rhs + i, rhs = P(ps, i) if not rhs then fail(ps, i, "expected an expression") return i end lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or + (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do + i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + if not rhs then + fail(ps, i, "expected an expression") + return i + end + lookahead = ps.tokens[i].tk + end + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } + return i, lhs end - return i, lhs - end - parse_expression = function(ps, i) - local lhs - local istart = i - i, lhs = P(ps, i) - if lhs then - i, lhs = E(ps, i, lhs, 0) - end - if lhs then - return i, lhs, 0 - end + parse_expression = function(ps, i) + local lhs + local istart = i + i, lhs = P(ps, i) + if lhs then + i, lhs = E(ps, i, lhs, 0) + end + if lhs then + return i, lhs, 0 + end - if i == istart then - i = fail(ps, i, "expected an expression") + if i == istart then + i = fail(ps, i, "expected an expression") + end + return i end - return i end -end -parse_expression_and_tk = function(ps, i, tk) - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") - end - if ps.tokens[i].tk == tk then - i = i + 1 - else - local msg = "syntax error, expected '" .. tk .. "'" - if ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + parse_expression_and_tk = function(ps, i, tk) + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") end + if ps.tokens[i].tk == tk then + i = i + 1 + else + local msg = "syntax error, expected '" .. tk .. "'" + if ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + end - for n = 0, 19 do - local t = ps.tokens[i + n] - if t.kind == "$EOF$" then - break - end - if t.tk == tk then - fail(ps, i, msg) - return i + n + 1, e + for n = 0, 19 do + local t = ps.tokens[i + n] + if t.kind == "$EOF$" then + break + end + if t.tk == tk then + fail(ps, i, msg) + return i + n + 1, e + end end + i = fail(ps, i, msg) end - i = fail(ps, i, msg) + return i, e end - return i, e -end -local function parse_variable_name(ps, i) - local node - i, node = verify_kind(ps, i, "identifier") - if not node then - return i - end - if ps.tokens[i].tk == "<" then - i = i + 1 - local annotation - i, annotation = verify_kind(ps, i, "identifier") - if annotation then - if not is_attribute[annotation.tk] then - fail(ps, i, "unknown variable annotation: " .. annotation.tk) + local function parse_variable_name(ps, i) + local node + i, node = verify_kind(ps, i, "identifier") + if not node then + return i + end + if ps.tokens[i].tk == "<" then + i = i + 1 + local annotation + i, annotation = verify_kind(ps, i, "identifier") + if annotation then + if not is_attribute[annotation.tk] then + fail(ps, i, "unknown variable annotation: " .. annotation.tk) + end + node.attribute = annotation.tk + else + fail(ps, i, "expected a variable annotation") end - node.attribute = annotation.tk - else - fail(ps, i, "expected a variable annotation") + i = verify_tk(ps, i, ">") end - i = verify_tk(ps, i, ">") + return i, node end - return i, node -end -local function parse_argument(ps, i) - local node - if ps.tokens[i].tk == "..." then - i, node = verify_kind(ps, i, "...", "argument") - node.opt = true - else - i, node = verify_kind(ps, i, "identifier", "argument") - end - if ps.tokens[i].tk == "..." then - fail(ps, i, "'...' needs to be declared as a typed argument") - end - if ps.tokens[i].tk == "?" then - i = i + 1 - node.opt = true - end - if ps.tokens[i].tk == ":" then - i = i + 1 - local argtype + local function parse_argument(ps, i) + local node + if ps.tokens[i].tk == "..." then + i, node = verify_kind(ps, i, "...", "argument") + node.opt = true + else + i, node = verify_kind(ps, i, "identifier", "argument") + end + if ps.tokens[i].tk == "..." then + fail(ps, i, "'...' needs to be declared as a typed argument") + end + if ps.tokens[i].tk == "?" then + i = i + 1 + node.opt = true + end + if ps.tokens[i].tk == ":" then + i = i + 1 + local argtype - i, argtype = parse_type(ps, i) + i, argtype = parse_type(ps, i) - if node then - node.argtype = argtype + if node then + node.argtype = argtype + end end + return i, node, 0 end - return i, node, 0 -end -parse_argument_list = function(ps, i) - local node = new_node(ps.tokens, i, "argument_list") - i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) - local opts = false - local min_arity = 0 - for a, fnarg in ipairs(node) do - if fnarg.tk == "..." then - if a ~= #node then - fail(ps, i, "'...' can only be last argument") - break + parse_argument_list = function(ps, i) + local node = new_node(ps, i, "argument_list") + i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) + local opts = false + local min_arity = 0 + for a, fnarg in ipairs(node) do + if fnarg.tk == "..." then + if a ~= #node then + fail(ps, i, "'...' can only be last argument") + break + end + elseif fnarg.opt then + opts = true + elseif opts then + return fail(ps, i, "non-optional arguments cannot follow optional arguments") + else + min_arity = min_arity + 1 end - elseif fnarg.opt then - opts = true - elseif opts then - return fail(ps, i, "non-optional arguments cannot follow optional arguments") - else - min_arity = min_arity + 1 end + return i, node, min_arity end - return i, node, min_arity -end @@ -3176,1014 +3207,982 @@ end -local function parse_argument_type(ps, i) - local opt = false - local is_va = false - local is_self = false - local argument_name = nil + local function parse_argument_type(ps, i) + local opt = false + local is_va = false + local is_self = false + local argument_name = nil - if ps.tokens[i].kind == "identifier" then - argument_name = ps.tokens[i].tk - if ps.tokens[i + 1].tk == "?" then + if ps.tokens[i].kind == "identifier" then + argument_name = ps.tokens[i].tk + if ps.tokens[i + 1].tk == "?" then + opt = true + if ps.tokens[i + 2].tk == ":" then + i = i + 3 + end + elseif ps.tokens[i + 1].tk == ":" then + i = i + 2 + end + elseif ps.tokens[i].kind == "?" then opt = true - if ps.tokens[i + 2].tk == ":" then - i = i + 3 + i = i + 1 + elseif ps.tokens[i].tk == "..." then + if ps.tokens[i + 1].tk == ":" then + i = i + 2 + is_va = true + else + return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") end - elseif ps.tokens[i + 1].tk == ":" then - i = i + 2 end - elseif ps.tokens[i].kind == "?" then - opt = true - i = i + 1 - elseif ps.tokens[i].tk == "..." then - if ps.tokens[i + 1].tk == ":" then - i = i + 2 - is_va = true - else - return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") - end - end - local typ; i, typ = parse_type(ps, i) - if typ then - if not is_va and ps.tokens[i].tk == "..." then - i = i + 1 - is_va = true - end + local typ; i, typ = parse_type(ps, i) + if typ then + if not is_va and ps.tokens[i].tk == "..." then + i = i + 1 + is_va = true + end - if argument_name == "self" then - is_self = true + if argument_name == "self" then + is_self = true + end end - end - return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 -end + return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 + end -parse_argument_type_list = function(ps, i) - local ars = {} - i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) - local t, list = new_tuple(ps, i) - local n = #ars - local min_arity = 0 - for l, ar in ipairs(ars) do - list[l] = ar.type - if ar.is_va and l < n then - fail(ps, ar.i, "'...' can only be last argument") + parse_argument_type_list = function(ps, i) + local ars = {} + i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) + local t, list = new_tuple(ps, i) + local n = #ars + local min_arity = 0 + for l, ar in ipairs(ars) do + list[l] = ar.type + if ar.is_va and l < n then + fail(ps, ar.i, "'...' can only be last argument") + end + if not ar.opt then + min_arity = min_arity + 1 + end end - if not ar.opt then - min_arity = min_arity + 1 + if n > 0 and ars[n].is_va then + t.is_va = true end + return i, t, (n > 0 and ars[1].is_self), min_arity end - if n > 0 and ars[n].is_va then - t.is_va = true + + local function parse_identifier(ps, i) + if ps.tokens[i].kind == "identifier" then + return i + 1, new_node(ps, i, "identifier") + end + i = fail(ps, i, "syntax error, expected identifier") + return i, new_node(ps, i, "error_node") end - return i, t, (n > 0 and ars[1].is_self), min_arity -end -local function parse_identifier(ps, i) - if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + local function parse_local_function(ps, i) + i = verify_tk(ps, i, "local") + i = verify_tk(ps, i, "function") + local node = new_node(ps, i - 2, "local_function") + i, node.name = parse_identifier(ps, i) + return parse_function_args_rets_body(ps, i, node) end - i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") -end -local function parse_local_function(ps, i) - i = verify_tk(ps, i, "local") - i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") - i, node.name = parse_identifier(ps, i) - return parse_function_args_rets_body(ps, i, node) -end + local function parse_function(ps, i, fk) + local orig_i = i + i = verify_tk(ps, i, "function") + local fn = new_node(ps, i - 1, "global_function") + local names = {} + i, names[1] = parse_identifier(ps, i) + while ps.tokens[i].tk == "." do + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + end + if ps.tokens[i].tk == ":" then + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + fn.is_method = true + end -local function parse_function(ps, i, fk) - local orig_i = i - i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") - local names = {} - i, names[1] = parse_identifier(ps, i) - while ps.tokens[i].tk == "." do - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - end - if ps.tokens[i].tk == ":" then - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - fn.is_method = true - end + if #names > 1 then + fn.kind = "record_function" + local owner = names[1] + owner.kind = "type_identifier" + for i2 = 2, #names - 1 do + local dot = an_operator(names[i2], 2, ".") + names[i2].kind = "identifier" + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + end + fn.fn_owner = owner + end + fn.name = names[#names] - if #names > 1 then - fn.kind = "record_function" - local owner = names[1] - owner.kind = "type_identifier" - for i2 = 2, #names - 1 do - local dot = an_operator(names[i2], 2, ".") - names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y + i = parse_function_args_rets_body(ps, i, fn) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + fn.min_arity = fn.min_arity + 1 end - fn.fn_owner = owner - end - fn.name = names[#names] - local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y - i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) - fn.min_arity = fn.min_arity + 1 - end + if not fn.name then + return orig_i + 1 + end - if not fn.name then - return orig_i + 1 - end + if fn.kind == "record_function" and fk == "global" then + fail(ps, orig_i, "record functions cannot be annotated as 'global'") + elseif fn.kind == "global_function" and fk == "record" then + fn.implicit_global_function = true + end - if fn.kind == "record_function" and fk == "global" then - fail(ps, orig_i, "record functions cannot be annotated as 'global'") - elseif fn.kind == "global_function" and fk == "record" then - fn.implicit_global_function = true + return i, fn end - return i, fn -end - -local function parse_if_block(ps, i, n, node, is_else) - local block = new_node(ps.tokens, i, "if_block") - i = i + 1 - block.if_parent = node - block.if_block_n = n - if not is_else then - i, block.exp = parse_expression_and_tk(ps, i, "then") - if not block.exp then + local function parse_if_block(ps, i, n, node, is_else) + local block = new_node(ps, i, "if_block") + i = i + 1 + block.if_parent = node + block.if_block_n = n + if not is_else then + i, block.exp = parse_expression_and_tk(ps, i, "then") + if not block.exp then + return i + end + end + i, block.body = parse_statements(ps, i) + if not block.body then return i end + end_at(block.body, ps.tokens[i - 1]) + block.yend, block.xend = block.body.yend, block.body.xend + table.insert(node.if_blocks, block) + return i, node end - i, block.body = parse_statements(ps, i) - if not block.body then - return i - end - end_at(block.body, ps.tokens[i - 1]) - block.yend, block.xend = block.body.yend, block.body.xend - table.insert(node.if_blocks, block) - return i, node -end -local function parse_if(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "if") - node.if_blocks = {} - i, node = parse_if_block(ps, i, 1, node) - if not node then - return i - end - local n = 2 - while ps.tokens[i].tk == "elseif" do - i, node = parse_if_block(ps, i, n, node) + local function parse_if(ps, i) + local istart = i + local node = new_node(ps, i, "if") + node.if_blocks = {} + i, node = parse_if_block(ps, i, 1, node) if not node then return i end - n = n + 1 + local n = 2 + while ps.tokens[i].tk == "elseif" do + i, node = parse_if_block(ps, i, n, node) + if not node then + return i + end + n = n + 1 + end + if ps.tokens[i].tk == "else" then + i, node = parse_if_block(ps, i, n, node, true) + if not node then + return i + end + end + i = verify_end(ps, i, istart, node) + return i, node + end + + local function parse_while(ps, i) + local istart = i + local node = new_node(ps, i, "while") + i = verify_tk(ps, i, "while") + i, node.exp = parse_expression_and_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - if ps.tokens[i].tk == "else" then - i, node = parse_if_block(ps, i, n, node, true) - if not node then - return i + + local function parse_fornum(ps, i) + local istart = i + local node = new_node(ps, i, "fornum") + i = i + 1 + i, node.var = parse_identifier(ps, i) + i = verify_tk(ps, i, "=") + i, node.from = parse_expression_and_tk(ps, i, ",") + i, node.to = parse_expression(ps, i) + if ps.tokens[i].tk == "," then + i = i + 1 + i, node.step = parse_expression_and_tk(ps, i, "do") + else + i = verify_tk(ps, i, "do") end + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_while(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "while") - i = verify_tk(ps, i, "while") - i, node.exp = parse_expression_and_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_fornum(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "fornum") - i = i + 1 - i, node.var = parse_identifier(ps, i) - i = verify_tk(ps, i, "=") - i, node.from = parse_expression_and_tk(ps, i, ",") - i, node.to = parse_expression(ps, i) - if ps.tokens[i].tk == "," then + local function parse_forin(ps, i) + local istart = i + local node = new_node(ps, i, "forin") i = i + 1 - i, node.step = parse_expression_and_tk(ps, i, "do") - else + node.vars = new_node(ps, i, "variable_list") + i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) + i = verify_tk(ps, i, "in") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) + if #node.exps < 1 then + return fail(ps, i, "missing iterator expression in generic for") + elseif #node.exps > 3 then + return fail(ps, i, "too many expressions in generic for") + end i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_forin(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "forin") - i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") - i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) - i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) - if #node.exps < 1 then - return fail(ps, i, "missing iterator expression in generic for") - elseif #node.exps > 3 then - return fail(ps, i, "too many expressions in generic for") - end - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_for(ps, i) - if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then - return parse_fornum(ps, i) - else - return parse_forin(ps, i) + local function parse_for(ps, i) + if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then + return parse_fornum(ps, i) + else + return parse_forin(ps, i) + end end -end -local function parse_repeat(ps, i) - local node = new_node(ps.tokens, i, "repeat") - i = verify_tk(ps, i, "repeat") - i, node.body = parse_statements(ps, i) - node.body.is_repeat = true - i = verify_tk(ps, i, "until") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + local function parse_repeat(ps, i) + local node = new_node(ps, i, "repeat") + i = verify_tk(ps, i, "repeat") + i, node.body = parse_statements(ps, i) + node.body.is_repeat = true + i = verify_tk(ps, i, "until") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local function parse_do(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "do") - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end + local function parse_do(ps, i) + local istart = i + local node = new_node(ps, i, "do") + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node + end -local function parse_break(ps, i) - local node = new_node(ps.tokens, i, "break") - i = verify_tk(ps, i, "break") - return i, node -end + local function parse_break(ps, i) + local node = new_node(ps, i, "break") + i = verify_tk(ps, i, "break") + return i, node + end -local function parse_goto(ps, i) - local node = new_node(ps.tokens, i, "goto") - i = verify_tk(ps, i, "goto") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - return i, node -end + local function parse_goto(ps, i) + local node = new_node(ps, i, "goto") + i = verify_tk(ps, i, "goto") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + return i, node + end -local function parse_label(ps, i) - local node = new_node(ps.tokens, i, "label") - i = verify_tk(ps, i, "::") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - i = verify_tk(ps, i, "::") - return i, node -end + local function parse_label(ps, i) + local node = new_node(ps, i, "label") + i = verify_tk(ps, i, "::") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + i = verify_tk(ps, i, "::") + return i, node + end -local stop_statement_list = { - ["end"] = true, - ["else"] = true, - ["elseif"] = true, - ["until"] = true, -} + local stop_statement_list = { + ["end"] = true, + ["else"] = true, + ["elseif"] = true, + ["until"] = true, + } -local stop_return_list = { - [";"] = true, - ["$EOF$"] = true, -} + local stop_return_list = { + [";"] = true, + ["$EOF$"] = true, + } -for k, v in pairs(stop_statement_list) do - stop_return_list[k] = v -end + for k, v in pairs(stop_statement_list) do + stop_return_list[k] = v + end -local function parse_return(ps, i) - local node = new_node(ps.tokens, i, "return") - i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) - if ps.tokens[i].kind == ";" then - i = i + 1 + local function parse_return(ps, i) + local node = new_node(ps, i, "return") + i = verify_tk(ps, i, "return") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) + if ps.tokens[i].kind == ";" then + i = i + 1 + end + return i, node end - return i, node -end -local function store_field_in_record(ps, i, field_name, t, fields, field_order) - if not fields[field_name] then - fields[field_name] = t - table.insert(field_order, field_name) - else - local prev_t = fields[field_name] - if t.typename == "function" and prev_t.typename == "function" then - local p = new_type(ps, i, "poly") - p.types = { prev_t, t } - fields[field_name] = p - elseif t.typename == "function" and prev_t.typename == "poly" then - table.insert(prev_t.types, t) + local function store_field_in_record(ps, i, field_name, t, fields, field_order) + if not fields[field_name] then + fields[field_name] = t + table.insert(field_order, field_name) else - fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") - return false + local prev_t = fields[field_name] + if t.typename == "function" and prev_t.typename == "function" then + local p = new_type(ps, i, "poly") + p.types = { prev_t, t } + fields[field_name] = p + elseif t.typename == "function" and prev_t.typename == "poly" then + table.insert(prev_t.types, t) + else + fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") + return false + end end + return true end - return true -end -local function parse_nested_type(ps, i, def, typename, parse_body) - i = i + 1 - local iv = i + local function parse_nested_type(ps, i, def, typename, parse_body) + i = i + 1 + local iv = i + + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end + local nt = new_node(ps, i - 2, "newtype") + local ndef = new_type(ps, i, typename) + local itype = i + local iok = parse_body(ps, i, ndef, nt) + if iok then + i = iok + nt.newtype = new_typedecl(ps, itype, ndef) + end - local nt = new_node(ps.tokens, i - 2, "newtype") - local ndef = new_type(ps, i, typename) - local iok = parse_body(ps, i, ndef, nt) - if iok then - i = iok - nt.newtype = new_typedecl(ps, i, ndef) + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + return i end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - return i -end - -parse_enum_body = function(ps, i, def, node) - local istart = i - 1 - def.enumset = {} - while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do - local item - i, item = verify_kind(ps, i, "string", "enum_item") - if item then - table.insert(node, item) - def.enumset[unquote(item.tk)] = true + parse_enum_body = function(ps, i, def, node) + local istart = i - 1 + def.enumset = {} + while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do + local item + i, item = verify_kind(ps, i, "string", "enum_item") + if item then + table.insert(node, item) + def.enumset[unquote(item.tk)] = true + end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local metamethod_names = { - ["__add"] = true, - ["__sub"] = true, - ["__mul"] = true, - ["__div"] = true, - ["__mod"] = true, - ["__pow"] = true, - ["__unm"] = true, - ["__idiv"] = true, - ["__band"] = true, - ["__bor"] = true, - ["__bxor"] = true, - ["__bnot"] = true, - ["__shl"] = true, - ["__shr"] = true, - ["__concat"] = true, - ["__len"] = true, - ["__eq"] = true, - ["__lt"] = true, - ["__le"] = true, - ["__index"] = true, - ["__newindex"] = true, - ["__call"] = true, - ["__tostring"] = true, - ["__pairs"] = true, - ["__gc"] = true, - ["__close"] = true, - ["__is"] = true, -} - -local function parse_macroexp(ps, istart, iargs) + local metamethod_names = { + ["__add"] = true, + ["__sub"] = true, + ["__mul"] = true, + ["__div"] = true, + ["__mod"] = true, + ["__pow"] = true, + ["__unm"] = true, + ["__idiv"] = true, + ["__band"] = true, + ["__bor"] = true, + ["__bxor"] = true, + ["__bnot"] = true, + ["__shl"] = true, + ["__shr"] = true, + ["__concat"] = true, + ["__len"] = true, + ["__eq"] = true, + ["__lt"] = true, + ["__le"] = true, + ["__index"] = true, + ["__newindex"] = true, + ["__call"] = true, + ["__tostring"] = true, + ["__pairs"] = true, + ["__gc"] = true, + ["__close"] = true, + ["__is"] = true, + } + local function parse_macroexp(ps, istart, iargs) - local node = new_node(ps.tokens, istart, "macroexp") - local i - i, node.args, node.min_arity = parse_argument_list(ps, iargs) - i, node.rets = parse_return_types(ps, i) - i = verify_tk(ps, i, "return") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_where_clause(ps, i) - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") - node.args[1].tk = "self" - node.args[1].argtype = selftype - node.min_arity = 1 - node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end -parse_interface_name = function(ps, i) - local istart = i - local typ - i, typ = parse_simple_type_or_nominal(ps, i) - if not (typ.typename == "nominal") then - return fail(ps, istart, "expected an interface") + local node = new_node(ps, istart, "macroexp") + local i + i, node.args, node.min_arity = parse_argument_list(ps, iargs) + i, node.rets = parse_return_types(ps, i) + i = verify_tk(ps, i, "return") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + return i, node end - return i, typ -end -local function parse_array_interface_type(ps, i, def) - if def.interface_list then - local first = def.interface_list[1] - if first.typename == "array" then - return failskip(ps, i, "duplicated declaration of array element type", parse_type) - end - end - local t - i, t = parse_base_type(ps, i) - if not t then - return i - end - if not (t.typename == "array") then - fail(ps, i, "expected an array declaration") - return i + local function parse_where_clause(ps, i) + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") + node.args[1].tk = "self" + node.args[1].argtype = new_nominal(ps, i, "@self") + node.min_arity = 1 + node.rets = new_tuple(ps, i) + node.rets.tuple[1] = new_type(ps, i, "boolean") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node end - def.elements = t.elements - return i, t -end - -parse_record_body = function(ps, i, def, node) - local istart = i - 1 - def.fields = {} - def.field_order = {} - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + parse_interface_name = function(ps, i) + local istart = i + local typ + i, typ = parse_simple_type_or_nominal(ps, i) + if not (typ.typename == "nominal") then + return fail(ps, istart, "expected an interface") + end + return i, typ end - if ps.tokens[i].tk == "{" then - local atype - i, atype = parse_array_interface_type(ps, i, def) - if atype then - def.interface_list = { atype } + local function parse_array_interface_type(ps, i, def) + if def.interface_list then + local first = def.interface_list[1] + if first.typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type) + end + end + local t + i, t = parse_base_type(ps, i) + if not t then + return i + end + if not (t.typename == "array") then + fail(ps, i, "expected an array declaration") + return i end + def.elements = t.elements + return i, t end - if ps.tokens[i].tk == "is" then - i = i + 1 + parse_record_body = function(ps, i, def, node) + local istart = i - 1 + def.fields = {} + def.field_order = {} + + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end if ps.tokens[i].tk == "{" then local atype i, atype = parse_array_interface_type(ps, i, def) - if ps.tokens[i].tk == "," then - i = i + 1 - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) - else - def.interface_list = {} - end if atype then - table.insert(def.interface_list, 1, atype) + def.interface_list = { atype } end - else - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end - end - if ps.tokens[i].tk == "where" then - local wstart = i - i = i + 1 - local where_macroexp - i, where_macroexp = parse_where_clause(ps, i) - - local typ = new_type(ps, wstart, "function") - typ.is_method = true - typ.min_arity = 1 - typ.args = a_type("tuple", { tuple = { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" }, - }), - } }) - typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) - typ.macroexp = where_macroexp - - def.meta_fields = {} - def.meta_field_order = {} - store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) - end - - while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do - local tn = ps.tokens[i].tk - if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then - if def.is_userdata then - fail(ps, i, "duplicated 'userdata' declaration") + if ps.tokens[i].tk == "is" then + i = i + 1 + + if ps.tokens[i].tk == "{" then + local atype + i, atype = parse_array_interface_type(ps, i, def) + if ps.tokens[i].tk == "," then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + else + def.interface_list = {} + end + if atype then + table.insert(def.interface_list, 1, atype) + end else - def.is_userdata = true + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + end + + if ps.tokens[i].tk == "where" then + local wstart = i i = i + 1 - elseif ps.tokens[i].tk == "{" then - return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") - elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then - i = i + 1 - local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end - i = verify_tk(ps, i, "=") - local nt - i, nt = parse_newtype(ps, i) - if not nt or not nt.newtype then - return fail(ps, i, "expected a type definition") - end + local where_macroexp + i, where_macroexp = parse_where_clause(ps, i) + + local typ = new_type(ps, wstart, "function") + typ.is_method = true + typ.min_arity = 1 + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }), + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) + typ.macroexp = where_macroexp - local ntt = nt.newtype - if ntt.typename == "typealias" then - ntt.is_nested_alias = true - end + def.meta_fields = {} + def.meta_field_order = {} + store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) + end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then - i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) - else - local is_metamethod = false - if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then - is_metamethod = true + while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do + local tn = ps.tokens[i].tk + if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then + if def.is_userdata then + fail(ps, i, "duplicated 'userdata' declaration") + else + def.is_userdata = true + end i = i + 1 - end + elseif ps.tokens[i].tk == "{" then + return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") + elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then + i = i + 1 + local iv = i + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end + i = verify_tk(ps, i, "=") + local nt + i, nt = parse_newtype(ps, i) + if not nt or not nt.newtype then + return fail(ps, i, "expected a type definition") + end - local v - if ps.tokens[i].tk == "[" then - i, v = parse_literal(ps, i + 1) - if v and not v.conststr then - return fail(ps, i, "expected a string literal") + local ntt = nt.newtype + if ntt.typename == "typealias" then + ntt.is_nested_alias = true end - i = verify_tk(ps, i, "]") + + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then + i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) else - i, v = verify_kind(ps, i, "identifier", "variable") - end - local iv = i - if not v then - return fail(ps, i, "expected a variable name") - end + local is_metamethod = false + if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then + is_metamethod = true + i = i + 1 + end - if ps.tokens[i].tk == ":" then - i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then - return fail(ps, i, "expected a type") + local v + if ps.tokens[i].tk == "[" then + i, v = parse_literal(ps, i + 1) + if v and not v.conststr then + return fail(ps, i, "expected a string literal") + end + i = verify_tk(ps, i, "]") + else + i, v = verify_kind(ps, i, "identifier", "variable") + end + local iv = i + if not v then + return fail(ps, i, "expected a variable name") end - local field_name = v.conststr or v.tk - local fields = def.fields - local field_order = def.field_order - if is_metamethod then - if not def.meta_fields then - def.meta_fields = {} - def.meta_field_order = {} + if ps.tokens[i].tk == ":" then + i = i + 1 + local t + i, t = parse_type(ps, i) + if not t then + return fail(ps, i, "expected a type") end - fields = def.meta_fields - field_order = def.meta_field_order - if not metamethod_names[field_name] then - fail(ps, i - 1, "not a valid metamethod: " .. field_name) + + local field_name = v.conststr or v.tk + local fields = def.fields + local field_order = def.field_order + if is_metamethod then + if not def.meta_fields then + def.meta_fields = {} + def.meta_field_order = {} + end + fields = def.meta_fields + field_order = def.meta_field_order + if not metamethod_names[field_name] then + fail(ps, i - 1, "not a valid metamethod: " .. field_name) + end end - end - if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then - if not (t.typename == "function") then - fail(ps, i + 1, "macroexp must have a function type") - else - i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then + if not (t.typename == "function") then + fail(ps, i + 1, "macroexp must have a function type") + else + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + end end - end - store_field_in_record(ps, iv, field_name, t, fields, field_order) - elseif ps.tokens[i].tk == "=" then - local next_word = ps.tokens[i + 1].tk - if next_word == "record" or next_word == "enum" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") - elseif next_word == "functiontype" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") + store_field_in_record(ps, iv, field_name, t, fields, field_order) + elseif ps.tokens[i].tk == "=" then + local next_word = ps.tokens[i + 1].tk + if next_word == "record" or next_word == "enum" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") + elseif next_word == "functiontype" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") + else + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + end else - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end - else - fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -parse_type_body_fns = { - ["interface"] = parse_record_body, - ["record"] = parse_record_body, - ["enum"] = parse_enum_body, -} -parse_newtype = function(ps, i) - local node = new_node(ps.tokens, i, "newtype") - local def - local tn = ps.tokens[i].tk - local itype = i - if parse_type_body_fns[tn] then - def = new_type(ps, i, tn) - i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node) - if not def then - return fail(ps, i, "expected a type") - end + parse_type_body_fns = { + ["interface"] = parse_record_body, + ["record"] = parse_record_body, + ["enum"] = parse_enum_body, + } - node.newtype = new_typedecl(ps, itype, def) - return i, node - else - i, def = parse_type(ps, i) - if not def then - return fail(ps, i, "expected a type") - end + parse_newtype = function(ps, i) + local node = new_node(ps, i, "newtype") + local def + local tn = ps.tokens[i].tk + local itype = i + if parse_type_body_fns[tn] then + def = new_type(ps, i, tn) + i = i + 1 + i = parse_type_body_fns[tn](ps, i, def, node) + if not def then + return fail(ps, i, "expected a type") + end - if def.typename == "nominal" then - local typealias = new_type(ps, itype, "typealias") - typealias.alias_to = def - node.newtype = typealias - else node.newtype = new_typedecl(ps, itype, def) - end - - return i, node - end -end + return i, node + else + i, def = parse_type(ps, i) + if not def then + return fail(ps, i, "expected a type") + end -local function parse_assignment_expression_list(ps, i, asgn) - asgn.exps = new_node(ps.tokens, i, "expression_list") - repeat - i = i + 1 - local val - i, val = parse_expression(ps, i) - if not val then - if #asgn.exps == 0 then - asgn.exps = nil + if def.typename == "nominal" then + node.newtype = new_typealias(ps, itype, def) + else + node.newtype = new_typedecl(ps, itype, def) end - return i - end - table.insert(asgn.exps, val) - until ps.tokens[i].tk ~= "," - return i, asgn -end -local parse_call_or_assignment -do - local function is_lvalue(node) - node.is_lvalue = node.kind == "variable" or - (node.kind == "op" and - (node.op.op == "@index" or node.op.op == ".")) - return node.is_lvalue + return i, node + end end - local function parse_variable(ps, i) - local node - i, node = parse_expression(ps, i) - if not (node and is_lvalue(node)) then - return fail(ps, i, "expected a variable") - end - return i, node + local function parse_assignment_expression_list(ps, i, asgn) + asgn.exps = new_node(ps, i, "expression_list") + repeat + i = i + 1 + local val + i, val = parse_expression(ps, i) + if not val then + if #asgn.exps == 0 then + asgn.exps = nil + end + return i + end + table.insert(asgn.exps, val) + until ps.tokens[i].tk ~= "," + return i, asgn end - parse_call_or_assignment = function(ps, i) - local exp - local istart = i - i, exp = parse_expression(ps, i) - if not exp then - return i + local parse_call_or_assignment + do + local function is_lvalue(node) + node.is_lvalue = node.kind == "variable" or + (node.kind == "op" and + (node.op.op == "@index" or node.op.op == ".")) + return node.is_lvalue end - if (exp.op and exp.op.op == "@funcall") or exp.failstore then - return i, exp + local function parse_variable(ps, i) + local node + i, node = parse_expression(ps, i) + if not (node and is_lvalue(node)) then + return fail(ps, i, "expected a variable") + end + return i, node end - if not is_lvalue(exp) then - return fail(ps, i, "syntax error") - end + parse_call_or_assignment = function(ps, i) + local exp + local istart = i + i, exp = parse_expression(ps, i) + if not exp then + return i + end - local asgn = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") - asgn.vars[1] = exp - if ps.tokens[i].tk == "," then - i = i + 1 - i = parse_trying_list(ps, i, asgn.vars, parse_variable) - if #asgn.vars < 2 then - return fail(ps, i, "syntax error") + if (exp.op and exp.op.op == "@funcall") or exp.failstore then + return i, exp end - end - if ps.tokens[i].tk ~= "=" then - verify_tk(ps, i, "=") - return i - end + if not is_lvalue(exp) then + return fail(ps, i, "syntax error") + end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - return i, asgn - end -end + local asgn = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") + asgn.vars[1] = exp + if ps.tokens[i].tk == "," then + i = i + 1 + i = parse_trying_list(ps, i, asgn.vars, parse_variable) + if #asgn.vars < 2 then + return fail(ps, i, "syntax error") + end + end -local function parse_variable_declarations(ps, i, node_name) - local asgn = new_node(ps.tokens, i, node_name) + if ps.tokens[i].tk ~= "=" then + verify_tk(ps, i, "=") + return i + end - asgn.vars = new_node(ps.tokens, i, "variable_list") - i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) - if #asgn.vars == 0 then - return fail(ps, i, "expected a local variable definition") + i, asgn = parse_assignment_expression_list(ps, i, asgn) + return i, asgn + end end - i, asgn.decltuple = parse_type_list(ps, i, "decltuple") + local function parse_variable_declarations(ps, i, node_name) + local asgn = new_node(ps, i, node_name) - if ps.tokens[i].tk == "=" then - - local next_word = ps.tokens[i + 1].tk - local tn = next_word - if parse_type_body_fns[tn] then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) - elseif next_word == "functiontype" then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + asgn.vars = new_node(ps, i, "variable_list") + i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) + if #asgn.vars == 0 then + return fail(ps, i, "expected a local variable definition") end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - end - return i, asgn -end + i, asgn.decltuple = parse_type_list(ps, i, "decltuple") -local function parse_type_declaration(ps, i, node_name) - i = i + 2 + if ps.tokens[i].tk == "=" then - local asgn = new_node(ps.tokens, i, node_name) - i, asgn.var = parse_variable_name(ps, i) - if not asgn.var then - return fail(ps, i, "expected a type name") - end + local next_word = ps.tokens[i + 1].tk + local tn = next_word + if parse_type_body_fns[tn] then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) + elseif next_word == "functiontype" then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + end - if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + i, asgn = parse_assignment_expression_list(ps, i, asgn) + end return i, asgn end - i = verify_tk(ps, i, "=") + local function parse_type_declaration(ps, i, node_name) + i = i + 2 - if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + local asgn = new_node(ps, i, node_name) + i, asgn.var = parse_variable_name(ps, i) + if not asgn.var then + return fail(ps, i, "expected a type name") end - return i, asgn - end - i, asgn.value = parse_newtype(ps, i) - if not asgn.value then - return i - end + if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + return i, asgn + end - local nt = asgn.value.newtype - if nt.typename == "typedecl" then - local def = nt.def - if def.fields or def.typename == "enum" then - if not def.declname then - def.declname = asgn.var.tk + i = verify_tk(ps, i, "=") + + if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then + local istart = i + i, asgn.value = parse_call_or_assignment(ps, i) + if asgn.value and not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") end + return i, asgn end - end - - return i, asgn -end -local function parse_type_constructor(ps, i, node_name, type_name, parse_body) - local asgn = new_node(ps.tokens, i, node_name) - local nt = new_node(ps.tokens, i, "newtype") - asgn.value = nt - local itype = i - local def = new_type(ps, i, type_name) + i, asgn.value = parse_newtype(ps, i) + if not asgn.value then + return i + end - i = i + 2 + local nt = asgn.value.newtype + if nt.typename == "typedecl" then + local def = nt.def + if def.fields or def.typename == "enum" then + if not def.declname then + def.declname = asgn.var.tk + end + end + end - i, asgn.var = verify_kind(ps, i, "identifier") - if not asgn.var then - return fail(ps, i, "expected a type name") + return i, asgn end - assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") - def.declname = asgn.var.tk + local function parse_type_constructor(ps, i, node_name, type_name, parse_body) + local asgn = new_node(ps, i, node_name) + local nt = new_node(ps, i, "newtype") + asgn.value = nt + local itype = i + local def = new_type(ps, i, type_name) - i = parse_body(ps, i, def, nt) + i = i + 2 - nt.newtype = new_typedecl(ps, itype, def) + i, asgn.var = verify_kind(ps, i, "identifier") + if not asgn.var then + return fail(ps, i, "expected a type name") + end - return i, asgn -end + assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") + def.declname = asgn.var.tk -local function skip_type_declaration(ps, i) - return parse_type_declaration(ps, i - 1, "local_type") -end + i = parse_body(ps, i, def, nt) -local function parse_local_macroexp(ps, i) - local istart = i - i = i + 2 - local node = new_node(ps.tokens, i, "local_macroexp") - i, node.name = parse_identifier(ps, i) - i, node.macrodef = parse_macroexp(ps, istart, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + nt.newtype = new_typedecl(ps, itype, def) -local function parse_local(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_local_function(ps, i) - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") - elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then - return parse_local_macroexp(ps, i) - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) - end - return parse_variable_declarations(ps, i + 1, "local_declaration") -end + return i, asgn + end -local function parse_global(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_function(ps, i + 1, "global") - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) - elseif ps.tokens[i + 1].kind == "identifier" then - return parse_variable_declarations(ps, i + 1, "global_declaration") - end - return parse_call_or_assignment(ps, i) -end + local function skip_type_declaration(ps, i) + return parse_type_declaration(ps, i - 1, "local_type") + end -local function parse_record_function(ps, i) - return parse_function(ps, i, "record") -end + local function parse_local_macroexp(ps, i) + local istart = i + i = i + 2 + local node = new_node(ps, i, "local_macroexp") + i, node.name = parse_identifier(ps, i) + i, node.macrodef = parse_macroexp(ps, istart, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local parse_statement_fns = { - ["::"] = parse_label, - ["do"] = parse_do, - ["if"] = parse_if, - ["for"] = parse_for, - ["goto"] = parse_goto, - ["local"] = parse_local, - ["while"] = parse_while, - ["break"] = parse_break, - ["global"] = parse_global, - ["repeat"] = parse_repeat, - ["return"] = parse_return, - ["function"] = parse_record_function, -} + local function parse_local(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_local_function(ps, i) + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "local_type") + elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then + return parse_local_macroexp(ps, i) + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) + end + return parse_variable_declarations(ps, i + 1, "local_declaration") + end + + local function parse_global(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_function(ps, i + 1, "global") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "global_type") + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) + elseif ps.tokens[i + 1].kind == "identifier" then + return parse_variable_declarations(ps, i + 1, "global_declaration") + end + return parse_call_or_assignment(ps, i) + end + + local function parse_record_function(ps, i) + return parse_function(ps, i, "record") + end + + local parse_statement_fns = { + ["::"] = parse_label, + ["do"] = parse_do, + ["if"] = parse_if, + ["for"] = parse_for, + ["goto"] = parse_goto, + ["local"] = parse_local, + ["while"] = parse_while, + ["break"] = parse_break, + ["global"] = parse_global, + ["repeat"] = parse_repeat, + ["return"] = parse_return, + ["function"] = parse_record_function, + } -local function type_needs_local_or_global(ps, i) - local tk = ps.tokens[i].tk - return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) -end + local function type_needs_local_or_global(ps, i) + local tk = ps.tokens[i].tk + return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) + end -local needs_local_or_global = { - ["type"] = function(ps, i) - return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) - end, - ["record"] = type_needs_local_or_global, - ["enum"] = type_needs_local_or_global, -} + local needs_local_or_global = { + ["type"] = function(ps, i) + return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) + end, + ["record"] = type_needs_local_or_global, + ["enum"] = type_needs_local_or_global, + } -parse_statements = function(ps, i, toplevel) - local node = new_node(ps.tokens, i, "statements") - local item - while true do - while ps.tokens[i].kind == ";" do - i = i + 1 - if item then - item.semicolon = true + parse_statements = function(ps, i, toplevel) + local node = new_node(ps, i, "statements") + local item + while true do + while ps.tokens[i].kind == ";" do + i = i + 1 + if item then + item.semicolon = true + end end - end - if ps.tokens[i].kind == "$EOF$" then - break - end - local tk = ps.tokens[i].tk - if (not toplevel) and stop_statement_list[tk] then - break - end + if ps.tokens[i].kind == "$EOF$" then + break + end + local tk = ps.tokens[i].tk + if (not toplevel) and stop_statement_list[tk] then + break + end - local fn = parse_statement_fns[tk] - if not fn then - local skip_fn = needs_local_or_global[tk] - if skip_fn and ps.tokens[i + 1].kind == "identifier" then - fn = skip_fn - else - fn = parse_call_or_assignment + local fn = parse_statement_fns[tk] + if not fn then + local skip_fn = needs_local_or_global[tk] + if skip_fn and ps.tokens[i + 1].kind == "identifier" then + fn = skip_fn + else + fn = parse_call_or_assignment + end end - end - i, item = fn(ps, i) + i, item = fn(ps, i) - if item then - table.insert(node, item) - elseif i > 1 then + if item then + table.insert(node, item) + elseif i > 1 then - local lasty = ps.tokens[i - 1].y - while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do - i = i + 1 + local lasty = ps.tokens[i - 1].y + while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do + i = i + 1 + end end end - end - - end_at(node, ps.tokens[i]) - return i, node -end -local function clear_redundant_errors(errors) - local redundant = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i + end_at(node, ps.tokens[i]) + return i, node end - table.sort(errors, function(a, b) - local af = a.filename or "" - local bf = b.filename or "" - return af < bf or - (af == bf and (a.y < b.y or - (a.y == b.y and (a.x < b.x or - (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) + + function tl.parse_program(tokens, errs, filename) + errs = errs or {} + local ps = { + tokens = tokens, + errs = errs, + filename = filename or "", + required_modules = {}, + } + local i = 1 + local hashbang + if ps.tokens[i].kind == "hashbang" then + hashbang = ps.tokens[i].tk + i = i + 1 + end + local _, node = parse_statements(ps, i, true) + if hashbang then + node.hashbang = hashbang end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end -function tl.parse_program(tokens, errs, filename) - errs = errs or {} - local ps = { - tokens = tokens, - errs = errs, - filename = filename or "", - required_modules = {}, - } - local i = 1 - local hashbang - if ps.tokens[i].kind == "hashbang" then - hashbang = ps.tokens[i].tk - i = i + 1 + clear_redundant_errors(errs) + return node, ps.required_modules end - local _, node = parse_statements(ps, i, true) - if hashbang then - node.hashbang = hashbang + + function tl.parse(input, filename) + local tokens, errs = tl.lex(input, filename) + local node, required_modules = tl.parse_program(tokens, errs, filename) + return node, errs, required_modules end - clear_redundant_errors(errs) - return node, ps.required_modules end -function tl.parse(input, filename) - local tokens, errs = tl.lex(input, filename) - local node, required_modules = tl.parse_program(tokens, errs, filename) - return node, errs, required_modules -end + @@ -4296,7 +4295,7 @@ local function tl_debug_indent_pop(mark, single, y, x, fmt, ...) end end -local function recurse_type(ast, visit) +local function recurse_type(s, ast, visit) local kind = ast.typename if TL_DEBUG then @@ -4308,7 +4307,7 @@ local function recurse_type(ast, visit) if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4316,90 +4315,90 @@ local function recurse_type(ast, visit) if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast.types then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast.typename == "map" then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast.fields then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "function" then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "nominal" then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "typearg" then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast.typename == "array" then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast.typename == "literal_table_item" then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4409,15 +4408,16 @@ local function recurse_type(ast, visit) return ret end -local function recurse_typeargs(ast, visit_type) +local function recurse_typeargs(s, ast, visit_type) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end local function extra_callback(name, + s, ast, xs, visit_node) @@ -4427,7 +4427,7 @@ local function extra_callback(name, if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node = { @@ -4447,7 +4447,7 @@ local no_recurse_node = { ["type_identifier"] = true, } -local function recurse_node(root, +local function recurse_node(s, root, visit_node, visit_type) if not root then @@ -4466,9 +4466,9 @@ local function recurse_node(root, local function walk_vars_exps(ast, xs) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4480,11 +4480,11 @@ local function recurse_node(root, end local function walk_named_function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4497,9 +4497,9 @@ local function recurse_node(root, end xs[2] = p1 if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4517,7 +4517,7 @@ local function recurse_node(root, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4543,13 +4543,13 @@ local function recurse_node(root, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast, xs) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4559,45 +4559,45 @@ local function recurse_node(root, end, ["macroexp"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast, xs) xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast, xs) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4606,7 +4606,7 @@ local function recurse_node(root, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4623,12 +4623,12 @@ local function recurse_node(root, end, ["newtype"] = function(ast, xs) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast, xs) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4647,7 +4647,7 @@ local function recurse_node(root, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4671,10 +4671,10 @@ local function recurse_node(root, local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4778,7 +4778,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local save_indent = {} - local function increment_indent(node) + local function increment_indent(_, node) local child = node.body or node[1] if not child then return @@ -4881,7 +4881,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_node.cbs = { ["statements"] = { - after = function(node, children) + after = function(_, node, children) local out if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4903,7 +4903,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4929,7 +4929,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4941,7 +4941,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4952,7 +4952,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4963,7 +4963,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["assignment"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4972,7 +4972,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["if"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -4983,7 +4983,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["if_block"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5003,7 +5003,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["while"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5016,7 +5016,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["repeat"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5028,7 +5028,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["do"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5039,7 +5039,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["forin"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5054,7 +5054,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["fornum"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5074,7 +5074,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["return"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5084,14 +5084,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["break"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } local space for i, child in ipairs(children) do @@ -5106,7 +5106,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["literal_table"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5126,7 +5126,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5149,13 +5149,13 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["local_macroexp"] = { before = increment_indent, - after = function(node, _children) + after = function(_, node, _children) return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5170,7 +5170,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["global_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5185,7 +5185,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["record_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5210,7 +5210,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5224,7 +5224,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["cast"] = {}, ["paren"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5233,7 +5233,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["op"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5294,14 +5294,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["variable"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } add_string(out, node.tk) return out end, }, ["newtype"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } local nt = node.newtype if nt.typename == "typealias" then @@ -5318,7 +5318,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["goto"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5326,7 +5326,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["label"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5339,7 +5339,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local visit_type = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ, _children) + after = function(_, typ, _children) local out = { y = typ.y or -1, h = 0 } local r = typ.typename == "nominal" and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5377,7 +5377,6 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] @@ -5392,7 +5391,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_node.cbs["argument"] = visit_node.cbs["variable"] visit_node.cbs["type_identifier"] = visit_node.cbs["variable"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5442,7 +5441,6 @@ local typename_to_typecode = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5450,8 +5448,8 @@ local typename_to_typecode = { local skip_types = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m) @@ -5474,6 +5472,7 @@ function tl.new_type_reporter() local self = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5481,6 +5480,24 @@ function tl.new_type_reporter() globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5500,9 +5517,15 @@ function TypeReporter:store_function(ti, rt) end function TypeReporter:get_typenum(t) + + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5526,7 +5549,7 @@ function TypeReporter:get_typenum(t) local ti = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5596,7 +5619,7 @@ end function TypeReporter:get_collector(filename) - local tc = { + local collector = { filename = filename, symbol_list = {}, } @@ -5604,10 +5627,10 @@ function TypeReporter:get_collector(filename) local ft = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y, x, typ) + collector.store_type = function(y, x, typ) if not typ or skip_types[typ.typename] then return end @@ -5621,12 +5644,12 @@ function TypeReporter:get_collector(filename) yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node) + collector.reserve_symbol_list_slot = function(node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node, name, t) + collector.add_to_symbol_list = function(node, name, t) if not node then return end @@ -5640,12 +5663,12 @@ function TypeReporter:get_collector(filename) symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node) + collector.begin_symbol_list_scope = function(node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node) + collector.end_symbol_list_scope = function(node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5655,14 +5678,14 @@ function TypeReporter:get_collector(filename) end end - return tc + return collector end -function TypeReporter:store_result(tc, globals) +function TypeReporter:store_result(collector, globals) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5731,143 +5754,445 @@ function TypeReporter:get_report() end -function tl.get_types(result) - return result.env.reporter:get_report(), result.env.reporter + + + + +function tl.symbols_in_scope(tr, y, x) + local function find(symbols, at_y, at_x) + local function le(a, b) + return a[1] < b[1] or + (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, { at_y, at_x }, le) or 0 + end + + local ret = {} + + local n = find(tr.symbols, y, x) + + local symbols = tr.symbols + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + + + + + +function Errors.new(filename) + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg, t1, t2, t3) + if t1 then + local s1, s2, s3 + if t1.typename == "invalid" then + return nil + end + s1 = show_type(t1) + if t2 then + if t2.typename == "invalid" then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3.typename == "invalid" then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self, y, x, err) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w, msg, ...) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx, name) + if not ctx then + return "" + end + local ec = (ctx.kind ~= nil) and ctx.expected_context + local cn = (type(ctx) == "string") and ctx or + (ctx.kind ~= nil) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w, ctx, msg, ...) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag, w, fmt, ...) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w, msg, ...) + self:add(w, msg, ...) + return a_type(w, "invalid", {}) +end + +function Errors:add_unknown(node, name) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node, old_var) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name, var) + local prefix = name:sub(1, 1) + if var.declared_at and + var.is_narrowed ~= "narrow" and + prefix ~= "_" and + prefix ~= "@" then + + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" or + t.typename == "function" and "function" or + t.typename == "typedecl" and "type" or + t.typename == "typealias" and "type" or + "variable", + name, + show_type(var.t)) + + end +end + +function Errors:add_prefixing(w, src, prefix, dst) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) or + (not err.y) or + (w.y > err.y or (w.y == err.y and w.x > err.x))) then + + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + + + + + + + + +local function check_for_unused_vars(scope, is_global) + local vars = scope.vars + if not next(vars) then + return + end + local list + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a, b) + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope, is_global) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end end +function Errors:add_unknown_dot(node, name) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end +function Errors:fail_unresolved_labels(scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end +function Errors:fail_unresolved_nominals(scope, global_scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {}) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }) +function Errors:check_redeclared_key(w, ctx, seen_keys, key) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end -local XPCALL_MSGH_FUNCTION = a_fn({ args = { ANY }, rets = {} }) local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", }, } local unop_types = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5898,67 +6223,66 @@ local binop_types = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", }, }, } @@ -6166,8 +6490,8 @@ local function show_type_base(t, short, seen) end end -local function inferred_msg(t) - return " (inferred at " .. t.inferred_at.filename .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" +local function inferred_msg(t, prefix) + return " (" .. (prefix or "") .. "inferred at " .. t.inferred_at.f .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" end show_type = function(t, short, seen) @@ -6219,28 +6543,29 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(module_name, lax, env) +local function require_module(w, module_name, feat_lax, env) local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$")) then + if found and (feat_lax or found:match("tl$")) then - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + env.module_filenames[module_name] = found + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) local found_result, err = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return a_type(w, "invalid", {}), found end local compat_code_cache = {} @@ -6262,7 +6587,7 @@ local function add_compat_entries(program, used_set, gen_compat) local code = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6301,32 +6626,26 @@ local function add_compat_entries(program, used_set, gen_compat) TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax) - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat() + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators = { @@ -6337,14 +6656,21 @@ local bit_operators = { ["<<"] = "lshift", } +local function node_at(w, n) + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node, mod_name, fn_name, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6353,10 +6679,10 @@ local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end @@ -6365,25 +6691,6 @@ local stdlib_globals = nil local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 -local function set_feat(feat, default) - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts) - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end - local function assert_no_stdlib_errors(errors, name) if #errors ~= 0 then local out = {} @@ -6394,46 +6701,31 @@ local function assert_no_stdlib_errors(errors, name) end end -tl.init_env = function(lax, gen_compat, gen_target, predefined) - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts) + opts = opts or {} local env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env, - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6442,21 +6734,20 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) local math_t = (stdlib_globals["math"].t).def local table_t = (stdlib_globals["table"].t).def - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true - stdlib_globals["..."] = { t = a_vararg({ STRING }) } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6467,53 +6758,40 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast, opts) - opts = opts or {} - local env = opts.env - if not env then - local err - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + + + + local TypeChecker = {} + + + + + + + + - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename - local st = { env.globals } - local all_needs_compat = {} - local dependencies = {} - local warnings = {} - local errors = {} - local module_type - local tc - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") - end @@ -6522,10 +6800,24 @@ tl.type_check = function(ast, opts) - local function find_var(name, use) - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + + + + + + + + + + + + + + + function TypeChecker:find_var(name, use) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6534,7 +6826,7 @@ tl.type_check = function(ast, opts) end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6547,10 +6839,10 @@ tl.type_check = function(ast, opts) end end - local function simulate_g() + function TypeChecker:simulate_g() local globals = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1, 1) ~= "@" then globals[k] = v.t end @@ -6564,100 +6856,60 @@ tl.type_check = function(ast, opts) end - local resolve_typevars + local typevar_resolver - local function fresh_typevar(t) - return a_type("typevar", { + local function fresh_typevar(_, t) + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function fresh_typearg(t) - return a_type("typearg", { + local function fresh_typearg(_, t) + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function ensure_fresh_typeargs(t) + function TypeChecker:ensure_fresh_typeargs(t) if not t.typeargs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name, use) - local var = find_var(name, use) + function TypeChecker:find_var_type(name, use) + local var = self:find_var(name, use) if var then local t = var.t if t.typename == "unresolved_typearg" then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where, msg, ...) - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(_tl_table_unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w, msg, ...) - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where, t) + local function ensure_not_abstract(t) if t.typename == "function" and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t.typename == "typedecl" then local def = t.def if def.typename == "interface" then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names, accept_typearg) - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names, accept_typearg) + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6679,7 +6931,7 @@ tl.type_check = function(ast, opts) return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ.typename == "nominal" and typ.found then typ = typ.found end @@ -6691,19 +6943,19 @@ tl.type_check = function(ast, opts) end end - local function union_type(t) + local function type_for_union(t) if t.typename == "typedecl" then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t.typename == "typealias" then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t.typename == "tuple" then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t.fields then if t.is_userdata then return "userdata", t @@ -6727,7 +6979,7 @@ tl.type_check = function(ast, opts) local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then assert(rt.fields) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6808,24 +7060,11 @@ tl.type_check = function(ast, opts) ["unknown"] = true, } - local function default_resolve_typevars_callback(t) - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt.typename == "string" then - - return STRING - end - return rt - end - - resolve_typevars = function(typ, fn_var, fn_arg) + typevar_resolver = function(self, typ, fn_var, fn_arg) local errs local seen = {} local resolved = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t, all_same) local same = true @@ -6840,7 +7079,7 @@ tl.type_check = function(ast, opts) local orig_t = t if t.typename == "typevar" then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then @@ -6856,7 +7095,7 @@ tl.type_check = function(ast, opts) seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6867,7 +7106,7 @@ tl.type_check = function(ast, opts) elseif t.typename == "typearg" then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy.typename == "typearg") copy.typearg = t.typearg @@ -6960,7 +7199,7 @@ tl.type_check = function(ast, opts) local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t.typename == "poly" then assert(copy.typename == "poly") @@ -6970,6 +7209,7 @@ tl.type_check = function(ast, opts) end elseif t.typename == "tupletable" then assert(copy.typename == "tupletable") + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6989,7 +7229,7 @@ tl.type_check = function(ast, opts) local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, a_type(typ, "invalid", {}), errs end if (not same) and @@ -7008,144 +7248,72 @@ tl.type_check = function(ast, opts) return true, copy end - local function infer_emptytable(emptytable, fresh_t) - local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st - for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } - end - end - end + local function resolve_typevar(tc, t) + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt.typename == "string" then - local function resolve_tuple(t) - if t.typename == "tuple" then - t = t.tuple[1] - end - if t == nil then - return NIL + return a_type(rt, "string", {}) end - return t - end - - local function add_warning(tag, where, fmt, ...) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where, msg, ...) - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node, name) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node, old_var) - if node.tk:sub(1, 1) == "_" then return end - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) + function TypeChecker:infer_emptytable(emptytable, fresh_t) + local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") + local nst = is_global and 1 or #self.st + for i = nst, 1, -1 do + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } + end end end - local function check_if_redeclaration(new_name, at) - local old = find_var(new_name, "check_only") - if old then - redeclaration_warning(at, old) + local function resolve_tuple(t) + local rt = t + if rt.typename == "tuple" then + rt = rt.tuple[1] end - end - - local function unused_warning(name, var) - local prefix = name:sub(1, 1) - if var.declared_at and - var.is_narrowed ~= "narrow" and - prefix ~= "_" and - prefix ~= "@" then - - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" or - t.typename == "function" and "function" or - t.typename == "typedecl" and "type" or - t.typename == "typealias" and "type" or - "variable", - name, - show_type(var.t)) - - end + if rt == nil then + return a_type(t, "nil", {}) end + return rt end - local function add_errs_prefixing(where, src, dst, prefix) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) or - (not err.y) or - (where.y > err.y or (where.y == err.y and where.x > err.x))) then - - err.y = where.y - err.x = where.x - err.filename = filename - end - table.insert(dst, err) + function TypeChecker:check_if_redeclaration(new_name, at) + local old = self:find_var(new_name, "check_only") + if old then + self.errs:redeclaration_warning(at, old) end end + local function type_at(w, t) t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where, t) - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w, t) + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where, t) - local ret = resolve_typevars_at(where, t) + function TypeChecker:infer_at(w, t) + local ret = self:resolve_typevars_at(w, t) if ret.typename == "invalid" then ret = t end @@ -7153,8 +7321,8 @@ tl.type_check = function(ast, opts) if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7167,12 +7335,9 @@ tl.type_check = function(ast, opts) return t end - local get_unresolved - local find_unresolved - - local function add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7185,11 +7350,11 @@ tl.type_check = function(ast, opts) var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7200,37 +7365,33 @@ tl.type_check = function(ast, opts) name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node, name, t, attribute, narrow, dont_check_redeclaration) - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node, name, t, attribute, narrow, dont_check_redeclaration) + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t.typename == "unresolved" or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var @@ -7238,8 +7399,6 @@ tl.type_check = function(ast, opts) - local same_type - local is_a @@ -7253,39 +7412,38 @@ tl.type_check = function(ast, opts) - - local function arg_check(where, all_errs, a, b, v, mode, n) + function TypeChecker:arg_check(w, all_errs, a, b, v, mode, n) local ok, errs if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s, t2s) + function TypeChecker:has_all_types_of(t1s, t2s) for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7317,8 +7475,8 @@ tl.type_check = function(ast, opts) end end - local function close_types(vars) - for _, var in pairs(vars) do + local function close_types(scope) + for _, var in pairs(scope.vars) do local t = var.t if t.typename == "typedecl" then t.closed = true @@ -7330,161 +7488,96 @@ tl.type_check = function(ast, opts) end end + function TypeChecker:begin_scope(node) + table.insert(self.st, { vars = {} }) - - - - - - - local function check_for_unused_vars(vars, is_global) - if not next(vars) then - return - end - local list = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a, b) - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope) - local unresolved - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t - else - unresolved = find_var_type("@unresolved") - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - }) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level) - local u = st[level or #st]["@unresolved"] - if u then - return u.t - end - end - - local function begin_scope(node) - table.insert(st, {}) - - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node) + function TypeChecker:end_scope(node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node, _children) - end_scope(node) + + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self, node, _children) + self:end_scope(node) return NONE end - local resolve_nominal - local resolve_typealias do - local function match_typevals(t, def) + local function match_typevals(self, t, def) if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t) + local function find_nominal_type_decl(self, t) if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found.typename == "typealias" then @@ -7492,8 +7585,7 @@ tl.type_check = function(ast, opts) end if not (found.typename == "typedecl") then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7508,44 +7600,35 @@ tl.type_check = function(ast, opts) return nil, found end - local function resolve_decl_into_nominal(t, found) + local function resolve_decl_into_nominal(self, t, found) local def = found.def local resolved if def.typename == "record" or def.typename == "function" then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t) - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias) + function TypeChecker:resolve_typealias(typealias) local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7554,90 +7637,92 @@ tl.type_check = function(ast, opts) return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved }) + local typedecl = a_type(typealias, "typedecl", { def = resolved }) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1, t2) - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self, t1, t2) + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1, t2) - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self, t1, t2) + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1, t2) - local same_names - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1, t2) + local same_names + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} end - return false, {} end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type - local function to_structural(t) + function TypeChecker:to_structural(t) assert(not (t.typename == "tuple")) if t.typename == "nominal" then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types, flatten_constants) + local function unite(w, types, flatten_constants) if #types == 1 then return types[1] end @@ -7648,7 +7733,6 @@ tl.type_check = function(ast, opts) local types_seen = {} - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7684,14 +7768,14 @@ tl.type_check = function(ast, opts) end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_type("union", { types = ts }) + return a_type(w, "union", { types = ts }) end end @@ -7711,21 +7795,20 @@ tl.type_check = function(ast, opts) end end - local expand_type - local function arraytype_from_tuple(where, tupletype) + function TypeChecker:arraytype_from_tuple(w, tupletype) - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not (element_type.typename == "union")) and true or is_valid_union(element_type) if valid then - return a_type("array", { elements = element_type }) + return a_type(w, "array", { elements = element_type }) end - local arr_type = a_type("array", { elements = tupletype.types[1] }) + local arr_type = a_type(w, "array", { elements = tupletype.types[1] }) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) + local expanded = self:expand_type(w, arr_type, a_type(w, "array", { elements = tupletype.types[i] })) if not (expanded.typename == "array") then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7736,33 +7819,33 @@ tl.type_check = function(ast, opts) return t.typename == "nominal" and t.names[1] == "@self" end - local function compare_true(_, _) + local function compare_true(_, _, _) return true end - local function subtype_nominal(a, b) + function TypeChecker:subtype_nominal(a, b) if is_self(a) and is_self(b) then return true end - local ra = a.typename == "nominal" and resolve_nominal(a) or a - local rb = b.typename == "nominal" and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a.typename == "nominal" and self:resolve_nominal(a) or a + local rb = b.typename == "nominal" and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false end return ok, errs end - local function subtype_array(a, b) - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a, b) + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7784,16 +7867,16 @@ tl.type_check = function(ast, opts) return nil end - local function subtype_record(a, b) + function TypeChecker:subtype_record(a, b) if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" or + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata"), } end @@ -7802,9 +7885,9 @@ tl.type_check = function(ast, opts) local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7818,32 +7901,32 @@ tl.type_check = function(ast, opts) return true end - local eqtype_record = function(a, b) + function TypeChecker:eqtype_record(a, b) if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak, bk, av, bv, no_hack) - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self, ak, bk, av, bv, no_hack) + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) if bk.typename == "any" and not no_hack then @@ -7873,25 +7956,25 @@ tl.type_check = function(ast, opts) return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar, a, b, cmp) + function TypeChecker:compare_or_infer_typevar(typevar, a, b, cmp) - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else local other = a or b if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then @@ -7899,22 +7982,22 @@ tl.type_check = function(ast, opts) end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r.typename == "typevar" and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end - local function exists_supertype_in(t, xs) + function TypeChecker:exists_supertype_in(t, xs) for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -7925,143 +8008,139 @@ tl.type_check = function(ast, opts) ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a, b) + ["interface"] = function(_self, _a, b) return not b.is_userdata end, - ["record"] = function(_a, b) + ["record"] = function(_self, _a, b) return not b.is_userdata end, } - - - local eqtype_relations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a, b) - return same_type(a.elements, b.elements) + ["array"] = function(self, a, b) + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a, b) - return (has_all_types_of(a.types, b.types) and - has_all_types_of(b.types, a.types)) + ["union"] = function(self, a, b) + return (self:has_all_types_of(a.types, b.types) and + self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a, b) + ["interface"] = function(_self, a, b) return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a, b) + ["tuple"] = function(self, a, b) local at, bt = a.tuple, b.tuple if #at ~= #bt then return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a, b) - return is_a(resolve_tuple(a), b) + ["*"] = function(self, a, b) + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a, b) + ["union"] = function(self, a, b) local used = {} for _, t in ipairs(a.types) do - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() if not u then return false end @@ -8070,13 +8149,13 @@ tl.type_check = function(ast, opts) end end for u, t in pairs(used) do - is_a(t, u) + self:is_a(t, u) end return true end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) for _, t in ipairs(a.types) do - if not is_a(t, b) then + if not self:is_a(t, b) then return false end end @@ -8084,212 +8163,212 @@ tl.type_check = function(ast, opts) end, }, ["poly"] = { - ["*"] = function(a, b) - if exists_supertype_in(b, a) then + ["*"] = function(self, a, b) + if self:exists_supertype_in(b, a) then return true end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a, b) - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self, a, b) + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb.typename == "interface" then - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra.typename == "union" or rb.typename == "union" then - return is_a(ra, rb) + return self:is_a(ra, rb) end return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a, b) + ["enum"] = function(_self, a, b) if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a, b) - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self, a, b) + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]), } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a, b) + ["record"] = function(self, a, b) if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a, b) + ["array"] = function(self, a, b) if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a, b) - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self, a, b) + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end if not a.declname then - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a, b) - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self, a, b) + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk.typename == "enum" and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self, a, b) if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a, b) - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a, b) - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self, a, b) + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a, b) + ["record"] = function(self, a, b) local def = a.def if def.fields then - return subtype_record(def, b) + return self:subtype_record(def, b) end end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8297,36 +8376,36 @@ a.types[i], b.types[i]), } end, }, ["typearg"] = { - ["typearg"] = function(a, b) + ["typearg"] = function(_self, a, b) return a.typearg == b.typearg end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a, b) - return is_a(a_type("tuple", { tuple = { a } }), b) + ["tuple"] = function(self, a, b) + return self:is_a(a_type(a, "tuple", { tuple = { a } }), b) end, - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a, b) + ["typearg"] = function(self, a, b) if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in, + ["union"] = TypeChecker.exists_supertype_in, - ["nominal"] = subtype_nominal, - ["poly"] = function(a, b) + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self, a, b) for _, t in ipairs(b.types) do - if not is_a(a, t) then - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + if not self:is_a(a, t) then + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8335,7 +8414,7 @@ a.types[i], b.types[i]), } } - local type_priorities = { + TypeChecker.type_priorities = { ["tuple"] = 2, ["typevar"] = 3, @@ -8364,19 +8443,7 @@ a.types[i], b.types[i]), } ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations, t1, t2) + local function compare_types(self, relations, t1, t2) if t1.typeid == t2.typeid then return true end @@ -8384,8 +8451,8 @@ a.types[i], b.types[i]), } local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8394,32 +8461,32 @@ a.types[i], b.types[i]), } if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end - is_a = function(t1, t2) - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1, t2) + return compare_types(self, self.subtype_relations, t1, t2) end - same_type = function(t1, t2) + function TypeChecker:same_type(t1, t2) - return compare_types(eqtype_relations, t1, t2) + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1, t2) + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self, t1, t2) assert(type(t1) == "table") assert(type(t2) == "table") @@ -8429,14 +8496,14 @@ a.types[i], b.types[i]), } return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where, t1, t2, context, name) + function TypeChecker:assert_is_a(w, t1, t2, ctx, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8444,24 +8511,27 @@ a.types[i], b.types[i]), } if t1.typename == "nil" then return true elseif t2.typename == "unresolved_emptytable_value" then - if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("array", { elements = t1 }))) + local t2keys = t2.emptytable_type.keys + if is_numeric_type(t2keys) then + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "array", { elements = t1 }))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("map", { keys = t2.emptytable_type.keys, values = t1 }))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "map", { keys = t2keys, values = t1 }))) end return true elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not (t1.typename == "emptytable") then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": " .. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8469,11 +8539,11 @@ a.types[i], b.types[i]), } if t.typename == "invalid" then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t.typename == "nominal" then - t = resolve_nominal(t) + t = assert(t.resolved) end if t.fields then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8491,36 +8561,27 @@ a.types[i], b.types[i]), } return definitely_not_closable_exprs[e.kind] end - local unknown_dots = {} - - local function add_unknown_dot(node, name) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end - end - - local function same_in_all_union_entries(u, check) + function TypeChecker:same_in_all_union_entries(u, check) local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u) - return same_in_all_union_entries(u, function(t) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u) + return self:same_in_all_union_entries(u, function(t) + t = self:to_structural(t) if t.fields then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt.typename == "function" then - local args_tuple = a_type("tuple", { tuple = {} }) + local args_tuple = a_type(u, "tuple", { tuple = {} }) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8530,20 +8591,21 @@ a.types[i], b.types[i]), } end) end - local function resolve_for_call(func, args, is_method) + function TypeChecker:resolve_for_call(func, args, is_method) - if lax and is_unknown(func) then - func = a_fn({ args = va_args({ UNKNOWN }), rets = va_args({ UNKNOWN }) }) + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then if func.typename == "union" then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) - return to_structural(r), true + return self:to_structural(r), true end end @@ -8557,7 +8619,7 @@ a.types[i], b.types[i]), } if func.fields and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8577,7 +8639,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["variable"] = { - after = function(node, _children) + after = function(_, node, _children) local i = argnames[node.tk] if not i then return nil @@ -8590,7 +8652,7 @@ a.types[i], b.types[i]), } after = on_node, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode, args, macroexp) @@ -8598,7 +8660,7 @@ a.types[i], b.types[i]), } return { Node, args[i] } end - local on_node = function(node, children, ret) + local on_node = function(_, node, children, ret) local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8627,12 +8689,12 @@ a.types[i], b.types[i]), } orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp) + function TypeChecker:check_macroexp_arg_use(macroexp) local used = {} local on_arg_id = function(node, _i) if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8655,18 +8717,15 @@ a.types[i], b.types[i]), } orignode.known = saveknown end - - - local type_check_function_call do - local function mark_invalid_typeargs(f) + local function mark_invalid_typeargs(self, f) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and a_type(a, "unknown", {}) or a_type(a, "unresolvable_typearg", { typearg = a.typearg, })) end @@ -8675,7 +8734,7 @@ a.types[i], b.types[i]), } end end - local function infer_emptytables(where, wheres, xs, ys, delta) + local function infer_emptytables(self, w, wheres, xs, ys, delta) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8685,9 +8744,9 @@ a.types[i], b.types[i]), } if x.typename == "emptytable" then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[i + delta] or where - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end @@ -8697,7 +8756,7 @@ a.types[i], b.types[i]), } local check_args_rets do - local function check_func_type_list(where, wheres, xs, ys, from, delta, v, mode) + local function check_func_type_list(self, w, wheres, xs, ys, from, delta, v, mode) assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8708,11 +8767,11 @@ a.types[i], b.types[i]), } for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8721,7 +8780,7 @@ a.types[i], b.types[i]), } return true end - check_args_rets = function(where, where_args, f, args, expected_rets, argdelta) + check_args_rets = function(self, w, where_args, f, args, expected_rets, argdelta) local rets_ok = true local rets_errs local args_ok @@ -8732,19 +8791,19 @@ a.types[i], b.types[i]), } if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8752,29 +8811,29 @@ a.types[i], b.types[i]), } - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func) + local function push_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, })) end end end - local function pop_typeargs(func) + local function pop_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8788,12 +8847,9 @@ a.types[i], b.types[i]), } end end - local function fail_call(where, func, nargs, errs) + local function fail_call(self, w, func, nargs, errs) if errs then - - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else local expects = {} @@ -8810,34 +8866,34 @@ a.types[i], b.types[i]), } else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) + local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func.typename == "function" or func.typename == "poly") then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func.typename == "function" or func.typename == "poly") then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_type("typedecl", { def = args.tuple[1] }))) + self:add_var(nil, "@self", a_type(w, "typedecl", { def = args.tuple[1] })) end local passes, n = 1, 1 @@ -8854,30 +8910,30 @@ a.types[i], b.types[i]), } local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) or (passes == 3 and ((pass == 1 and given == wanted) or - (pass == 2 and given < wanted and (lax or given >= min_arity)) or + (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then return matched, f @@ -8886,23 +8942,23 @@ a.types[i], b.types[i]), } if expected_rets then - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node, func, args, argdelta, e1, e2) + function TypeChecker:type_check_function_call(node, func, args, argdelta, e1, e2) e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8911,14 +8967,14 @@ a.types[i], b.types[i]), } if expected and expected.typename == "tuple" then expected_rets = expected else - expected_rets = a_type("tuple", { tuple = { node.expected } }) + expected_rets = a_type(node, "tuple", { tuple = { node.expected } }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver.typename == "nominal" then local resolved = receiver.resolved if resolved and resolved.typename == "typedecl" then @@ -8927,12 +8983,12 @@ a.types[i], b.types[i]), } end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -8943,9 +8999,9 @@ a.types[i], b.types[i]), } end end - local function check_metamethod(node, method_name, a, b, orig_a, orig_b) - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node, method_name, a, b, orig_a, orig_b) + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return a_type(node, "unknown", {}), nil end local ameta = a.fields and a.meta_fields local bmeta = b and b.fields and b.meta_fields @@ -8966,26 +9022,26 @@ a.types[i], b.types[i]), } if metamethod then local e2 = { node.e1 } - local args = a_type("tuple", { tuple = { orig_a } }) + local args = a_type(node, "tuple", { tuple = { orig_a } }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl, rec, key) + function TypeChecker:match_record_key(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl.typename == "string" or tbl.typename == "enum" then - tbl = find_var_type("string") + tbl = self:find_var_type("string") end if tbl.typename == "typedecl" then @@ -8994,13 +9050,13 @@ a.types[i], b.types[i]), } if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl.typename == "union" then - local t = same_in_all_union_entries(tbl, function(t) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9009,7 +9065,7 @@ a.types[i], b.types[i]), } end if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9023,7 +9079,8 @@ a.types[i], b.types[i]), } return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9034,8 +9091,8 @@ a.types[i], b.types[i]), } return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl.typename == "emptytable" or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return a_type(rec, "unknown", {}) end return nil, "cannot index a value of unknown type" end @@ -9047,30 +9104,35 @@ a.types[i], b.types[i]), } end end - local function widen_in_scope(scope, var) - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope, var) + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name) + function TypeChecker:widen_back_var(name) local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9084,7 +9146,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["assignment"] = { - after = function(node, _children) + after = function(_, node, _children) for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9094,7 +9156,7 @@ a.types[i], b.types[i]), } end, }, }, - after = function(_node, children, ret) + after = function(_, _node, children, ret) ret = ret or false for _, c in ipairs(children) do local ca = c @@ -9112,118 +9174,82 @@ a.types[i], b.types[i]), } end, } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node, var, valtype, is_assigning) - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node, varname, valtype, is_assigning) + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - local get_rets - if lax then - get_rets = function(rets) - if #rets.tuple == 0 then - return a_vararg({ UNKNOWN }) - end - return rets - end - else - get_rets = function(rets) - return rets - end + return var end - local function add_internal_function_variables(node, args) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type("tuple", { tuple = {} })) + function TypeChecker:add_internal_function_variables(node, args) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_type(node, "tuple", { tuple = {} })) if node.typeargs then for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") + local v = self:find_var(t.typearg, "check_only") if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) - end - end - end - end - - local function add_function_definition_for_recursion(node, fnargs) - add_var(nil, node.name.tk, type_at(node, a_function({ - min_arity = node.min_arity, - typeargs = node.typeargs, - args = fnargs, - rets = get_rets(node.rets), - }))) - end - - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function end_function_scope(node) - fail_unresolved() - end_scope(node) + function TypeChecker:add_function_definition_for_recursion(node, fnargs) + self:add_var(nil, node.name.tk, a_function(node, { + min_arity = node.min_arity, + typeargs = node.typeargs, + args = fnargs, + rets = self.get_rets(node.rets), + })) + end + + function TypeChecker:end_function_scope(node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals) local vt = vals.tuple local n_vals = #vt - local ret = a_type("tuple", { tuple = {} }) + local ret = a_type(vals, "tuple", { tuple = {} }) local rt = ret.tuple if n_vals == 0 then @@ -9251,9 +9277,9 @@ a.types[i], b.types[i]), } return ret end - local function get_assignment_values(vals, wanted) + local function get_assignment_values(w, vals, wanted) if vals == nil then - return a_type("tuple", { tuple = {} }) + return a_type(w, "tuple", { tuple = {} }) end local ret = flatten_tuple(vals) @@ -9272,14 +9298,14 @@ a.types[i], b.types[i]), } return ret end - local function match_all_record_field_names(node, a, field_names, errmsg) + function TypeChecker:match_all_record_field_names(node, a, field_names, errmsg) local t for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9289,26 +9315,26 @@ a.types[i], b.types[i]), } if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode, bnode, a, b) + function TypeChecker:type_check_index(anode, bnode, a, b) assert(not (a.typename == "tuple")) assert(not (b.typename == "tuple")) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm local erra local errb - if ra.typename == "tupletable" and is_a(rb, INTEGER) then + if ra.typename == "tupletable" and rb.typename == "integer" then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum] @@ -9316,38 +9342,35 @@ a.types[i], b.types[i]), } errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra.elements and is_a(rb, INTEGER) then + elseif ra.elements and rb.typename == "integer" then return ra.elements elseif ra.typename == "emptytable" then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra, - })) + }) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. - ra.keys.inferred_at.filename .. ":" .. - ra.keys.inferred_at.y .. ":" .. - ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra.typename == "map" then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb.typename == "string" and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9363,10 +9386,10 @@ a.types[i], b.types[i]), } end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb.typename == "string" then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9375,28 +9398,28 @@ a.types[i], b.types[i]), } errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where, old, new) + function TypeChecker:expand_type(w, old, new) if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old.typename == "map" and new.fields then local old_keys = old.keys if old_keys.typename == "string" then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") + edit_type(w, old, "map") else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old.fields and new.fields then local values @@ -9404,14 +9427,14 @@ a.types[i], b.types[i]), } if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9419,25 +9442,25 @@ a.types[i], b.types[i]), } old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old.typename == "union" then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp) + function TypeChecker:find_record_to_extend(exp) if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9454,7 +9477,7 @@ a.types[i], b.types[i]), } return t, v, exp.tk elseif exp.kind == "op" then - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9475,30 +9498,29 @@ a.types[i], b.types[i]), } end end - local function typedecl_to_nominal(where, name, t, resolved) + local function typedecl_to_nominal(node, name, t, resolved) local typevals local def = t.def if def.typeargs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, })) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - })) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp) + function TypeChecker:get_self_type(exp) if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9510,7 +9532,7 @@ a.types[i], b.types[i]), } end elseif exp.kind == "op" then - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9542,7 +9564,6 @@ a.types[i], b.types[i]), } local facts_and local facts_or local facts_not - local apply_facts local FACT_TRUTHY do local IsFact_mt = { @@ -9554,6 +9575,7 @@ a.types[i], b.types[i]), } setmetatable(IsFact, { __call = function(_, fact) fact.fact = "is" + assert(fact.w) return setmetatable(fact, IsFact_mt) end, }) @@ -9567,6 +9589,7 @@ a.types[i], b.types[i]), } setmetatable(EqFact, { __call = function(_, fact) fact.fact = "==" + assert(fact.w) return setmetatable(fact, EqFact_mt) end, }) @@ -9625,57 +9648,57 @@ a.types[i], b.types[i]), } FACT_TRUTHY = TruthyFact({}) - facts_and = function(where, f1, f2) - return AndFact({ f1 = f1, f2 = f2, where = where }) + facts_and = function(w, f1, f2) + return AndFact({ f1 = f1, f2 = f2, w = w }) end - facts_or = function(where, f1, f2) + facts_or = function(w, f1, f2) if f1 and f2 then - return OrFact({ f1 = f1, f2 = f2, where = where }) + return OrFact({ f1 = f1, f2 = f2, w = w }) else return nil end end - facts_not = function(where, f1) + facts_not = function(w, f1) if f1 then - return NotFact({ f1 = f1, where = where }) + return NotFact({ f1 = f1, w = w }) else return nil end end - local function unite_types(t1, t2) - return unite({ t2, t1 }) + local function unite_types(w, t1, t2) + return unite(w, { t2, t1 }) end - local function intersect_types(t1, t2) + local function intersect_types(self, w, t1, t2) if t2.typename == "union" then t1, t2 = t2, t1 end if t1.typename == "union" then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL + return a_type(w, "nil", {}) end end end - local function resolve_if_union(t) - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t) + local rt = self:to_structural(t) if rt.typename == "union" then return rt end @@ -9683,23 +9706,23 @@ a.types[i], b.types[i]), } end - local function subtract_types(t1, t2) + local function subtract_types(self, w, t1, t2) local types = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) if not (t1.typename == "union") then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2.typename == "union" and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9710,10 +9733,10 @@ a.types[i], b.types[i]), } end if #types == 0 then - return NIL + return a_type(w, "nil", {}) end - return unite(types) + return unite(w, types) end local eval_not @@ -9723,65 +9746,65 @@ a.types[i], b.types[i]), } local eval_fact local function invalid_from(f) - return IsFact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) + return IsFact({ fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w }) end - not_facts = function(fs) + not_facts = function(self, fs) local ret = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) elseif f.fact == "==" then - ret[var] = EqFact({ var = var, typ = typ }) + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) elseif typ.typename == "typevar" then assert(f.fact == "is") - ret[var] = EqFact({ var = var, typ = typ }) - elseif not is_a(f.typ, typ) then + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) else assert(f.fact == "is") - ret[var] = IsFact({ var = var, typ = subtract_types(typ, f.typ), where = f.where }) + ret[var] = IsFact({ var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer }) end end return ret end - eval_not = function(f) + eval_not = function(self, f) if not f then return {} elseif f.fact == "is" then - return not_facts({ [f.var] = f }) + return not_facts(self, { [f.var] = f }) elseif f.fact == "not" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f.fact == "or" then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1, fs2) + or_facts = function(_self, fs1, fs2) local ret = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact({ var = var, typ = united, where = f.where }) + ret[var] = IsFact({ var = var, typ = united, w = f.w }) else - ret[var] = EqFact({ var = var, typ = united, where = f.where }) + ret[var] = EqFact({ var = var, typ = united, w = f.w }) end end end @@ -9789,7 +9812,7 @@ a.types[i], b.types[i]), } return ret end - and_facts = function(fs1, fs2) + and_facts = function(self, fs1, fs2) local ret = {} local has = {} @@ -9800,18 +9823,18 @@ a.types[i], b.types[i]), } if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor({ var = var, typ = rt, where = f.where }) + local ff = ctor({ var = var, typ = rt, w = f.w, no_infer = f.no_infer }) ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact({ var = var, typ = f.typ, where = f.where }) + ret[var] = EqFact({ var = var, typ = f.typ, w = f.w, no_infer = f.no_infer }) has["=="] = true end end @@ -9825,21 +9848,21 @@ a.types[i], b.types[i]), } return ret end - eval_fact = function(f) + eval_fact = function(self, f) if not f then return {} elseif f.fact == "is" then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9847,63 +9870,60 @@ a.types[i], b.types[i]), } elseif f.fact == "==" then return { [f.var] = f } elseif f.fact == "not" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "truthy" then return {} elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "and" then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f.fact == "or" then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where, known) + function TypeChecker:apply_facts(w, known) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall - - local function special_pcall_xpcall(node, _a, b, argdelta) + local function special_pcall_xpcall(self, node, _a, b, argdelta) local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_type("tuple", { tuple = { BOOLEAN } }) + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_type(node, "tuple", { tuple = { bool } }) end @@ -9915,137 +9935,142 @@ a.types[i], b.types[i]), } ftype.is_method = false end - local fe2 = {} + local fe2 = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), + rets = a_type(arg2, "tuple", { tuple = {} }), + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode = { - y = node.y, - x = node.x, + local fnode = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename == "invalid" then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end local special_functions = { - ["pairs"] = function(node, a, b, argdelta) + ["pairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t.elements then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t.fields then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node, a, b, argdelta) + ["ipairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t.typename == "tupletable" then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t.elements then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t.typename == "emptytable")) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node, _a, b, _argdelta) + ["rawget"] = function(self, node, _a, b, _argdelta) if #b.tuple == 2 then - return a_type("tuple", { tuple = { type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) + return a_type(node, "tuple", { tuple = { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node, _a, b, _argdelta) + ["require"] = function(self, node, _a, b, _argdelta) if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return invalid_at(node, "don't know how to resolve a dynamic require") + return self.errs:invalid_at(node, "don't know how to resolve a dynamic require") end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_type("tuple", { tuple = { UNKNOWN } }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") + end + + if self.feat_lax then + return a_type(node, "tuple", { tuple = { a_type(node, "unknown", {}) } }) end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_type("tuple", { tuple = { t } })) + self.dependencies[module_name] = module_filename + return a_type(node, "tuple", { tuple = { t } }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node, a, b, argdelta) + ["assert"] = function(self, node, a, b, argdelta) node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node, a, b, argdelta) + function TypeChecker:type_check_funcall(node, a, b, argdelta) argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10057,19 +10082,19 @@ a.types[i], b.types[i]), } node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node, i, name) - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node, i, name) + if self.feat_lax then + return a_type(node, "unknown", {}) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node, children) + local function set_expected_types_to_decltuple(_, node, children) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple.typename == "tuple") local decls = decltuple.tuple @@ -10081,7 +10106,7 @@ a.types[i], b.types[i]), } typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_type("tuple", { tuple = {} })) + typ = a_type(node, "tuple", { tuple = {} }) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10097,38 +10122,7 @@ a.types[i], b.types[i]), } return n and n >= 1 and math.floor(n) == n end - local context_name = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx, msg) - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - - - local function check_redeclared_key(where, ctx, seen_keys, key) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node, children) + local function infer_table_literal(self, node, children) local is_record = false local is_array = false local is_map = false @@ -10153,14 +10147,15 @@ a.types[i], b.types[i]), } for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end local key = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10171,7 +10166,7 @@ a.types[i], b.types[i]), } end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif is_numeric_type(cktype) then is_array = true if not is_not_tuple then is_tuple = true @@ -10185,25 +10180,25 @@ a.types[i], b.types[i]), } if i == #children and cv.typename == "tuple" then for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n] = uvtype if n > largest_array_idx then largest_array_idx = n end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10215,37 +10210,37 @@ a.types[i], b.types[i]), } end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_type("map", { keys = -expand_type(node, keys, INTEGER), values = + self.errs:add(node, "cannot determine type of table literal") + t = a_type(node, "map", { keys = +self:expand_type(node, keys, a_type(node, "integer", {})), values = -expand_type(node, values, elements) }) +self:expand_type(node, values, elements) }) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, a_type("array", { elements = elements })), + a_type(node, "array", { elements = elements }), }, }) elseif is_record and is_map then if keys.typename == "string" then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10253,7 +10248,7 @@ expand_type(node, values, elements) }) local last_t for _, current_t in pairs(types) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10262,69 +10257,70 @@ expand_type(node, values, elements) }) end end if pure_array then - t = a_type("array", { elements = elements }) + t = a_type(node, "array", { elements = elements }) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, }) elseif is_map then - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) elseif is_tuple then - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where, ifnode, n) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w, ifnode, n) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var, node, infertypes, i) + function TypeChecker:determine_declaration_type(var, node, infertypes, i) local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype.typename == "invalid" then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype.typename == "unresolvable_typearg" then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype.typename == "function" and infertype.is_method then @@ -10336,17 +10332,17 @@ expand_type(node, values, elements) }) end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10354,12 +10350,12 @@ expand_type(node, values, elements) }) if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri.typename == "map" then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri.typename == "record" then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10369,34 +10365,36 @@ expand_type(node, values, elements) }) local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = name elseif t.elements then t.inferred_len = nil + elseif t.typename == "nominal" then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value) + function TypeChecker:get_typedecl(value) if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) local ty = t.typename == "tuple" and t.tuple[1] or t - ty = (ty.typename == "typealias") and resolve_typealias(ty) or ty - local td = (ty.typename == "typedecl") and ty or a_type("typedecl", { def = ty }) + ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty + local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td else local newtype = value.newtype if newtype.typename == "typealias" then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype.typename == "typedecl" then return newtype, nil end end @@ -10427,15 +10425,14 @@ expand_type(node, values, elements) }) return is_total, missing end - local function total_map_check(t, seen_keys) - local k = to_structural(t.keys) + local function total_map_check(keys, seen_keys) local is_total = true local missing - if k.typename == "enum" then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys.typename == "enum" then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10449,35 +10446,38 @@ expand_type(node, values, elements) }) - local function check_assignment(where, vartype, valtype, varname, attr) + function TypeChecker:check_assignment(varnode, vartype, valtype) + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var.typename == "typedecl" or var.typename == "typealias" then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10493,181 +10493,182 @@ expand_type(node, values, elements) }) visit_node.cbs = { ["statements"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - after = function(node, _children) + after = function(self, node, _children) - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end return NONE end, }, ["local_type"] = { - before = function(node) + before = function(self, node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node) + before = function(self, node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node) - if tc then + before = function(self, node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not (rt.typename == "enum")) and ((not (t.typename == "nominal")) or (rt.typename == "union")) and - not same_type(t, infertype) then + not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local vartuple = children[1] assert(vartuple.typename == "tuple") local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple.typename == "tuple") - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then if rval.typename == "function" then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar.typename == "union" or rvar.typename == "interface") then - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10676,7 +10677,7 @@ expand_type(node, values, elements) }) end, }, ["if"] = { - after = function(node, _children) + after = function(self, node, _children) local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10686,26 +10687,26 @@ expand_type(node, values, elements) }) end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node) + before_statements = function(self, node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node, _children) - end_scope(node) + after = function(self, node, _children) + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10715,76 +10716,96 @@ expand_type(node, values, elements) }) end, }, ["while"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self, node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node) - - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + before = function(self, node) + + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + + end + end, after = function() return NONE end, }, ["goto"] = { - after = function(node, _children) - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self, node, _children) + local label_id = node.label + local found_label + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, after = end_scope_and_none_type, }, ["forin"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local exptuple = children[2] assert(exptuple.typename == "tuple") local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_type("tuple", { tuple = { + local args = a_type(node.exps, "tuple", { tuple = { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], } }) - local exp1type = resolve_for_call(exptypes[1], args, false) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then local _ - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) end if exp1type.typename == "function" then @@ -10797,69 +10818,69 @@ expand_type(node, values, elements) }) if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and a_type(v, "unknown", {}) or a_type(v, "invalid", {}) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node, children) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and + before_statements = function(self, node, children) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename = (from_t.typename == "integer" and to_t.typename == "integer" and (not step_t or step_t.typename == "integer")) and - INTEGER or - NUMBER - add_var(node.var, node.var.tk, t) + "integer" or + "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node) - local rets = find_var_type("@return") + before = function(self, node) + local rets = self:find_var_type("@return") if rets and rets.typename == "tuple" then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node, children) + after = function(self, node, children) local got = children[1] assert(got.typename == "tuple") local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") + local expected = self:find_var_type("@return") if not expected then - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10874,8 +10895,8 @@ expand_type(node, values, elements) }) vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 and @@ -10883,18 +10904,18 @@ expand_type(node, values, elements) }) node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) and + local w = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10902,25 +10923,28 @@ expand_type(node, values, elements) }) end, }, ["variable_list"] = { - after = function(node, children) - local tuple = a_type("tuple", { tuple = children }) + after = function(self, node, children) + local tuple = a_type(node, "tuple", { tuple = children }) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node) + before = function(self, node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype.typename == "typevar" and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -10952,19 +10976,19 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) + after = function(self, node, children) node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint if decltype.typename == "typevar" and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype.typename == "union" then @@ -10972,7 +10996,7 @@ expand_type(node, values, elements) }) local single_table_rt for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then @@ -10993,7 +11017,7 @@ expand_type(node, values, elements) }) end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array = nil @@ -11003,73 +11027,75 @@ expand_type(node, values, elements) }) for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype.fields and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df.typename == "typedecl" or df.typename == "typealias" then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype.typename == "tupletable" and is_number_type(child.ktype) then + elseif decltype.typename == "tupletable" and is_numeric_type(cktype) then local dt = decltype.types[n] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype.elements and is_number_type(child.ktype) then + elseif decltype.elements and is_numeric_type(cktype) then local cv = child.vtype if cv.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype.typename == "map" then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype.typename == "map" then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t if force_array then - t = infer_at(node, a_type("array", { elements = force_array })) + t = self:infer_at(node, a_type(node, "array", { elements = force_array })) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype.typename == "record" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "record" then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype.typename == "map" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "map" then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11081,13 +11107,13 @@ expand_type(node, values, elements) }) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(self, node, children) local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype.typename == "function" and vtype.is_method then @@ -11096,210 +11122,210 @@ expand_type(node, values, elements) }) vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - })) + }) end, }, ["local_function"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, - }))) + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ.typename == "function" then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function({ + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - })))) + rets = self.get_rets(rets), + }))) return NONE end, }, ["record_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node, children) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self, _node, children) + local rtype = self:to_structural(resolve_typedecl(children[1])) if rtype.fields and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[3] assert(args.typename == "tuple") local rets = children[4] assert(rets.typename == "tuple") - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype.typename == "unknown" then return end if rtype.typename == "emptytable" then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype r.fields = {} r.field_order = {} end if not rtype.fields then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function({ + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype.typename == "poly" then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11312,82 +11338,82 @@ expand_type(node, values, elements) }) open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, _children) - end_function_scope(node) + after = function(self, node, _children) + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["macroexp"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node, children) + before_exp = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["cast"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.casttype end, }, ["paren"] = { - before = function(node) + before = function(_self, node) node.e1.expected = node.expected end, - after = function(node, children) + after = function(_self, node, children) node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node) - begin_scope() + before = function(self, node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11399,18 +11425,19 @@ expand_type(node, values, elements) }) end end end, - before_e2 = function(node, children) + before_e2 = function(self, node, children) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11433,8 +11460,8 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) - end_scope() + after = function(self, node, children) + self:end_scope() local ga = children[1] @@ -11445,29 +11472,34 @@ expand_type(node, values, elements) }) local ub - local ra = to_structural(ua) + local ra = self:to_structural(ua) local rb if ra.typename == "circular_require" or (ra.typename == "typedecl" and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb.typename == "tuple") + assert(node.f) + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra.typename == "typedecl" and ra.def.typename == "record" then ra = ra.def end @@ -11476,8 +11508,11 @@ expand_type(node, values, elements) }) if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb.typename == "typedecl" and rb.def.typename == "record" then rb = rb.def end @@ -11487,22 +11522,20 @@ expand_type(node, values, elements) }) node.receiver = ua assert(node.e2.kind == "identifier") - local bnode = { - y = node.e2.y, - x = node.e2.x, + local bnode = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk })) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk }) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11510,22 +11543,22 @@ expand_type(node, values, elements) }) end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra.typename == "typedecl" then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact({ var = node.e1.tk, typ = ub, where = node }) + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact({ var = node.e1.tk, typ = ub, w = node }) else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11533,16 +11566,16 @@ expand_type(node, values, elements) }) - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua.typename == "typevar") then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return a_type(node, "unknown", {}) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11550,7 +11583,7 @@ expand_type(node, values, elements) }) if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11568,33 +11601,33 @@ expand_type(node, values, elements) }) node.known = nil t = ua - elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or - (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + elseif ((ra.typename == "enum" and rb.typename == "string" and self:is_a(rb, ra)) or + (ra.typename == "string" and rb.typename == "enum" and self:is_a(ra, rb))) then node.known = nil t = (ra.typename == "enum" and ra or rb) elseif expected and expected.typename == "union" then node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ ra, rb }, true) + local u = unite(node, { ra, rb }, true) if u.typename == "union" then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or a_type(node, "invalid", {}) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11618,39 +11651,41 @@ expand_type(node, values, elements) }) if ra.typename == "enum" and rb.typename == "string" then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact({ var = node.e1.tk, typ = ub, where = node }) + node.known = EqFact({ var = node.e1.tk, typ = ub, w = node }) end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact({ var = node.e2.tk, typ = ua, where = node }) + node.known = EqFact({ var = node.e2.tk, typ = ua, w = node }) end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return a_type(node, "unknown", {}) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(node, ra.types, true) end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra.fields then t = find_in_interface_list(ra, function(ty) - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11658,19 +11693,18 @@ expand_type(node, values, elements) }) if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra.typename == "map" then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11678,12 +11712,12 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11697,39 +11731,39 @@ expand_type(node, values, elements) }) end if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(ra, ra.types, true) end if rb.typename == "union" then - rb = unite(rb.types, true) + rb = unite(rb, rb.types, true) end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ ua, ub }) + local u = unite(node, { ua, ub }) if u.typename == "union" and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua.typename == "nominal" and ub.typename == "nominal" and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11737,20 +11771,20 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11762,28 +11796,28 @@ expand_type(node, values, elements) }) end, }, ["variable"] = { - after = function(node, _children) + after = function(self, node, _children) if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t.typename == "typedecl" then @@ -11794,70 +11828,70 @@ expand_type(node, values, elements) }) end, }, ["type_identifier"] = { - after = function(node, _children) - local typ, attr = find_var_type(node.tk) + after = function(self, node, _children) + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node, children) + after = function(self, node, children) local t = children[1] if not t then - t = UNKNOWN + t = a_type(node, "unknown", {}) end if node.tk == "..." then - t = a_vararg({ t }) + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, }, ["newtype"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.newtype end, }, ["error_node"] = { - after = function(_node, _children) - return INVALID + after = function(_self, node, _children) + return a_type(node, "invalid", {}) end, }, } visit_node.cbs["break"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node) + local function after_literal(_self, node) node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind, {})) + return a_type(node, node.kind, {}) end visit_node.cbs["string"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected.typename == "enum" and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected.typename == "enum" and self:is_a(t, expected) then return node.expected end @@ -11868,8 +11902,8 @@ expand_type(node, values, elements) }) visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11880,7 +11914,7 @@ expand_type(node, values, elements) }) visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node, _children, t) + visit_node.after = function(_self, node, _children, t) if node.expanded then apply_macroexp(node) end @@ -11888,13 +11922,12 @@ expand_type(node, values, elements) }) return t end - local expand_interfaces do - local function add_interface_fields(what, fields, field_order, resolved, named, list) + local function add_interface_fields(self, what, fields, field_order, resolved, named, list) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11903,18 +11936,21 @@ expand_type(node, values, elements) }) end end - local function collect_interfaces(list, t, seen) + local function collect_interfaces(self, list, t, seen) if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri.typename == "interface", "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri.typename == "interface" then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -11927,30 +11963,30 @@ expand_type(node, values, elements) }) return list end - expand_interfaces = function(t) + function TypeChecker:expand_interfaces(t) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri.typename == "interface") - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -11962,29 +11998,29 @@ expand_type(node, values, elements) }) visit_type = { cbs = { ["function"] = { - before = function(_typ) - begin_scope() + before = function(self, _typ) + self:begin_scope() end, - after = function(typ, _children) - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self, typ, _children) + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ) - begin_scope() - add_var(nil, "@self", type_at(typ, a_type("typedecl", { def = typ }))) + before = function(self, typ) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) for fname, ftype in fields_of(typ) do if ftype.typename == "typealias" then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype.typename == "typedecl" then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ, children) + after = function(self, typ, children) local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11998,11 +12034,11 @@ expand_type(node, values, elements) }) if iface.typename == "array" then typ.interface_list[j] = iface elseif iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri.typename == "interface" then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12042,7 +12078,7 @@ expand_type(node, values, elements) }) end end elseif ftype.typename == "typealias" then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12061,55 +12097,55 @@ expand_type(node, values, elements) }) end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ, _children) - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self, typ, _children) + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) return typ end, }, ["typevar"] = { - after = function(typ, _children) - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self, typ, _children) + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ, _children) + after = function(self, typ, _children) if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t.typename == "typearg" then typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint @@ -12120,18 +12156,19 @@ expand_type(node, values, elements) }) end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ, _children) + after = function(self, typ, _children) local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or a_type(typ, "invalid", {}) end return typ end, @@ -12139,15 +12176,47 @@ expand_type(node, values, elements) }) }, } + local default_type_visitor = { + after = function(_self, typ, _children) + return typ + end, + } + + visit_type.cbs["interface"] = visit_type.cbs["record"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor + + + local function internal_compiler_check(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(s, n, children, t) + t = fn and fn(s, n, children, t) or t if type(t) ~= "table" then - error(((w).kind or (w).typename) .. " did not produce a type") + error(((n).kind or (n).typename) .. " did not produce a type") end if type(t.typename) ~= "string" then - error(((w).kind or (w).typename) .. " type does not have a typename") + error(((n).kind or (n).typename) .. " type does not have a typename") end return t @@ -12155,13 +12224,13 @@ expand_type(node, values, elements) }) end local function store_type_after(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(self, n, children, t) + t = fn and fn(self, n, children, t) or t - local where = w + local w = n - if where.y then - tc.store_type(where.y, where.x, t) + if w.y then + self.collector.store_type(w.y, w.x, t) end return t @@ -12169,119 +12238,167 @@ expand_type(node, values, elements) }) end local function debug_type_after(fn) - return function(node, children, t) - t = fn and fn(node, children, t) or t + return function(s, node, children, t) + t = fn and fn(s, node, children, t) or t + node.debug_type = t return t end end - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end + local function patch_visitors(my_visit_node, + after_node, + my_visit_type, + after_type) - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) + + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) + end + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) + local function set_feat(feat, default) + if feat then + return (feat == "on") + else + return default + end end - local default_type_visitor = { - after = function(typ, _children) - return typ - end, - } + tl.type_check = function(ast, filename, opts, env) + assert(type(filename) == "string", "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts).env), "tl.type_check signature has changed, pass env separately") - visit_type.cbs["interface"] = visit_type.cbs["record"] + filename = filename or "?" - visit_type.cbs["string"] = default_type_visitor - visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typedecl"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor - visit_type.cbs["array"] = default_type_visitor - visit_type.cbs["map"] = default_type_visitor - visit_type.cbs["enum"] = default_type_visitor - visit_type.cbs["boolean"] = default_type_visitor - visit_type.cbs["nil"] = default_type_visitor - visit_type.cbs["number"] = default_type_visitor - visit_type.cbs["integer"] = default_type_visitor - visit_type.cbs["thread"] = default_type_visitor - visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["literal_table_item"] = default_type_visitor - visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor - visit_type.cbs["tuple"] = default_type_visitor - visit_type.cbs["poly"] = default_type_visitor - visit_type.cbs["any"] = default_type_visitor - visit_type.cbs["unknown"] = default_type_visitor - visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor - visit_type.cbs["none"] = default_type_visitor + opts = opts or {} + + if not env then + local err + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + local self = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } - close_types(st[1]) - check_for_unused_vars(st[1], true) + setmetatable(self, { __index = TypeChecker }) - clear_redundant_errors(errors) + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET - add_compat_entries(ast, all_needs_compat, env.gen_compat) + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + self.subtype_relations = shallow_copy_table(self.subtype_relations) - if tc then - env.reporter:store_result(tc, env.globals) - end + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true - return result -end + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true + + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets) + if #rets.tuple == 0 then + return a_vararg(rets, { a_type(rets, "unknown", {}) }) + end + return rets + end + else + self.get_rets = function(rets) + return rets + end + end + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after) -function tl.symbols_in_scope(tr, y, x) - local function find(symbols, at_y, at_x) - local function le(a, b) - return a[1] < b[1] or - (a[1] == b[1] and a[2] <= b[2]) end - return binary_search(symbols, { at_y, at_x }, le) or 0 - end - local ret = {} + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) - local n = find(tr.symbols, y, x) + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] - else - ret[s[3]] = s[4] - n = n - 1 + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) end - end - return ret + return result + end end @@ -12297,9 +12414,24 @@ local function read_full_file(fd) return content, err end -tl.process = function(filename, env, fd) - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename, input) + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename, env, fd) if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12319,23 +12451,38 @@ tl.process = function(filename, env, fd) return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") +function tl.target_from_lua_version(str) + if str == "Lua 5.1" or + str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime, filename, input) + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + }, + } end -function tl.process_string(input, is_lua, env, filename) - env = env or tl.init_env(is_lua) +function tl.process_string(input, env, filename) + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12347,7 +12494,7 @@ function tl.process_string(input, is_lua, env, filename) local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12357,14 +12504,7 @@ function tl.process_string(input, is_lua, env, filename) return result end - local opts = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12372,15 +12512,15 @@ function tl.process_string(input, is_lua, env, filename) end tl.gen = function(input, env, pp) - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12396,28 +12536,25 @@ local function tl_package_loader(module_name) if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname, loader_data) @@ -12443,21 +12580,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str) - if str == "Lua 5.1" or - str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax, env_tbl) +local function env_for(opts, env_tbl) if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12466,7 +12592,7 @@ local function env_for(lax, env_tbl) tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12476,17 +12602,14 @@ tl.load = function(input, chunkname, mode, ...) return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12500,7 +12623,7 @@ tl.load = function(input, chunkname, mode, ...) mode = mode:gsub("c", "") end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12508,4 +12631,29 @@ tl.load = function(input, chunkname, mode, ...) return load(code, chunkname, mode, ...) end + + + + +function tl.get_types(result) + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax, gen_compat, gen_target, predefined) + local opts = { + defaults = { + feat_lax = (lax and "on" or "off"), + gen_compat = ((type(gen_compat) == "string") and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl diff --git a/tl.tl b/tl.tl index a8b612ec6..00400e2ab 100644 --- a/tl.tl +++ b/tl.tl @@ -476,9 +476,16 @@ end ]=====] local interface Where + f: string y: integer x: integer +end + +local record Errors filename: string + errors: {Error} + warnings: {Error} + unknown_dots: {string:boolean} end local record tl @@ -492,13 +499,13 @@ local record tl end type LoadFunction = function(...:any): any... - enum CompatMode + enum GenCompat "off" "optional" "required" end - enum TargetMode + enum GenTarget "5.1" "5.3" "5.4" @@ -516,25 +523,23 @@ local record tl end record TypeCheckOptions - lax: boolean - filename: string - gen_compat: CompatMode - gen_target: TargetMode - env: Env + feat_lax: Feat + feat_arity: Feat + gen_compat: GenCompat + gen_target: GenTarget run_internal_compiler_checks: boolean end record Env globals: {string:Variable} modules: {string:Type} + module_filenames: {string:string} loaded: {string:Result} loaded_order: {string} reporter: TypeReporter - gen_compat: CompatMode - gen_target: TargetMode keep_going: boolean report_types: boolean - feat_arity: boolean + defaults: TypeCheckOptions end record Result @@ -571,6 +576,8 @@ local record tl i: integer end + type errors = Errors + typecodes: {string:integer} record TypeInfo @@ -601,28 +608,28 @@ local record tl end record EnvOptions - lax_mode: boolean - gen_compat: CompatMode - gen_target: TargetMode - feat_arity: Feat + defaults: TypeCheckOptions predefined_modules: {string} end load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, ? FILE): (Result, string) - process_string: function(string, boolean, Env, ? string): Result + process_string: function(string, Env, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result - type_check: function(Node, TypeCheckOptions): Result, string - new_env: function(EnvOptions): Env, string - init_env: function(? boolean, ? boolean | CompatMode, ? TargetMode, ? {string}): Env, string + type_check: function(Node, string, TypeCheckOptions, ? Env): Result, string + new_env: function(? EnvOptions): Env, string version: function(): string + -- Backwards compatibility + init_env: function(? boolean, ? boolean | GenCompat, ? GenTarget, ? {string}): Env, string + package_loader_env: Env load_envs: { {any:any} : Env } end local record TypeReporter typeid_to_num: {integer: integer} + typename_to_num: {TypeName: integer} next_num: integer tr: TypeReport @@ -684,17 +691,23 @@ tl.typecodes = { INVALID = 0x80000000, } -local type Result = tl.Result local type Env = tl.Env +local type EnvOptions = tl.EnvOptions local type Error = tl.Error -local type CompatMode = tl.CompatMode +local type Feat = tl.Feat +local type GenCompat = tl.GenCompat +local type GenTarget = tl.GenTarget +local type LoadFunction = tl.LoadFunction +local type LoadMode = tl.LoadMode local type PrettyPrintOptions = tl.PrettyPrintOptions +local type Result = tl.Result local type TypeCheckOptions = tl.TypeCheckOptions -local type LoadMode = tl.LoadMode -local type LoadFunction = tl.LoadFunction -local type TargetMode = tl.TargetMode local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport +local type WarningKind = tl.WarningKind + +local DEFAULT_GEN_COMPAT : GenCompat = "optional" +local DEFAULT_GEN_TARGET : GenTarget = "5.3" local enum Narrow "narrow" @@ -1515,7 +1528,6 @@ local enum TypeName "any" "unknown" -- to be used in lax mode only "invalid" -- producing a new value of this type (not propagating) must always produce a type error - "unresolved" "none" "*" end @@ -1552,7 +1564,6 @@ local table_types : {TypeName:boolean} = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1561,6 +1572,9 @@ local interface Type is Where where self.typename + y: integer + x: integer + typename: TypeName -- discriminator typeid: integer -- unique identifier inferred_at: Where -- for error messages @@ -1574,7 +1588,24 @@ local record StringType literal: string end -local type TypeType = TypeAliasType | TypeDeclType +local function is_numeric_type(t:Type): boolean + return t.typename == "number" or t.typename == "integer" +end + +local interface NumericType + is Type + where is_numeric_type(self) +end + +local record IntegerType + is NumericType + where self.typename == "integer" +end + +local record BooleanType + is Type + where self.typename == "boolean" +end local record TypeDeclType is Type @@ -1592,6 +1623,8 @@ local record TypeAliasType is_nested_alias: boolean end +local type TypeType = TypeDeclType | TypeAliasType + local record LiteralTableItemType is Type where self.typename == "literal_table_item" @@ -1602,13 +1635,12 @@ local record LiteralTableItemType vtype: Type end -local record UnresolvedType - is Type - where self.typename == "unresolved" - - labels: {string:{Node}} - nominals: {string:{NominalType}} - global_types: {string:boolean} +local record Scope + vars: {string:Variable} + labels: {string:Node} + pending_labels: {string:{Node}} + pending_nominals: {string:{NominalType}} + pending_global_types: {string:boolean} narrows: {string:boolean} end @@ -1675,6 +1707,11 @@ local record InvalidType where self.typename == "invalid" end +local record UnknownType + is Type + where self.typename == "unknown" +end + local record TupleType is Type where self.typename == "tuple" @@ -1849,7 +1886,8 @@ local interface Fact where self.fact fact: FactType - where: Where + w: Where + no_infer: boolean end local record TruthyFact @@ -2014,6 +2052,9 @@ local record Node -- goto label: string + -- label + used_label: boolean + casttype: Type -- variable @@ -2032,10 +2073,125 @@ local record Node debug_type: Type end -local function is_number_type(t:Type): boolean - return t.typename == "number" or t.typename == "integer" +local function a_type(w: Where, typename: TypeName, t: T): T + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local function edit_type(w: Where, t: Type, typename: TypeName): Type + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local macroexp a_typedecl(w: Where, def: Type): TypeDeclType + return a_type(w, "typedecl", { def = def } as TypeDeclType) +end + +local macroexp a_tuple(w: Where, t: {Type}): TupleType + return a_type(w, "tuple", { tuple = t } as TupleType) +end + +local macroexp a_union(w: Where, t: {Type}): UnionType + return a_type(w, "union", { types = t } as UnionType) +end + +local function a_function(w: Where, t: FunctionType): FunctionType + assert(t.min_arity) + return a_type(w, "function", t) +end + +local function a_vararg(w: Where, t: {Type}): TupleType + local typ = a_tuple(w, t) + typ.is_va = true + return typ +end + +local macroexp an_array(w: Where, t: Type): ArrayType + return a_type(w, "array", { elements = t } as ArrayType) +end + +local macroexp a_map(w: Where, k: Type, v: Type): MapType + return a_type(w, "map", { keys = k, values = v } as MapType) +end + +local function a_nominal(n: Node, names: {string}): NominalType + return a_type(n, "nominal", { names = names } as NominalType) end +local macroexp an_invalid(w: Where): InvalidType + return a_type(w, "invalid", {} as InvalidType) +end + +local macroexp an_unknown(w: Where): UnknownType + return a_type(w, "unknown", {} as UnknownType) +end + +local an_operator: function(Node, integer, string): Operator + +local function shallow_copy_new_type(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy as T +end + +local function shallow_copy_table(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + return copy as T +end + +-- TODO move to Errors module +local function clear_redundant_errors(errors: {Error}) + local redundant: {integer} = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a: Error, b: Error): boolean + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf + or (af == bf and (a.y < b.y + or (a.y == b.y and (a.x < b.x + or (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end +end + +local simple_types: {TypeName:boolean} = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} + +do ----------------------------------------------------------------------------- + local record ParseState tokens: {Token} errs: {Error} @@ -2108,163 +2264,52 @@ local function verify_end(ps: ParseState, i: integer, istart: integer, node: Nod return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end -local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } -end - -local function a_type(typename: TypeName, t: T): T - t.typeid = new_typeid() - t.typename = typename - return t +local function new_node(ps: ParseState, i: integer, kind?: NodeKind): Node + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end -local function edit_type(t: Type, typename: TypeName): Type +local function new_type(ps: ParseState, i: integer, typename: TypeName): Type + local token = ps.tokens[i] + local t: Type = {} t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y t.typename = typename return t end -local function new_type(ps: ParseState, i: integer, typename: TypeName): Type - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - --tk = token.tk - }) -end - local function new_typedecl(ps: ParseState, i: integer, def: Type): TypeDeclType local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def return t end -local macroexp a_typedecl(def: Type): TypeDeclType - return a_type("typedecl", { def = def } as TypeDeclType) -end - -local macroexp a_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t } as TupleType) -end - -local macroexp a_union(t: {Type}): UnionType - return a_type("union", { types = t } as UnionType) -end - ---local macroexp a_poly(t: {FunctionType}): PolyType --- return a_type("poly", { types = t } as PolyType) ---end --- -local function a_function(t: FunctionType): FunctionType - assert(t.min_arity) - return a_type("function", t) -end - -local record Opt - where self.opttype - - opttype: Type -end - ---local function OPT(t: Type): Opt --- return { opttype = t } ---end --- -local record Args - is {Type|Opt} - - is_va: boolean -end - -local function va_args(args: Args): Args - args.is_va = true - return args -end - -local record FuncArgs - is HasTypeArgs - - args: Args - rets: Args - needs_compat: boolean -end - -local function a_fn(f: FuncArgs): FunctionType - local args_t = a_tuple {} - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a is Opt then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - - local rets_t = a_tuple {} - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a is Type) - table.insert(tup, a) - end - - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - } as FunctionType) -end - -local function a_vararg(t: {Type}): TupleType - local typ = a_tuple(t) - typ.is_va = true - return typ -end - -local macroexp an_array(t: Type): ArrayType - return a_type("array", { elements = t } as ArrayType) -end - -local macroexp a_map(k: Type, v: Type): MapType - return a_type("map", { keys = k, values = v } as MapType) +local function new_tuple(ps: ParseState, i: integer, types?: {Type}, is_va?: boolean): TupleType, {Type} + local t = new_type(ps, i, "tuple") as TupleType + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple end -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_map(ANY, ANY) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) - -local function shallow_copy_new_type(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v - end - copy.typeid = new_typeid() - return copy as T +local function new_typealias(ps: ParseState, i: integer, alias_to: NominalType): TypeAliasType + local t = new_type(ps, i, "typealias") as TypeAliasType + t.alias_to = alias_to + return t end -local function shallow_copy_table(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v +local function new_nominal(ps: ParseState, i: integer, name?: string): NominalType + local t = new_type(ps, i, "nominal") as NominalType + if name then + t.names = { name } end - return copy as T + return t end local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind?: NodeKind): integer, Node if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + return i + 1, new_node(ps, i, node_kind) end return fail(ps, i, "syntax error, expected " .. kind) end @@ -2302,23 +2347,23 @@ local function parse_table_value(ps: ParseState, i: integer): integer, Node, int fail(ps, i, next_word == "record" and "syntax error: this syntax is no longer valid; declare nested record inside a record" or "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + return skip_i, new_node(ps, i, "error_node") end elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return i, new_node(ps, i - 1, "error_node") end local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end return i, e end local function parse_table_item(ps: ParseState, i: integer, n?: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "literal_table_item") + local node = new_node(ps, i, "literal_table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") end @@ -2369,7 +2414,7 @@ local function parse_table_item(ps: ParseState, i: integer, n?: integer): intege end end - node.key = new_node(ps.tokens, i, "integer") + node.key = new_node(ps, i, "integer") node.key_parsed = "implicit" node.key.constnum = n node.key.tk = tostring(n) @@ -2445,7 +2490,7 @@ local function parse_bracket_list(ps: ParseState, i: integer, list: {T}, open end local function parse_table_literal(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "literal_table") + local node = new_node(ps, i, "literal_table") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end @@ -2501,16 +2546,21 @@ local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, i = i + 1 i, constraint = parse_interface_name(ps, i) -- FIXME what about generic interfaces end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - } as TypeArgType) + local t = new_type(ps, i, "typearg") as TypeArgType + t.typearg = name + t.constraint = constraint + return i, t end local function parse_return_types(ps: ParseState, i: integer): integer, TupleType - return parse_type_list(ps, i, "rets") + local iprev = i - 1 + local t: TupleType + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType @@ -2523,31 +2573,25 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else - typ.args = a_vararg { ANY } - typ.rets = a_vararg { ANY } + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) end return i, typ end -local simple_types: {string:Type} = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} - local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk - local st = simple_types[tk] + local st = simple_types[tk as TypeName] if st then - return i + 1, st + return i + 1, new_type(ps, i, tk as TypeName) + elseif tk == "table" then + local typ = new_type(ps, i, "map") as MapType + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ end - local typ = new_type(ps, i, "nominal") as NominalType - typ.names = { tk } + + local typ = new_nominal(ps, i, tk) i = i + 1 while ps.tokens[i].tk == "." do i = i + 1 @@ -2614,12 +2658,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ elseif tk == "function" then return parse_function_type(ps, i) elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") as MapType - typ.keys = ANY - typ.values = ANY - return i + 1, typ + return i + 1, new_type(ps, i, "nil") end return fail(ps, i, "expected a type") end @@ -2655,12 +2694,6 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i, bt end -local function new_tuple(ps: ParseState, i: integer): TupleType, {Type} - local t = new_type(ps, i, "tuple") as TupleType - t.tuple = {} - return t, t.tuple -end - parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, TupleType local t, list = new_tuple(ps, i) @@ -2716,7 +2749,7 @@ local function parse_function_args_rets_body(ps: ParseState, i: integer, node: N end local function parse_function_value(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "function") + local node = new_node(ps, i, "function") i = verify_tk(ps, i, "function") return parse_function_args_rets_body(ps, i, node) end @@ -2737,7 +2770,7 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node if kind == "identifier" then return verify_kind(ps, i, "identifier", "variable") elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") + local node = new_node(ps, i, "string") node.conststr, node.is_longstring = unquote(tk) return i + 1, node elseif kind == "number" or kind == "integer" then @@ -2785,8 +2818,6 @@ local function node_is_require_call(n: Node): string end end -local an_operator: function(Node, integer, string): Operator - do local precedences: {integer:{string:integer}} = { [1] = { @@ -2861,8 +2892,8 @@ do -- small hack: for the sake of `tl types`, parse an invalid binary exp -- as a paren to produce a unary indirection on e1 and save its location. - local function failstore(tkop: Token, e1: Node): Node - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } + local function failstore(ps: ParseState, tkop: Token, e1: Node): Node + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end local function P(ps: ParseState, i: integer): integer, Node @@ -2880,7 +2911,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } elseif ps.tokens[i].tk == "(" then i = i + 1 local prev_i = i @@ -2889,7 +2920,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } else i, e1 = parse_literal(ps, i) end @@ -2914,12 +2945,12 @@ do local skipped = skip(ps, i, parse_type as SkipFunction) if skipped > i + 1 then fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + return skipped, failstore(ps, tkop, e1) end end i, key = verify_kind(ps, i, "identifier") if not key then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if op.op == ":" then @@ -2929,30 +2960,30 @@ do else fail(ps, i, "expected a function call for a method") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } elseif tkop.tk == "(" then local op: Operator = new_operator(tkop, 2, "@funcall") local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "[" then @@ -2966,19 +2997,19 @@ do if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } elseif tkop.kind == "string" or tkop.kind == "{" then local op: Operator = new_operator(tkop, 2, "@funcall") local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") local argument: Node if tkop.kind == "string" then - argument = new_node(ps.tokens, i) + argument = new_node(ps, i) argument.conststr = unquote(tkop.tk) i = i + 1 else @@ -2991,27 +3022,27 @@ do else fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "as" or tkop.tk == "is" then local op: Operator = new_operator(tkop, 2, tkop.tk) i = i + 1 - local cast = new_node(ps.tokens, i, "cast") + local cast = new_node(ps, i, "cast") if ps.tokens[i].tk == "(" then i, cast.casttype = parse_type_list(ps, i, "casttype") else i, cast.casttype = parse_type(ps, i) end if not cast.casttype then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else break end @@ -3042,7 +3073,7 @@ do end lookahead = ps.tokens[i].tk end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } end return i, lhs end @@ -3069,7 +3100,7 @@ parse_expression_and_tk = function(ps: ParseState, i: integer, tk: string): inte local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end if ps.tokens[i].tk == tk then i = i + 1 @@ -3147,7 +3178,7 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege end parse_argument_list = function(ps: ParseState, i: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "argument_list") + local node = new_node(ps, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false local min_arity = 0 @@ -3242,16 +3273,16 @@ end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + return i + 1, new_node(ps, i, "identifier") end i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") + return i, new_node(ps, i, "error_node") end local function parse_local_function(ps: ParseState, i: integer): integer, Node i = verify_tk(ps, i, "local") i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") + local node = new_node(ps, i - 2, "local_function") i, node.name = parse_identifier(ps, i) return parse_function_args_rets_body(ps, i, node) end @@ -3264,7 +3295,7 @@ end local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): integer, Node local orig_i = i i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") + local fn = new_node(ps, i - 1, "global_function") local names: {Node} = {} i, names[1] = parse_identifier(ps, i) while ps.tokens[i].tk == "." do @@ -3284,7 +3315,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int for i2 = 2, #names - 1 do local dot = an_operator(names[i2], 2, ".") names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } end fn.fn_owner = owner end @@ -3292,8 +3323,8 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) fn.min_arity = fn.min_arity + 1 end @@ -3311,7 +3342,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int end local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node, is_else?: boolean): integer, Node - local block = new_node(ps.tokens, i, "if_block") + local block = new_node(ps, i, "if_block") i = i + 1 block.if_parent = node block.if_block_n = n @@ -3333,7 +3364,7 @@ end local function parse_if(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "if") + local node = new_node(ps, i, "if") node.if_blocks = {} i, node = parse_if_block(ps, i, 1, node) if not node then @@ -3359,7 +3390,7 @@ end local function parse_while(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "while") + local node = new_node(ps, i, "while") i = verify_tk(ps, i, "while") i, node.exp = parse_expression_and_tk(ps, i, "do") i, node.body = parse_statements(ps, i) @@ -3369,7 +3400,7 @@ end local function parse_fornum(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "fornum") + local node = new_node(ps, i, "fornum") i = i + 1 i, node.var = parse_identifier(ps, i) i = verify_tk(ps, i, "=") @@ -3388,12 +3419,12 @@ end local function parse_forin(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "forin") + local node = new_node(ps, i, "forin") i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") + node.vars = new_node(ps, i, "variable_list") i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) if #node.exps < 1 then return fail(ps, i, "missing iterator expression in generic for") @@ -3415,7 +3446,7 @@ local function parse_for(ps: ParseState, i: integer): integer, Node end local function parse_repeat(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "repeat") + local node = new_node(ps, i, "repeat") i = verify_tk(ps, i, "repeat") i, node.body = parse_statements(ps, i) node.body.is_repeat = true @@ -3427,7 +3458,7 @@ end local function parse_do(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "do") + local node = new_node(ps, i, "do") i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) i = verify_end(ps, i, istart, node) @@ -3435,13 +3466,13 @@ local function parse_do(ps: ParseState, i: integer): integer, Node end local function parse_break(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "break") + local node = new_node(ps, i, "break") i = verify_tk(ps, i, "break") return i, node end local function parse_goto(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "goto") + local node = new_node(ps, i, "goto") i = verify_tk(ps, i, "goto") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3449,7 +3480,7 @@ local function parse_goto(ps: ParseState, i: integer): integer, Node end local function parse_label(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "label") + local node = new_node(ps, i, "label") i = verify_tk(ps, i, "::") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3474,9 +3505,9 @@ for k, v in pairs(stop_statement_list) do end local function parse_return(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "return") + local node = new_node(ps, i, "return") i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) if ps.tokens[i].kind == ";" then i = i + 1 @@ -3514,12 +3545,13 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType return fail(ps, i, "expected a variable name") end - local nt: Node = new_node(ps.tokens, i - 2, "newtype") + local nt: Node = new_node(ps, i - 2, "newtype") local ndef = new_type(ps, i, typename) + local itype = i local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype = new_typedecl(ps, i, ndef) + nt.newtype = new_typedecl(ps, itype, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3576,7 +3608,7 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): -- if ps.tokens[i].tk == "<" then -- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) -- end - local node = new_node(ps.tokens, istart, "macroexp") + local node = new_node(ps, istart, "macroexp") local i: integer i, node.args, node.min_arity = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) @@ -3588,18 +3620,14 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): end local function parse_where_clause(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") as NominalType - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = selftype + node.args[1].argtype = new_nominal(ps, i, "@self") node.min_arity = 1 node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN + node.rets.tuple[1] = new_type(ps, i, "boolean") i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i - 1]) return i, node @@ -3681,15 +3709,10 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true typ.min_arity = 1 - typ.args = a_tuple { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" } - } as NominalType) - } - typ.rets = a_tuple { BOOLEAN } + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }) + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) typ.macroexp = where_macroexp def.meta_fields = {} @@ -3810,7 +3833,7 @@ parse_type_body_fns = { } parse_newtype = function(ps: ParseState, i: integer): integer, Node - local node: Node = new_node(ps.tokens, i, "newtype") + local node: Node = new_node(ps, i, "newtype") local def: Type local tn = ps.tokens[i].tk as TypeName local itype = i @@ -3831,9 +3854,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end if def is NominalType then - local typealias = new_type(ps, itype, "typealias") as TypeAliasType - typealias.alias_to = def - node.newtype = typealias + node.newtype = new_typealias(ps, itype, def) else node.newtype = new_typedecl(ps, itype, def) end @@ -3843,7 +3864,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end local function parse_assignment_expression_list(ps: ParseState, i: integer, asgn: Node): integer, Node - asgn.exps = new_node(ps.tokens, i, "expression_list") + asgn.exps = new_node(ps, i, "expression_list") repeat i = i + 1 local val: Node @@ -3893,8 +3914,8 @@ do return fail(ps, i, "syntax error") end - local asgn: Node = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") + local asgn: Node = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") asgn.vars[1] = exp if ps.tokens[i].tk == "," then i = i + 1 @@ -3915,9 +3936,9 @@ do end local function parse_variable_declarations(ps: ParseState, i: integer, node_name: NodeKind): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) - asgn.vars = new_node(ps.tokens, i, "variable_list") + asgn.vars = new_node(ps, i, "variable_list") i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) if #asgn.vars == 0 then return fail(ps, i, "expected a local variable definition") @@ -3945,7 +3966,7 @@ end local function parse_type_declaration(ps: ParseState, i: integer, node_name: NodeKind): integer, Node i = i + 2 -- skip `local` or `global`, and `type` - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) i, asgn.var = parse_variable_name(ps, i) if not asgn.var then return fail(ps, i, "expected a type name") @@ -3985,8 +4006,8 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod end local function parse_type_constructor(ps: ParseState, i: integer, node_name: NodeKind, type_name: TypeName, parse_body: ParseBody): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) - local nt: Node = new_node(ps.tokens, i, "newtype") + local asgn: Node = new_node(ps, i, node_name) + local nt: Node = new_node(ps, i, "newtype") asgn.value = nt local itype = i local def = new_type(ps, i, type_name) @@ -4015,7 +4036,7 @@ end local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node local istart = i i = i + 2 -- skip `local` - local node = new_node(ps.tokens, i, "local_macroexp") + local node = new_node(ps, i, "local_macroexp") i, node.name = parse_identifier(ps, i) i, node.macrodef = parse_macroexp(ps, istart, i) end_at(node, ps.tokens[i - 1]) @@ -4085,7 +4106,7 @@ local needs_local_or_global: {string : function(ParseState, integer):(integer, N } parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): integer, Node - local node = new_node(ps.tokens, i, "statements") + local node = new_node(ps, i, "statements") local item: Node while true do while ps.tokens[i].kind == ";" do @@ -4130,32 +4151,6 @@ parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): int return i, node end -local function clear_redundant_errors(errors: {Error}) - local redundant: {integer} = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i - end - table.sort(errors, function(a: Error, b: Error): boolean - local af = a.filename or "" - local bf = b.filename or "" - return af < bf - or (af == bf and (a.y < b.y - or (a.y == b.y and (a.x < b.x - or (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) - end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end - function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Node, {string} errs = errs or {} local ps: ParseState = { @@ -4185,17 +4180,19 @@ function tl.parse(input: string, filename: string): Node, {Error}, {string} return node, errs, required_modules end +end ---------------------------------------------------------------------------- + -------------------------------------------------------------------------------- -- AST traversal -------------------------------------------------------------------------------- -local record VisitorCallbacks - before: function(N) - before_exp: function({N}, {T}) - before_arguments: function({N}, {T}) - before_statements: function({N}, {T}) - before_e2: function({N}, {T}) - after: function(N, {T}): T +local record VisitorCallbacks + before: function(S, N) + before_exp: function(S, {N}, {T}) + before_arguments: function(S, {N}, {T}) + before_statements: function(S, {N}, {T}) + before_e2: function(S, {N}, {T}) + after: function(S, N, {T}): T end local enum VisitorExtraCallback @@ -4205,9 +4202,11 @@ local enum VisitorExtraCallback "before_e2" end -local record Visitor - cbs: {K:VisitorCallbacks} - after: function(N, {T}, T): T +local type VisitorAfter = function(S, N, {T}, T): T + +local record Visitor + cbs: {K:VisitorCallbacks} + after: VisitorAfter allow_missing_cbs: boolean end @@ -4296,7 +4295,7 @@ local function tl_debug_indent_pop(mark: string, single: string, y: integer, x: end end -local function recurse_type(ast: Type, visit: Visitor): T +local function recurse_type(s: S, ast: Type, visit: Visitor): T local kind = ast.typename if TL_DEBUG then @@ -4308,7 +4307,7 @@ local function recurse_type(ast: Type, visit: Visitor): T if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4316,90 +4315,90 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast is TupleType then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast is AggregateType then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast is MapType then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast is RecordLikeType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is FunctionType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is NominalType then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is TypeArgType then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast is ArrayType then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast is LiteralTableItemType then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast is TypeAliasType then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast is TypeDeclType then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4409,25 +4408,26 @@ local function recurse_type(ast: Type, visit: Visitor): T return ret end -local function recurse_typeargs(ast: Node, visit_type: Visitor) +local function recurse_typeargs(s: S, ast: Node, visit_type: Visitor) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end -local function extra_callback(name: VisitorExtraCallback, - ast: Node, - xs: {T}, - visit_node: Visitor) +local function extra_callback(name: VisitorExtraCallback, + s: S, + ast: Node, + xs: {T}, + visit_node: Visitor) local cbs = visit_node.cbs if not cbs then return end local nbs = cbs[ast.kind] if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node: {NodeKind : boolean} = { @@ -4447,9 +4447,9 @@ local no_recurse_node: {NodeKind : boolean} = { ["type_identifier"] = true, } -local function recurse_node(root: Node, - visit_node: Visitor, - visit_type: Visitor): T +local function recurse_node(s: S, root: Node, + visit_node: Visitor, + visit_type: Visitor): T if not root then -- parse error return @@ -4466,9 +4466,9 @@ local function recurse_node(root: Node, local function walk_vars_exps(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4480,11 +4480,11 @@ local function recurse_node(root: Node, end local function walk_named_function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4497,9 +4497,9 @@ local function recurse_node(root: Node, end xs[2] = p1 as T if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4517,7 +4517,7 @@ local function recurse_node(root: Node, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4543,13 +4543,13 @@ local function recurse_node(root: Node, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4559,45 +4559,45 @@ local function recurse_node(root: Node, end, ["macroexp"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast: Node, xs: {T}) -- TODO: generic macroexp xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4606,7 +4606,7 @@ local function recurse_node(root: Node, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4623,12 +4623,12 @@ local function recurse_node(root: Node, end, ["newtype"] = function(ast: Node, xs:{T}) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast: Node, xs:{T}) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4647,7 +4647,7 @@ local function recurse_node(root: Node, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4671,10 +4671,10 @@ local function recurse_node(root: Node, local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4757,7 +4757,7 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOptions): string, string +function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode: boolean | PrettyPrintOptions): string, string local err: string local indent = 0 @@ -4778,7 +4778,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local save_indent: {integer} = {} - local function increment_indent(node: Node) + local function increment_indent(_: nil, node: Node) local child = node.body or node[1] if not child then return @@ -4871,7 +4871,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return table.concat(out) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} local lua_54_attribute : {Attribute:string} = { ["const"] = " ", @@ -4881,7 +4881,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_node.cbs = { ["statements"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4903,7 +4903,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end }, ["local_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4929,7 +4929,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["local_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4941,7 +4941,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4952,7 +4952,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4963,7 +4963,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["assignment"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4972,7 +4972,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["if"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -4983,7 +4983,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["if_block"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5003,7 +5003,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["while"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5016,7 +5016,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["repeat"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5028,7 +5028,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["do"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5039,7 +5039,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["forin"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5054,7 +5054,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["fornum"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5074,7 +5074,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["return"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5084,14 +5084,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["break"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } local space: string for i, child in ipairs(children) do @@ -5106,7 +5106,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["literal_table"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5126,7 +5126,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5149,13 +5149,13 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["local_macroexp"] = { before = increment_indent, - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5170,7 +5170,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["global_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5185,7 +5185,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["record_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5210,7 +5210,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5224,7 +5224,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["cast"] = { }, ["paren"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5233,7 +5233,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["op"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5294,14 +5294,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["variable"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_string(out, node.tk) return out end, }, ["newtype"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } local nt = node.newtype if nt is TypeAliasType then @@ -5318,7 +5318,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["goto"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5326,7 +5326,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["label"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5336,10 +5336,10 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, } - local visit_type: Visitor = {} + local visit_type: Visitor = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ: Type, _children: {Output}): Output + after = function(_: nil, typ: Type, _children: {Output}): Output local out: Output = { y = typ.y or -1, h = 0 } local r = typ is NominalType and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5377,7 +5377,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] @@ -5392,7 +5391,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_node.cbs["argument"] = visit_node.cbs["variable"] visit_node.cbs["type_identifier"] = visit_node.cbs["variable"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5442,7 +5441,6 @@ local typename_to_typecode : {TypeName:integer} = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5450,8 +5448,8 @@ local typename_to_typecode : {TypeName:integer} = { local skip_types: {TypeName: boolean} = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m: {A:B}):{A} @@ -5474,6 +5472,7 @@ function tl.new_type_reporter(): TypeReporter local self: TypeReporter = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5481,6 +5480,24 @@ function tl.new_type_reporter(): TypeReporter globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti: TypeInfo = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5500,9 +5517,15 @@ function TypeReporter:store_function(ti: TypeInfo, rt: FunctionType) end function TypeReporter:get_typenum(t: Type): integer + -- try simple types first + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) -- try by typeid - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5526,7 +5549,7 @@ function TypeReporter:get_typenum(t: Type): integer local ti: TypeInfo = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5596,7 +5619,7 @@ local record TypeCollector end function TypeReporter:get_collector(filename: string): TypeCollector - local tc: TypeCollector = { + local collector: TypeCollector = { filename = filename, symbol_list = {}, } @@ -5604,10 +5627,10 @@ function TypeReporter:get_collector(filename: string): TypeCollector local ft: {integer:{integer:integer}} = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y: integer, x: integer, typ: Type) + collector.store_type = function(y: integer, x: integer, typ: Type) if not typ or skip_types[typ.typename] then return end @@ -5621,12 +5644,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node: Node) + collector.reserve_symbol_list_slot = function(node: Node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node: Node, name: string, t: Type) + collector.add_to_symbol_list = function(node: Node, name: string, t: Type) if not node then return end @@ -5640,12 +5663,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node: Node) + collector.begin_symbol_list_scope = function(node: Node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node: Node) + collector.end_symbol_list_scope = function(node: Node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5655,14 +5678,14 @@ function TypeReporter:get_collector(filename: string): TypeCollector end end - return tc + return collector end -function TypeReporter:store_result(tc: TypeCollector, globals: {string:Variable}) +function TypeReporter:store_result(collector: TypeCollector, globals: {string:Variable}) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5730,144 +5753,446 @@ function TypeReporter:get_report(): TypeReport return self.tr end --- backwards compatibility -function tl.get_types(result: Result): TypeReport, TypeReporter - return result.env.reporter:get_report(), result.env.reporter -end -------------------------------------------------------------------------------- --- Type check +-- Report types -------------------------------------------------------------------------------- -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {} as InvalidType) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) - -local FUNCTION = a_fn { args = va_args { ANY }, rets = va_args { ANY } } - ---local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) -local XPCALL_MSGH_FUNCTION = a_fn { args = { ANY }, rets = { } } - ---local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type - -local numeric_binop = { +function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} + local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer + local function le(a: {integer, integer}, b: {integer, integer}): boolean + return a[1] < b[1] + or (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, {at_y, at_x}, le) or 0 + end + + local ret: {string:integer} = {} + + local n = find(tr.symbols, y, x) + + local symbols = tr.symbols + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + +-------------------------------------------------------------------------------- +-- Errors +-------------------------------------------------------------------------------- + +function Errors.new(filename: string): Errors + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg: string, t1?: Type, t2?: Type, t3?: Type): Error + if t1 then + local s1, s2, s3: string, string, string + if t1 is InvalidType then + return nil + end + s1 = show_type(t1) + if t2 then + if t2 is InvalidType then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3 is InvalidType then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self: Errors, y: integer, x: integer, err: Error) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w: Where, msg: string, ...:Type) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name: {NodeKind: string} = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx: Node|string, name?: string): string + if not ctx then + return "" + end + local ec = (ctx is Node) and ctx.expected_context + local cn = (ctx is string) and ctx or + (ctx is Node) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w: Where, ctx: Node, msg: string, ...:Type) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs: {Error}) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag: WarningKind, w: Where, fmt: string, ...: any) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w: Where, msg: string, ...:Type): InvalidType + self:add(w, msg, ...) + return an_invalid(w) +end + +function Errors:add_unknown(node: Node, name: string) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node: Node, old_var?: Variable) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name: string, var: Variable) + local prefix = name:sub(1,1) + if var.declared_at + and var.is_narrowed ~= "narrow" + and prefix ~= "_" + and prefix ~= "@" + then + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" + or t is FunctionType and "function" + or t is TypeDeclType and "type" + or t is TypeAliasType and "type" + or "variable", + name, + show_type(var.t) + ) + end +end + +function Errors:add_prefixing(w: Where, src: {Error}, prefix: string, dst?: {Error}) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) + or (not err.y) + or (w.y > err.y or (w.y == err.y and w.x > err.x)) + ) then + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + +local record Unused + y: integer + x: integer + name: string + var: Variable +end + +local function check_for_unused_vars(scope: Scope, is_global?: boolean): {Unused} + local vars = scope.vars + if not next(vars) then + return + end + local list: {Unused} + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t is TypeDeclType or t is TypeAliasType) and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a: Unused, b: Unused): boolean + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope: Scope, is_global?: boolean) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end +end + +function Errors:add_unknown_dot(node: Node, name: string) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end + +function Errors:fail_unresolved_labels(scope: Scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end + +function Errors:fail_unresolved_nominals(scope: Scope, global_scope: Scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end + +local type CheckableKey = string | number | boolean + +function Errors:check_redeclared_key(w: Where, ctx: Node, seen_keys: {CheckableKey:Where}, key: CheckableKey) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end + +-------------------------------------------------------------------------------- +-- Type check +-------------------------------------------------------------------------------- + +local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", } } -local unop_types: {string:{string:Type}} = { +local unop_types: {string:{TypeName:TypeName}} = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5877,7 +6202,7 @@ local unop_to_metamethod: {string:string} = { ["~"] = "__bnot", } -local binop_types: {string:{TypeName:{TypeName:Type}}} = { +local binop_types: {string:{TypeName:{TypeName:TypeName}}} = { ["+"] = numeric_binop, ["-"] = numeric_binop, ["*"] = numeric_binop, @@ -5898,67 +6223,66 @@ local binop_types: {string:{TypeName:{TypeName:Type}}} = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, -- HACK + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", } }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", } }, } @@ -6166,8 +6490,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end end -local function inferred_msg(t: Type): string - return " (inferred at "..t.inferred_at.filename..":"..t.inferred_at.y..":"..t.inferred_at.x..")" +local function inferred_msg(t: Type, prefix?: string): string + return " (" .. (prefix or "") .. "inferred at "..t.inferred_at.f..":"..t.inferred_at.y..":"..t.inferred_at.x..")" end show_type = function(t: Type, short?: boolean, seen?: {Type:string}): string @@ -6219,33 +6543,34 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean +local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$") as boolean) then + if found and (feat_lax or found:match("tl$") as boolean) then - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + env.module_filenames[module_name] = found + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return an_invalid(w), found end local compat_code_cache: {string:Node} = {} -local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: CompatMode) +local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: GenCompat) if gen_compat == "off" or not next(used_set) then return end @@ -6262,7 +6587,7 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge local code: Node = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6301,32 +6626,26 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax: boolean): {string:boolean} - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat(): {string:boolean} + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators: {string:string} = { @@ -6337,14 +6656,21 @@ local bit_operators: {string:string} = { ["<<"] = "lshift", } +local function node_at(w: Where, n: Node): Node + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name: string, e1: Node, e2?: Node) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6353,10 +6679,10 @@ local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end @@ -6365,25 +6691,6 @@ local stdlib_globals: {string:Variable} = nil local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 -local function set_feat(feat: tl.Feat, default: boolean): boolean - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts: tl.EnvOptions): Env, string - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end - local function assert_no_stdlib_errors(errors: {Error}, name: string) if #errors ~= 0 then local out = {} @@ -6394,46 +6701,31 @@ local function assert_no_stdlib_errors(errors: {Error}, name: string) end end -tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat as CompatMode - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts?: EnvOptions): Env, string + opts = opts or {} local env: Env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w: Where = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6442,21 +6734,20 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar -- special cases for compatibility local math_t = (stdlib_globals["math"].t as TypeDeclType).def as RecordType local table_t = (stdlib_globals["table"].t as TypeDeclType).def as RecordType - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true -- only global scope and vararg functions accept `...`: -- `@is_va` is an internal sentinel value which is -- `any` if `...` is accepted in this scope or `nil` if it isn't. - stdlib_globals["..."] = { t = a_vararg { STRING } } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6467,52 +6758,53 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string - opts = opts or {} - local env = opts.env - if not env then - local err: string - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + local type TypeRelations = {TypeName:{TypeName:CompareTypes}} + local type InvalidOrTupleType = InvalidType | TupleType - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename + local record TypeChecker + env: Env + st: {Scope} + + filename: string + errs: Errors + module_type: Type - local type Scope = {string:Variable} - local st: {Scope} = { env.globals } + subtype_relations: TypeRelations + eqtype_relations: TypeRelations + type_priorities: {TypeName:integer} - local all_needs_compat = {} + all_needs_compat: {string:boolean} + dependencies: {string:string} + collector: TypeCollector + + gen_compat: GenCompat + gen_target: GenTarget + feat_arity: boolean + feat_lax: boolean - local dependencies: {string:string} = {} - local warnings: {Error} = {} - local errors: {Error} = {} + same_type: function(TypeChecker, Type, Type): boolean, {Error} + is_a: function(TypeChecker, Type, Type): boolean, {Error} - local module_type: Type + type_check_funcall: function(TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType - local tc: TypeCollector - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") + expand_type: function(TypeChecker, w: Where, old: Type, new: Type): Type + + get_rets: function(TupleType): TupleType end local enum VarUse @@ -6522,10 +6814,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "check_only" end - local function find_var(name: string, use?: VarUse): Variable, integer, Attribute - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + function TypeChecker:find_var(name: string, use?: VarUse): Variable, integer, Attribute + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6534,7 +6826,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6547,10 +6839,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function simulate_g(): RecordType, Attribute + function TypeChecker:simulate_g(): RecordType, Attribute -- this is a static approximation of _G local globals: {string:Type} = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1,1) ~= "@" then globals[k] = v.t end @@ -6563,101 +6855,61 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, nil end - local type ResolveType = function(Type): Type - local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local type ResolveType = function(S, Type): Type + local typevar_resolver: function(s: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} - local function fresh_typevar(t: TypeVarType): Type, Type, boolean - return a_type("typevar", { + local function fresh_typevar(_: nil, t: TypeVarType): Type, Type, boolean + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeVarType) end - local function fresh_typearg(t: TypeArgType): Type - return a_type("typearg", { + local function fresh_typearg(_: nil, t: TypeArgType): Type + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeArgType) end - local function ensure_fresh_typeargs(t: T): T + function TypeChecker:ensure_fresh_typeargs(t: T): T if not t is HasTypeArgs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok: boolean - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name: string, use?: VarUse): Type, Attribute, Type - local var = find_var(name, use) + function TypeChecker:find_var_type(name: string, use?: VarUse): Type, Attribute, Type + local var = self:find_var(name, use) if var then local t = var.t if t is UnresolvedTypeArgType then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where: Where, msg: string, ...: Type): Error - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(table.unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w: Where, msg: string, ...:Type): boolean - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where: Where, t: Type) + local function ensure_not_abstract(t: Type): boolean, string if t is FunctionType and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t is TypeDeclType then local def = t.def if def is InterfaceType then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names: {string}, accept_typearg?: boolean): Type - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names: {string}, accept_typearg?: boolean): Type + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6679,7 +6931,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ is NominalType and typ.found then typ = typ.found end @@ -6691,19 +6943,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function union_type(t: Type): string, Type + local function type_for_union(t: Type): string, Type if t is TypeDeclType then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t is TypeAliasType then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t is TupleType then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t is NominalType then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t is RecordLikeType then if t.is_userdata then return "userdata", t @@ -6727,7 +6979,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then -- must be tested before table_types assert(rt is RecordLikeType) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6808,24 +7060,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["unknown"] = true, } - local function default_resolve_typevars_callback(t: TypeVarType): Type - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt is StringType then - -- tk is not propagated - return STRING - end - return rt - end - - resolve_typevars = function(typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + typevar_resolver = function(self: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:boolean} = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t: T, all_same: boolean): T, boolean local same = true @@ -6840,7 +7079,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local orig_t = t if t is TypeVarType then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then @@ -6856,7 +7095,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6867,7 +7106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- inferred_len is not propagated elseif t is TypeArgType then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy is TypeArgType) copy.typearg = t.typearg @@ -6960,7 +7199,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t is PolyType then assert(copy is PolyType) @@ -6970,6 +7209,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t is TupleTableType then assert(copy is TupleTableType) + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6989,7 +7229,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, an_invalid(typ), errs end if (not same) and @@ -7008,153 +7248,81 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true, copy end - local function infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) + local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt is StringType then + -- tk is not propagated + return a_type(rt, "string", {}) + end + return rt + end + + + + function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st + local nst = is_global and 1 or #self.st for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } end end end local function resolve_tuple(t: Type): Type - if t is TupleType then - t = t.tuple[1] + local rt = t + if rt is TupleType then + rt = rt.tuple[1] end - if t == nil then - return NIL + if rt == nil then + return a_type(t, "nil", {}) end - return t - end - - local function add_warning(tag: tl.WarningKind, where: Where, fmt: string, ...: any) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where: Where, msg: string, ...:Type): InvalidType - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node: Node, name: string) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node: Node, old_var?: Variable) - if node.tk:sub(1, 1) == "_" then return end - - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) - end - end - local function check_if_redeclaration(new_name: string, at: Node) - local old = find_var(new_name, "check_only") + function TypeChecker:check_if_redeclaration(new_name: string, at: Node) + local old = self:find_var(new_name, "check_only") if old then - redeclaration_warning(at, old) + self.errs:redeclaration_warning(at, old) end end - local function unused_warning(name: string, var: Variable) - local prefix = name:sub(1,1) - if var.declared_at - and var.is_narrowed ~= "narrow" - and prefix ~= "_" - and prefix ~= "@" - then - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" - or t is FunctionType and "function" - or t is TypeDeclType and "type" - or t is TypeAliasType and "type" - or "variable", - name, - show_type(var.t) - ) - end - end - end - - local function add_errs_prefixing(where: Where, src: {Error}, dst: {Error}, prefix: string) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) - or (not err.y) - or (where.y > err.y or (where.y == err.y and where.x > err.x)) - ) then - err.y = where.y - err.x = where.x - err.filename = filename - end - - table.insert(dst, err) - end - end local function type_at(w: Where, t: T): T t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where: Where, t: T): T - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w: Where, t: T): T + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where: Where, t: T): T - local ret = resolve_typevars_at(where, t) - if ret.typename == "invalid" then + function TypeChecker:infer_at(w: Where, t: T): T + local ret = self:resolve_typevars_at(w, t) + if ret is InvalidType then ret = t -- errors are produced by resolve_typevars_at end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7167,12 +7335,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local get_unresolved: function(scope?: Scope): UnresolvedType - local find_unresolved: function(level?: integer): UnresolvedType - - local function add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7185,11 +7350,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7200,46 +7365,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then -- the old var is removed from the scope and won't be checked when it closes, -- so check it here - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t is UnresolvedType or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var end - local type CompareTypes = function(Type, Type): boolean, {Error} - - local same_type: function(t1: Type, t2: Type): boolean, {Error} - local is_a: function(Type, Type): boolean, {Error} + local type CompareTypes = function(TypeChecker, Type, Type): boolean, {Error} local enum ArgCheckMode "argument" @@ -7254,38 +7412,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "invariant" end - local function arg_check(where: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean + function TypeChecker:arg_check(w: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean local ok, errs: boolean, {Error} if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s: {Type}, t2s: {Type}): boolean + function TypeChecker:has_all_types_of(t1s: {Type}, t2s: {Type}): boolean for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7317,8 +7475,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function close_types(vars: {string:Variable}) - for _, var in pairs(vars) do + local function close_types(scope: Scope) + for _, var in pairs(scope.vars) do local t = var.t if t is TypeDeclType then t.closed = true @@ -7330,161 +7488,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local record Unused - y: integer - x: integer - name: string - var: Variable - end - - local function check_for_unused_vars(vars: {string:Variable}, is_global?: boolean) - if not next(vars) then - return - end - local list: {Unused} = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t is TypeDeclType or t is TypeAliasType) and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a: Unused, b: Unused): boolean - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope?: Scope): UnresolvedType - local unresolved: UnresolvedType - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t as UnresolvedType - else - unresolved = find_var_type("@unresolved") as UnresolvedType - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - } as UnresolvedType) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level?: integer): UnresolvedType - local u = st[level or #st]["@unresolved"] - if u then - return u.t as UnresolvedType - end - end - - local function begin_scope(node?: Node) - table.insert(st, {}) + function TypeChecker:begin_scope(node?: Node) + table.insert(self.st, { vars = {} }) - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node?: Node) + function TypeChecker:end_scope(node?: Node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t as UnresolvedType - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node: Node, _children: {Type}): Type - end_scope(node) + -- This type must never be used for any values + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) return NONE end local type InvalidOrTypeDeclType = InvalidType | TypeDeclType - local resolve_nominal: function(t: NominalType): Type - local resolve_typealias: function(t: TypeAliasType): InvalidOrTypeDeclType do - local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type + local function match_typevals(self: TypeChecker, t: NominalType, def: RecordLikeType | FunctionType): Type if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t: NominalType): Type, TypeDeclType + local function find_nominal_type_decl(self: TypeChecker, t: NominalType): Type, TypeDeclType if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found is TypeAliasType then @@ -7492,8 +7585,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not found is TypeDeclType then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7508,44 +7600,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, found end - local function resolve_decl_into_nominal(t: NominalType, found: TypeDeclType): Type + local function resolve_decl_into_nominal(self: TypeChecker, t: NominalType, found: TypeDeclType): Type local def = found.def local resolved: Type if def is RecordType or def is FunctionType then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t: NominalType): Type - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t: NominalType): Type + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias: TypeAliasType): InvalidOrTypeDeclType + function TypeChecker:resolve_typealias(typealias: TypeAliasType): InvalidOrTypeDeclType local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7554,90 +7637,92 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved } as TypeDeclType) + local typedecl = a_type(typealias, "typedecl", { def = resolved } as TypeDeclType) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1: NominalType, t2: NominalType): boolean - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self: TypeChecker, t1: NominalType, t2: NominalType): boolean + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self: TypeChecker, t1: NominalType, t2: NominalType): boolean, {Error} + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local same_names: boolean - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} + local same_names: boolean + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} -- errors were already produced end - return false, {} -- errors were already produced end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type: function(t: Type): boolean - local function to_structural(t: Type): Type + function TypeChecker:to_structural(t: Type): Type assert(not t is TupleType) if t is NominalType then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types: {Type}, flatten_constants?: boolean): Type + local function unite(w: Where, types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] end @@ -7648,7 +7733,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- Make things like number | number resolve to number local types_seen: {(integer|string):boolean} = {} -- but never add nil as a type in the union - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7684,14 +7768,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_union(ts) + return a_union(w, ts) end end @@ -7711,21 +7795,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local expand_type: function(where: Where, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): ArrayType, {Error} + function TypeChecker:arraytype_from_tuple(w: Where, tupletype: TupleTableType): ArrayType, {Error} -- first just try a basic union - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not element_type is UnionType) and true or is_valid_union(element_type) if valid then - return an_array(element_type) + return an_array(w, element_type) end -- failing a basic union, expand the types - local arr_type = an_array(tupletype.types[1]) + local arr_type = an_array(w, tupletype.types[1]) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, an_array(tupletype.types[i])) + local expanded = self:expand_type(w, arr_type, an_array(w, tupletype.types[i])) if not expanded is ArrayType then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7736,33 +7819,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t is NominalType and t.names[1] == "@self" end - local function compare_true(_: Type, _: Type): boolean, {Error} + local function compare_true(_: TypeChecker, _: Type, _: Type): boolean, {Error} return true end - local function subtype_nominal(a: Type, b: Type): boolean, {Error} + function TypeChecker:subtype_nominal(a: Type, b: Type): boolean, {Error} if is_self(a) and is_self(b) then return true end - local ra = a is NominalType and resolve_nominal(a) or a - local rb = b is NominalType and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a is NominalType and self:resolve_nominal(a) or a + local rb = b is NominalType and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false -- translate to got-expected error with unresolved types end return ok, errs end - local function subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then -- constant array, check elements (useful for array of enums) for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7784,16 +7867,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local function subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} + function TypeChecker:subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata") } end @@ -7802,9 +7885,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7818,32 +7901,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local eqtype_record = function(a: RecordType, b: RecordType): boolean, {Error} + function TypeChecker:eqtype_record(a: RecordType, b: RecordType): boolean, {Error} -- checking array interface if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self: TypeChecker, ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) -- FIXME hack for {any:any} if bk.typename == "any" and not no_hack then @@ -7873,25 +7956,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} + function TypeChecker:compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) -- does the typevar currently match to a type? - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then -- If so, compare it to the other type - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else -- otherwise, infer it to the other type local other = a or b -- but check interface constraint first if present if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then -- do not infer to some type as constraint right away, -- to give a chance to more specific inferences -- in other arguments/returns @@ -7899,22 +7982,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r is TypeVarType and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end -- ∃ x ∈ xs. t <: x - local function exists_supertype_in(t: Type, xs: AggregateType): Type + function TypeChecker:exists_supertype_in(t: Type, xs: AggregateType): Type for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -7925,143 +8008,139 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a: Type, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self: TypeChecker, _a: Type, b: InterfaceType): boolean, {Error} return not b.is_userdata end, - ["record"] = function(_a: Type, b: RecordType): boolean, {Error} + ["record"] = function(_self: TypeChecker, _a: Type, b: RecordType): boolean, {Error} return not b.is_userdata end, } - local type TypeRelations = {TypeName:{TypeName:CompareTypes}} - - local eqtype_relations: TypeRelations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a: ArrayType, b: ArrayType): boolean, {Error} - return same_type(a.elements, b.elements) + ["array"] = function(self: TypeChecker, a: ArrayType, b: ArrayType): boolean, {Error} + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} - return (has_all_types_of(a.types, b.types) - and has_all_types_of(b.types, a.types)) + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} + return (self:has_all_types_of(a.types, b.types) + and self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self:TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self:TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations: TypeRelations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] + ["tuple"] = function(self: TypeChecker, a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] local at, bt = a.tuple, b.tuple -- ────────────────────────────────── if #at ~= #bt then -- a tuple <: b tuple return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a: Type, b: Type): boolean, {Error} - return is_a(resolve_tuple(a), b) + ["*"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u local used = {} -- ──────────────────────── for _, t in ipairs(a.types) do -- a union <: b union - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() -- don't preserve failed inferences + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() -- don't preserve failed inferences if not u then return false end @@ -8070,13 +8149,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end for u, t in pairs(used) do - is_a(t, u) -- preserve valid inferences + self:is_a(t, u) -- preserve valid inferences end return true end, - ["*"] = function(a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b - for _, t in ipairs(a.types) do -- ──────────────── - if not is_a(t, b) then -- a union <: b + ["*"] = function(self: TypeChecker, a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b + for _, t in ipairs(a.types) do -- ──────────────── + if not self:is_a(t, b) then -- a union <: b return false end end @@ -8084,212 +8163,212 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["poly"] = { - ["*"] = function(a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b - if exists_supertype_in(b, a) then -- ─────────────── - return true -- a poly <: b + ["*"] = function(self: TypeChecker, a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b + if self:exists_supertype_in(b, a) then -- ─────────────── + return true -- a poly <: b end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a: NominalType, b: NominalType): boolean, {Error} - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self: TypeChecker, a: NominalType, b: NominalType): boolean, {Error} + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb is InterfaceType then -- match interface subtyping - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra is UnionType or rb is UnionType then -- match unions structurally - return is_a(ra, rb) + return self:is_a(ra, rb) end -- all other types nominally return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a: StringType, b: EnumType): boolean, {Error} + ["enum"] = function(_self: TypeChecker, a: StringType, b: EnumType): boolean, {Error} if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["interface"] = function(self: TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a: Type, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: Type, b: RecordType): boolean, {Error} if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a: TupleTableType, b: ArrayType): boolean, {Error} + ["array"] = function(self: TypeChecker, a: TupleTableType, b: ArrayType): boolean, {Error} if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a: TupleTableType, b: MapType): boolean, {Error} - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self: TypeChecker, a: TupleTableType, b: MapType): boolean, {Error} + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a: RecordType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self: TypeChecker, a: RecordType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end if not a.declname then -- match inferred table (anonymous record) structurally to interface - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a: RecordType, b: MapType): boolean, {Error} - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self: TypeChecker, a: RecordType, b: MapType): boolean, {Error} + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk is EnumType and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a: RecordType, b: Type): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: RecordType, b: Type): boolean, {Error} if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a: ArrayType, b: RecordType): boolean, {Error} + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self: TypeChecker, a: ArrayType, b: RecordType): boolean, {Error} if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a: ArrayType, b: MapType): boolean, {Error} - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self: TypeChecker, a: ArrayType, b: MapType): boolean, {Error} + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a: ArrayType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: ArrayType, b: TupleTableType): boolean, {Error} local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end -- for array literals (which is the only case where inferred_len is defined), -- only check the entries that are present for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a: MapType, b: ArrayType): boolean, {Error} - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self: TypeChecker, a: MapType, b: ArrayType): boolean, {Error} + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a: TypeDeclType, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: TypeDeclType, b: RecordType): boolean, {Error} local def = a.def if def is RecordLikeType then - return subtype_record(def, b) -- record as prototype + return self:subtype_record(def, b) -- record as prototype end end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self: TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8297,36 +8376,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["typearg"] = { - ["typearg"] = function(a: TypeArgType, b: TypeArgType): boolean, {Error} + ["typearg"] = function(_self: TypeChecker, a: TypeArgType, b: TypeArgType): boolean, {Error} return a.typearg == b.typearg end, - ["*"] = function(a: TypeArgType, b: Type): boolean, {Error} + ["*"] = function(self: TypeChecker, a: TypeArgType, b: Type): boolean, {Error} if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a: Type, b: Type): boolean, {Error} - return is_a(a_tuple({a}), b) + ["tuple"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(a_tuple(a, {a}), b) end, - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a: Type, b: TypeArgType): boolean, {Error} + ["typearg"] = function(self: TypeChecker, a: Type, b: TypeArgType): boolean, {Error} if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t - -- ─────────────── - -- a <: b union - ["nominal"] = subtype_nominal, - ["poly"] = function(a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t - for _, t in ipairs(b.types) do -- ─────────────── - if not is_a(a, t) then -- a <: b poly - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + ["union"] = TypeChecker.exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t + -- ─────────────── + -- a <: b union + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self: TypeChecker, a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t + for _, t in ipairs(b.types) do -- ─────────────── + if not self:is_a(a, t) then -- a <: b poly + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8335,7 +8414,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } -- evaluation strategy - local type_priorities: {TypeName:integer} = { + TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first ["tuple"] = 2, ["typevar"] = 3, @@ -8364,19 +8443,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - -- in .lua files, all values can be used in a boolean context - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} + local function compare_types(self: TypeChecker, relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} if t1.typeid == t2.typeid then return true end @@ -8384,8 +8451,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8394,32 +8461,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end -- subtyping comparison - is_a = function(t1: Type, t2: Type): boolean, {Error} - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1: Type, t2: Type): boolean, {Error} + return compare_types(self, self.subtype_relations, t1, t2) end -- invariant type comparison - same_type = function(t1: Type, t2: Type): boolean, {Error} + function TypeChecker:same_type(t1: Type, t2: Type): boolean, {Error} -- except for error messages, behavior is the same as - -- `return (is_a(t1, t2) and is_a(t2, t1))` - return compare_types(eqtype_relations, t1, t2) + -- `return (is_a(t1, t2) and self:is_a(t2, t1))` + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1: Type, t2: Type): boolean, {Error} + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self: TypeChecker, t1: Type, t2: Type): boolean, {Error} assert(type(t1) == "table") assert(type(t2) == "table") @@ -8429,14 +8496,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where: Where, t1: Type, t2: Type, context: string, name?: string): boolean + function TypeChecker:assert_is_a(w: Where, t1: Type, t2: Type, ctx?: string | Node, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8444,24 +8511,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t1.typename == "nil" then return true elseif t2 is UnresolvedEmptyTableValueType then - if is_number_type(t2.emptytable_type.keys) then -- ideally integer only - infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) + local t2keys = t2.emptytable_type.keys + if t2keys is NumericType then -- ideally integer only + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, an_array(w, t1))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_map(w, t2keys, t1))) end return true elseif t2 is EmptyTableType then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not t1 is EmptyTableType then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": ".. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8469,11 +8539,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t is InvalidType then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t is NominalType then - t = resolve_nominal(t) + t = assert(t.resolved) end if t is RecordLikeType then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8487,40 +8557,31 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["boolean"] = true, ["literal_table"] = true, } - local function expr_is_definitely_not_closable(e: Node): boolean - return definitely_not_closable_exprs[e.kind] - end - - local unknown_dots: {string:boolean} = {} - - local function add_unknown_dot(node: Node, name: string) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end + local function expr_is_definitely_not_closable(e: Node): boolean + return definitely_not_closable_exprs[e.kind] end - local function same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type + function TypeChecker:same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u: UnionType): Type - return same_in_all_union_entries(u, function(t: Type): (Type, Type) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u: UnionType): Type + return self:same_in_all_union_entries(u, function(t: Type): (Type, Type) + t = self:to_structural(t) if t is RecordLikeType then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt is FunctionType then - local args_tuple = a_tuple({}) + local args_tuple = a_tuple(u, {}) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8530,20 +8591,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end) end - local function resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean + function TypeChecker:resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function - if lax and is_unknown(func) then - func = a_fn { args = va_args { UNKNOWN }, rets = va_args { UNKNOWN } } + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end -- unwrap if tuple, resolve if nominal - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union if func is UnionType then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? - return to_structural(r), true + return self:to_structural(r), true end end -- resolve if prototype @@ -8557,7 +8619,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8565,19 +8627,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local type OnArgId = function(node: Node, i: integer): T - local type OnNode = function(node: Node, children: {T}, ret: T): T + local type OnNode = function(s: S, node: Node, children: {T}, ret: T): T - local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T + local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T local root = macroexp.exp local argnames = {} for i, a in ipairs(macroexp.args) do argnames[a.tk] = i end - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["variable"] = { - after = function(node: Node, _children: {T}): T + after = function(_: nil, node: Node, _children: {T}): T local i = argnames[node.tk] if not i then return nil @@ -8587,10 +8649,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = on_node, + after = on_node as VisitorAfter, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode: Node, args: {Node}, macroexp: Node) @@ -8598,7 +8660,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return { Node, args[i] } end - local on_node = function(node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} + local on_node = function(_: nil, node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8627,12 +8689,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp: Node) + function TypeChecker:check_macroexp_arg_use(macroexp: Node) local used: {string:boolean} = {} local on_arg_id = function(node: Node, _i: integer): {Node, Node} if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8655,18 +8717,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type InvalidOrTupleType = InvalidType | TupleType - - local type_check_function_call: function(Node, Type, TupleType, ? integer, ? Node, ? {Node}): InvalidOrTupleType, FunctionType do - local function mark_invalid_typeargs(f: FunctionType) + local function mark_invalid_typeargs(self: TypeChecker, f: FunctionType) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and an_unknown(a) or a_type(a, "unresolvable_typearg", { typearg = a.typearg } as UnresolvableTypeArgType)) end @@ -8675,7 +8734,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_emptytables(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) + local function infer_emptytables(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8685,19 +8744,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if x is EmptyTableType then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then -- y may not be present when inferring returns - local w = wheres and wheres[i + delta] or where -- for self, a + argdelta is 0 - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w -- for self, a + argdelta is 0 + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end end end - local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + local check_args_rets: function(TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} + local function check_func_type_list(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8708,11 +8767,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8721,7 +8780,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + check_args_rets = function(self: TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -8732,19 +8791,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8752,29 +8811,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- if we got to this point without returning, -- we got a valid function match - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func: FunctionType) + local function push_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, } as UnresolvedTypeArgType)) end end end - local function pop_typeargs(func: FunctionType) + local function pop_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8788,12 +8847,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType + local function fail_call(self: TypeChecker, w: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType if errs then - -- report the errors from the first match - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else -- found no arity match to try local expects: {string} = {} @@ -8810,34 +8866,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func is FunctionType or func is PolyType) then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func is FunctionType or func is PolyType) then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_typedecl(args.tuple[1]))) + self:add_var(nil, "@self", a_typedecl(w, args.tuple[1])) end local passes, n = 1, 1 @@ -8854,30 +8910,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 -- simple functions: - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) -- poly, pass 1: try exact arity matches first or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (lax or given >= min_arity)) + or (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then -- success! return matched, f @@ -8886,23 +8942,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected_rets then -- revert inferred returns - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType + function TypeChecker:type_check_function_call(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8911,14 +8967,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected and expected is TupleType then expected_rets = expected else - expected_rets = a_tuple { node.expected } + expected_rets = a_tuple(node, { node.expected }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall: boolean - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver is NominalType then local resolved = receiver.resolved if resolved and resolved is TypeDeclType then @@ -8927,12 +8983,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -8943,9 +8999,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return an_unknown(node), nil end local ameta = a is RecordLikeType and a.meta_fields local bmeta = b and b is RecordLikeType and b.meta_fields @@ -8966,26 +9022,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if metamethod then local e2 = { node.e1 } - local args = a_tuple { orig_a } + local args = a_tuple(node, { orig_a }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl: Type, rec: Node, key: string): Type, string + function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl is StringType or tbl is EnumType then - tbl = find_var_type("string") -- simulate string metatable + tbl = self:find_var_type("string") -- simulate string metatable end if tbl is TypeDeclType then @@ -8994,13 +9050,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl is UnionType then - local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9009,7 +9065,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if (tbl is TypeVarType or tbl is TypeArgType) and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9023,7 +9079,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9034,8 +9091,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl is EmptyTableType or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return an_unknown(rec) end return nil, "cannot index a value of unknown type" end @@ -9047,30 +9104,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function widen_in_scope(scope: Scope, var: string): boolean - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope: Scope, var: string): boolean + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name: string): boolean + function TypeChecker:widen_back_var(name: string): boolean local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9081,10 +9143,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function assigned_anywhere(name: string, root: Node): boolean - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["assignment"] = { - after = function(node: Node, _children: {boolean}): boolean + after = function(_: nil, node: Node, _children: {boolean}): boolean for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9094,7 +9156,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = function(_node: Node, children: {boolean}, ret: boolean): boolean + after = function(_: nil, _node: Node, children: {boolean}, ret: boolean): boolean ret = ret or false for _, c in ipairs(children) do local ca = c as any @@ -9106,124 +9168,88 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } - local visit_type: Visitor = { + local visit_type: Visitor = { after = function(): boolean return false end } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node?: Node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node?: Node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node: Node, var: string, valtype: Type, is_assigning?: boolean): Variable - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node: Node, varname: string, valtype: Type, is_assigning?: boolean): Variable + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - local get_rets: function(TupleType): TupleType - if lax then - get_rets = function(rets: TupleType): TupleType - if #rets.tuple == 0 then - return a_vararg { UNKNOWN } - end - return rets - end - else - get_rets = function(rets: TupleType): TupleType - return rets - end + return var end - local function add_internal_function_variables(node: Node, args: TupleType) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_tuple({})) + function TypeChecker:add_internal_function_variables(node: Node, args: TupleType) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_tuple(node, {})) if node.typeargs then for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") + local v = self:find_var(t.typearg, "check_only") if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function add_function_definition_for_recursion(node: Node, fnargs: TupleType) - add_var(nil, node.name.tk, type_at(node, a_function { + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + self:add_var(nil, node.name.tk, a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = fnargs, - rets = get_rets(node.rets), + rets = self.get_rets(node.rets), })) end - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end - end - end - end - end - - local function end_function_scope(node: Node) - fail_unresolved() - end_scope(node) + function TypeChecker:end_function_scope(node: Node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals: TupleType): TupleType local vt = vals.tuple local n_vals = #vt - local ret = a_tuple {} + local ret = a_tuple(vals, {}) local rt = ret.tuple if n_vals == 0 then @@ -9251,9 +9277,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function get_assignment_values(vals: TupleType, wanted: integer): TupleType + local function get_assignment_values(w: Where, vals: TupleType, wanted: integer): TupleType if vals == nil then - return a_tuple {} + return a_tuple(w, {}) end local ret = flatten_tuple(vals) @@ -9272,14 +9298,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type + function TypeChecker:match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type local t: Type for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9289,26 +9315,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type + function TypeChecker:type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type assert(not a is TupleType) assert(not b is TupleType) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm: string local erra: Type local errb: Type - if ra is TupleTableType and is_a(rb, INTEGER) then + if ra is TupleTableType and rb is IntegerType then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum as integer] @@ -9316,38 +9342,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra is ArrayLikeType and is_a(rb, INTEGER) then + elseif ra is ArrayLikeType and rb is IntegerType then return ra.elements elseif ra is EmptyTableType then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra - } as UnresolvedEmptyTableValueType)) + } as UnresolvedEmptyTableValueType) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " - .. ra.keys.inferred_at.filename .. ":" - .. ra.keys.inferred_at.y .. ":" - .. ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra is MapType then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb is StringType and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9363,10 +9386,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb is StringType then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9375,28 +9398,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where: Where, old: Type, new: Type): Type + function TypeChecker:expand_type(w: Where, old: Type, new: Type): Type if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old is MapType and new is RecordLikeType then local old_keys = old.keys if old_keys is StringType then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") -- map changed, refresh typeid + edit_type(w, old, "map") -- map changed, refresh typeid else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old is RecordLikeType and new is RecordLikeType then local values: Type @@ -9404,14 +9427,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9419,25 +9442,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old as MapType - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old is UnionType then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp: Node): Type, Variable, string + function TypeChecker:find_record_to_extend(exp: Node): Type, Variable, string -- base if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9454,7 +9477,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t, v, exp.tk -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9475,30 +9498,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function typedecl_to_nominal(where: Where, name: string, t: TypeDeclType, resolved?: Type): Type + local function typedecl_to_nominal(node: Node, name: string, t: TypeDeclType, resolved?: Type): Type local typevals: {Type} local def = t.def if def is HasTypeArgs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, } as TypeVarType)) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - } as NominalType)) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp: Node): Type + function TypeChecker:get_self_type(exp: Node): Type -- base if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9510,7 +9532,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9539,10 +9561,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- Inference engine for 'is' operator - local facts_and: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_or: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_not: function(where: Where, f1: Fact): Fact - local apply_facts: function(where: Where, known: Fact) + local facts_and: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_or: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_not: function(w: Where, f1: Fact): Fact local FACT_TRUTHY: Fact do local IsFact_mt: metatable = { @@ -9554,6 +9575,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(IsFact, { __call = function(_: IsFact, fact: Fact): IsFact fact.fact = "is" + assert(fact.w) return setmetatable(fact as IsFact, IsFact_mt) end, }) @@ -9567,6 +9589,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(EqFact, { __call = function(_: EqFact, fact: Fact): EqFact fact.fact = "==" + assert(fact.w) return setmetatable(fact as EqFact, EqFact_mt) end, }) @@ -9625,57 +9648,57 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string FACT_TRUTHY = TruthyFact {} - facts_and = function(where: Where, f1: Fact, f2: Fact): Fact - return AndFact { f1 = f1, f2 = f2, where = where } + facts_and = function(w: Where, f1: Fact, f2: Fact): Fact + return AndFact { f1 = f1, f2 = f2, w = w } end - facts_or = function(where: Where, f1: Fact, f2: Fact): Fact + facts_or = function(w: Where, f1: Fact, f2: Fact): Fact if f1 and f2 then - return OrFact { f1 = f1, f2 = f2, where = where } + return OrFact { f1 = f1, f2 = f2, w = w } else return nil end end - facts_not = function(where: Where, f1: Fact): Fact + facts_not = function(w: Where, f1: Fact): Fact if f1 then - return NotFact { f1 = f1, where = where } + return NotFact { f1 = f1, w = w } else return nil end end -- t1 ∪ t2 - local function unite_types(t1: Type, t2: Type): Type, string - return unite({t2, t1}) + local function unite_types(w: Where, t1: Type, t2: Type): Type, string + return unite(w, {t2, t1}) end -- t1 ∩ t2 - local function intersect_types(t1: Type, t2: Type): Type, string + local function intersect_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type, string if t2 is UnionType then t1, t2 = t2, t1 end if t1 is UnionType then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end end end - local function resolve_if_union(t: Type): Type - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t: Type): Type + local rt = self:to_structural(t) if rt is UnionType then return rt end @@ -9683,23 +9706,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- t1 - t2 - local function subtract_types(t1: Type, t2: Type): Type + local function subtract_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type local types: {Type} = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) -- poly are not first-class, so we don't handle them here if not t1 is UnionType then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2 is UnionType and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9710,78 +9733,78 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if #types == 0 then - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end - return unite(types) + return unite(w, types) end - local eval_not: function(f: Fact): {string:IsFact|EqFact} - local not_facts: function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local or_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local and_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local eval_fact: function(f: Fact): {string:IsFact|EqFact} + local eval_not: function(TypeChecker, f: Fact): {string:IsFact|EqFact} + local not_facts: function(TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local or_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local and_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local eval_fact: function(TypeChecker, f: Fact): {string:IsFact|EqFact} local function invalid_from(f: IsFact): IsFact - return IsFact { fact = "is", var = f.var, typ = INVALID, where = f.where } + return IsFact { fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w } end - not_facts = function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + not_facts = function(self: TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } elseif f is EqFact then -- nothing is known from negation of equality; widen back - ret[var] = EqFact { var = var, typ = typ } - elseif typ.typename == "typevar" then + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif typ is TypeVarType then assert(f.fact == "is") - -- nothing is known from negation on typeargs; widen back (no 'where') - ret[var] = EqFact { var = var, typ = typ } - elseif not is_a(f.typ, typ) then + -- nothing is known from negation on typeargs; widen back + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } else assert(f.fact == "is") - ret[var] = IsFact { var = var, typ = subtract_types(typ, f.typ), where = f.where } + ret[var] = IsFact { var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer } end end return ret end - eval_not = function(f: Fact): {string:IsFact|EqFact} + eval_not = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - return not_facts({[f.var] = f}) + return not_facts(self, {[f.var] = f}) elseif f is NotFact then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f is OrFact then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + or_facts = function(_self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact { var = var, typ = united, where = f.where } + ret[var] = IsFact { var = var, typ = united, w = f.w } else - ret[var] = EqFact { var = var, typ = united, where = f.where } + ret[var] = EqFact { var = var, typ = united, w = f.w } end end end @@ -9789,7 +9812,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - and_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + and_facts = function(self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} local has: {FactType:boolean} = {} @@ -9800,18 +9823,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor { var = var, typ = rt, where = f.where } + local ff = ctor { var = var, typ = rt, w = f.w, no_infer = f.no_infer } ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact { var = var, typ = f.typ, where = f.where } + ret[var] = EqFact { var = var, typ = f.typ, w = f.w, no_infer = f.no_infer } has["=="] = true end end @@ -9825,21 +9848,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - eval_fact = function(f: Fact): {string:IsFact|EqFact} + eval_fact = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then -- drop this warning because of implicit nil in all unions - -- add_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) + -- self.errs:add_warning("branch", f.w, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9847,63 +9870,60 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif f is EqFact then return { [f.var] = f } elseif f is NotFact then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is TruthyFact then return {} elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is AndFact then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f is OrFact then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where: Where, known: Fact) + function TypeChecker:apply_facts(w: Where, known: Fact) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name: string) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name: string) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): InvalidOrTupleType - - local function special_pcall_xpcall(node: Node, _a: Type, b: TupleType, argdelta: integer): Type + local function special_pcall_xpcall(self: TypeChecker, node: Node, _a: Type, b: TupleType, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_tuple { BOOLEAN } + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_tuple(node, { bool }) end -- The function called by pcall/xpcall is invoked as a regular function, @@ -9915,137 +9935,142 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ftype.is_method = false end - local fe2: Node = {} + local fe2: Node = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_tuple(arg2, { a_type(arg2, "any", {}) }), + rets = a_tuple(arg2, {}) + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode: Node = { - y = node.y, - x = node.x, + local fnode: Node = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets is InvalidType then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end - local special_functions: {string : function(Node,Type,TupleType,integer):InvalidOrTupleType } = { - ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + local special_functions: {string : function(TypeChecker, Node,Type,TupleType,integer):InvalidOrTupleType } = { + ["pairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t is ArrayLikeType then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t is RecordLikeType then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["ipairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t is TupleTableType then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t is ArrayLikeType then - if not (lax and (is_unknown(t) or t is EmptyTableType)) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t is EmptyTableType)) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["rawget"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType -- TODO should those offsets be fixed by _argdelta? if #b.tuple == 2 then - return a_tuple({ type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) + return a_tuple(node, { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["require"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return invalid_at(node, "don't know how to resolve a dynamic require") + return self.errs:invalid_at(node, "don't know how to resolve a dynamic require") end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_tuple({ UNKNOWN }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") + end + + if self.feat_lax then + return a_tuple(node, { an_unknown(node) }) end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_tuple({ t })) + self.dependencies[module_name] = module_filename + return a_tuple(node, { t }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["assert"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType + function TypeChecker:type_check_funcall(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10057,19 +10082,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node: Node, i: integer, name: string): Type - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node: Node, i: integer, name: string): (InvalidType | UnknownType) + if self.feat_lax then + return an_unknown(node) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node: Node, children: {Type}) + local function set_expected_types_to_decltuple(_: TypeChecker, node: Node, children: {Type}) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple is TupleType) local decls = decltuple.tuple @@ -10081,7 +10106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_tuple {}) + typ = a_tuple(node, {}) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10097,38 +10122,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return n and n >= 1 and math.floor(n) == n end - local context_name: {NodeKind: string} = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx: Node.ExpectedContext, msg: string): string - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - local type CheckableKey = string | number | boolean - - local function check_redeclared_key(where: Where, ctx: Node.ExpectedContext, seen_keys: {CheckableKey:Where}, key: CheckableKey) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node: Node, children: {LiteralTableItemType}): Type + local function infer_table_literal(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type local is_record = false local is_array = false local is_map = false @@ -10153,14 +10147,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end local key: CheckableKey = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10171,7 +10166,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif cktype is NumericType then is_array = true if not is_not_tuple then is_tuple = true @@ -10185,25 +10180,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if i == #children and cv is TupleType then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else -- explicit if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n as integer] = uvtype if n > largest_array_idx then largest_array_idx = n as integer end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10215,37 +10210,37 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t: Type if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_map( - expand_type(node, keys, INTEGER), - expand_type(node, values, elements) + self.errs:add(node, "cannot determine type of table literal") + t = a_map(node, + self:expand_type(node, keys, a_type(node, "integer", {})), + self:expand_type(node, values, elements) ) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, an_array(elements)) + an_array(node, elements) } } as RecordType) - -- TODO adopt logic from is_array below when we accept tupletable as an interface + -- TODO adopt logic from self:is_array below when we accept tupletable as an interface elseif is_record and is_map then if keys is StringType then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_map(keys, values) + t = a_map(node, keys, values) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10253,7 +10248,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last_t: Type for _, current_t in pairs(types as {integer:Type}) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10262,69 +10257,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if pure_array then - t = an_array(elements) + t = an_array(node, elements) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, } as RecordType) elseif is_map then - t = a_map(keys, values) + t = a_map(node, keys, values) elseif is_tuple then - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where: Where, ifnode: Node, n: integer) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w: Where, ifnode: Node, n: integer) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean + function TypeChecker:determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype is InvalidType then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype is UnresolvableTypeArgType then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype is FunctionType and infertype.is_method then -- If we assign a method to a variable, e.g: -- `local myfunc = myobj.dothing`, @@ -10336,17 +10332,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10354,12 +10350,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri is MapType then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri is RecordType then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10369,34 +10365,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t is EmptyTableType then t.declared_at = node t.assigned_to = name elseif t is ArrayLikeType then t.inferred_len = nil + elseif t is NominalType then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value: Node): TypeDeclType, Variable + function TypeChecker:get_typedecl(value: Node): TypeDeclType, Variable if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_tuple { STRING }, 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }), 0) local ty = t is TupleType and t.tuple[1] or t - ty = (ty is TypeAliasType) and resolve_typealias(ty) or ty - local td = (ty is TypeDeclType) and ty or a_type("typedecl", { def = ty } as TypeDeclType) + ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty + local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType) return td else local newtype = value.newtype if newtype is TypeAliasType then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype is TypeDeclType then return newtype, nil end end @@ -10427,15 +10425,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_map_check(t: MapType, seen_keys: {CheckableKey:Where}): boolean, {string} - local k = to_structural(t.keys) + local function total_map_check(keys: Type, seen_keys: {CheckableKey:Where}): boolean, {string} local is_total = true local missing: {string} - if k is EnumType then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys is EnumType then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10449,35 +10446,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "missing" end - local function check_assignment(where: Where, vartype: Type, valtype: Type, varname: string, attr: Attribute): Type, Type, MissingError + function TypeChecker:check_assignment(varnode: Node, vartype: Type, valtype: Type): Type, Type, MissingError + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var is TypeDeclType or var is TypeAliasType then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10489,185 +10489,186 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_tuple(t) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} visit_node.cbs = { ["statements"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type -- if at the top level - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end - -- TODO extract node type from `return` + return NONE end }, ["local_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node: Node) - if tc then + before = function(self: TypeChecker, node: Node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not rt is EnumType) and ((not t is NominalType) or (rt is UnionType)) - and not same_type(t, infertype) + and not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local vartuple = children[1] assert(vartuple is TupleType) local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple is TupleType) - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then -- assigning a function if rval is FunctionType then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar is UnionType or rvar is InterfaceType) then -- narrow unions and interfaces - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10676,7 +10677,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["if"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10686,26 +10687,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node: Node) + before_statements = function(self: TypeChecker, node: Node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node: Node, _children: {Type}): Type - end_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10715,76 +10716,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end }, ["while"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node: Node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self: TypeChecker, node: Node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + --for i = #self.st, 1, -1 do + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + --break + end + --end end, after = function(): Type return NONE end }, ["goto"] = { - after = function(node: Node, _children: {Type}): Type - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local label_id = node.label + local found_label: Node + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, }, ["forin"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local exptuple = children[2] assert(exptuple is TupleType) local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_tuple { + local args = a_tuple(node.exps, { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3] - } - local exp1type = resolve_for_call(exptypes[1], args, false) + }) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type is PolyType then local _: Type - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) end if exp1type is FunctionType then @@ -10797,69 +10818,69 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and an_unknown(v) or an_invalid(v) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node: Node, children: {Type}) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and - to_t.typename == "integer" and - (not step_t or step_t.typename == "integer")) - and INTEGER - or NUMBER - add_var(node.var, node.var.tk, t) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename: TypeName = (from_t.typename == "integer" and + to_t.typename == "integer" and + (not step_t or step_t.typename == "integer")) + and "integer" + or "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node: Node) - local rets = find_var_type("@return") + before = function(self: TypeChecker, node: Node) + local rets = self:find_var_type("@return") if rets and rets is TupleType then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local got = children[1] assert(got is TupleType) local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") as TupleType + local expected = self:find_var_type("@return") as TupleType if not expected then -- if at the toplevel - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10874,8 +10895,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 @@ -10883,18 +10904,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) - and node.exps[i] - or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + local w = (node.exps[i] and node.exps[i].x) + and node.exps[i] + or node.exps + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10902,25 +10923,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable_list"] = { - after = function(node: Node, children: {Type}): Type - local tuple = a_tuple(children) + after = function(self: TypeChecker, node: Node, children: {Type}): Type + local tuple = a_tuple(node, children) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype is TypeVarType and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype is TupleTableType then @@ -10952,19 +10976,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {LiteralTableItemType}): Type + after = function(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint: Type if decltype is TypeVarType and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype is UnionType then @@ -10972,7 +10996,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local single_table_rt: Type for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then -- multiple table types in union, give up @@ -10993,7 +11017,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array: Type = nil @@ -11003,73 +11027,75 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype is RecordLikeType and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df is TypeDeclType or df is TypeAliasType then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype is TupleTableType and is_number_type(child.ktype) then + elseif decltype is TupleTableType and cktype is NumericType then local dt = decltype.types[n as integer] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype is ArrayLikeType and is_number_type(child.ktype) then + elseif decltype is ArrayLikeType and cktype is NumericType then local cv = child.vtype if cv is TupleType and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype is MapType then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype is MapType then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t: Type if force_array then - t = infer_at(node, an_array(force_array)) + t = self:infer_at(node, an_array(node, force_array)) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype is RecordType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is RecordType then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is MapType then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11081,13 +11107,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype is FunctionType and vtype.is_method then -- If we assign a method to a table item, e.g. @@ -11096,210 +11122,210 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - } as LiteralTableItemType)) + } as LiteralTableItemType) end, }, ["local_function"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ is FunctionType then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function { + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), }))) return NONE end, }, ["record_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node: Node, children: {Type}) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self: TypeChecker, _node: Node, children: {Type}) + local rtype = self:to_structural(resolve_typedecl(children[1])) -- add type arguments from the record implicitly if rtype is RecordLikeType and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[3] assert(args is TupleType) local rets = children[4] assert(rets is TupleType) - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype is UnknownType then return end if rtype is EmptyTableType then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype as RecordType r.fields = {} r.field_order = {} end if not rtype is RecordLikeType then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function { + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype is PolyType then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11312,32 +11338,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, _children: {Type}): Type - end_function_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11346,24 +11372,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["macroexp"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node: Node, children: {Type}) + before_exp = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11372,22 +11398,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["cast"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.casttype end }, ["paren"] = { - before = function(node: Node) + before = function(_self: TypeChecker, node: Node) node.e1.expected = node.expected end, - after = function(node: Node, children: {Type}): Type + after = function(_self: TypeChecker, node: Node, children: {Type}): Type node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node: Node) - begin_scope() + before = function(self: TypeChecker, node: Node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11399,18 +11425,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - before_e2 = function(node: Node, children: {Type}) + before_e2 = function(self: TypeChecker, node: Node, children: {Type}) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type is FunctionType then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + -- this forces typevars in function return types + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11433,8 +11460,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {Type}): Type - end_scope() + after = function(self: TypeChecker, node: Node, children: {Type}): Type + self:end_scope() -- given a and b: may be TupleType local ga: Type = children[1] @@ -11445,29 +11472,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ub: Type -- resolved a and b: not NominalType - local ra: Type = to_structural(ua) + local ra: Type = self:to_structural(ua) local rb: Type if ra.typename == "circular_require" or (ra is TypeDeclType and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb is TupleType) +assert(node.f) + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra is TypeDeclType and ra.def.typename == "record" then ra = ra.def end @@ -11476,8 +11508,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- after they are handled above, we can resolve b's tuple and only use that instead. if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb is TypeDeclType and rb.def.typename == "record" then rb = rb.def end @@ -11487,22 +11522,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.receiver = ua assert(node.e2.kind == "identifier") - local bnode: Node = { - y = node.e2.y, - x = node.e2.x, + local bnode: Node = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk } as StringType)) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk } as StringType) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then -- only apply to a literal use, not a propagated type if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11510,22 +11543,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra is TypeDeclType then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact { var = node.e1.tk, typ = ub, where = node } + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact { var = node.e1.tk, typ = ub, w = node } else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11533,16 +11566,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- we handle ':' separately from '.' because ':' is specific to records, -- so we produce different error messages - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua is TypeVarType) then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return an_unknown(node) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11550,7 +11583,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11568,33 +11601,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = ua - elseif ((ra is EnumType and rb is StringType and is_a(rb, ra)) - or (ra is StringType and rb is EnumType and is_a(ra, rb))) then + elseif ((ra is EnumType and rb is StringType and self:is_a(rb, ra)) + or (ra is StringType and rb is EnumType and self:is_a(ra, rb))) then node.known = nil t = (ra is EnumType and ra or rb) elseif expected and expected is UnionType then -- must be checked after string/enum above node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ra, rb}, true) + local u = unite(node, {ra, rb}, true) if u is UnionType then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or an_invalid(node) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11613,44 +11646,46 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "==" or node.op.op == "~=" then -- if is_lua_table_type(ra) and is_lua_table_type(rb) then --- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) +-- self:check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end if ra is EnumType and rb is StringType then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua is TypeVarType then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact { var = node.e1.tk, typ = ub, where = node } + node.known = EqFact { var = node.e1.tk, typ = ub, w = node } end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub is TypeVarType then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact { var = node.e2.tk, typ = ua, where = node } + node.known = EqFact { var = node.e2.tk, typ = ua, w = node } end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return an_unknown(node) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(node, ra.types, true) -- squash unions of string constants end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra is RecordLikeType then t = find_in_interface_list(ra, function(ty: Type): Type - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11658,19 +11693,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra is MapType then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11678,12 +11712,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11697,39 +11731,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(ra, ra.types, true) -- squash unions of string constants end if rb is UnionType then - rb = unite(rb.types, true) -- squash unions of string constants + rb = unite(rb, rb.types, true) -- squash unions of string constants end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ua, ub}) + local u = unite(node, {ua, ub}) if u is UnionType and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua is NominalType and ub is NominalType and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11737,20 +11771,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div: Node = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div: Node = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11762,28 +11796,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t: Type if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use: VarUse = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t is TypeDeclType then @@ -11794,70 +11828,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["type_identifier"] = { - after = function(node: Node, _children: {Type}): Type - local typ, attr = find_var_type(node.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local t = children[1] if not t then - t = UNKNOWN + t = an_unknown(node) end if node.tk == "..." then - t = a_vararg { t } + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE -- type is resolved elsewhere end, }, ["newtype"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.newtype end, }, ["error_node"] = { - after = function(_node: Node, _children: {Type}): Type - return INVALID + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type + return an_invalid(node) end, } } visit_node.cbs["break"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node: Node): Type + local function after_literal(_self: TypeChecker, node: Node): Type node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind as TypeName, {})) + return a_type(node, node.kind as TypeName, {}) end visit_node.cbs["string"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) as StringType + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) as StringType t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected is EnumType and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected is EnumType and self:is_a(t, expected) then return node.expected end @@ -11868,8 +11902,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11880,7 +11914,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node: Node, _children: {Type}, t: Type): Type + visit_node.after = function(_self: TypeChecker, node: Node, _children: {Type}, t: Type): Type if node.expanded then apply_macroexp(node) end @@ -11888,13 +11922,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local expand_interfaces: function(Type) do - local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) + local function add_interface_fields(self: TypeChecker, what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11903,18 +11936,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function collect_interfaces(list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} + local function collect_interfaces(self: TypeChecker, list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri is InterfaceType, "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri is InterfaceType then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -11927,30 +11963,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return list end - expand_interfaces = function(t: RecordLikeType) + function TypeChecker:expand_interfaces(t: RecordLikeType) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri is InterfaceType) - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -11958,33 +11994,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local visit_type: Visitor + local visit_type: Visitor visit_type = { cbs = { ["function"] = { - before = function(_typ: Type) - begin_scope() + before = function(self: TypeChecker, _typ: Type) + self:begin_scope() end, - after = function(typ: Type, _children: {Type}): Type - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self: TypeChecker, typ: Type, _children: {Type}): Type + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ: RecordType) - begin_scope() - add_var(nil, "@self", type_at(typ, a_typedecl(typ))) + before = function(self: TypeChecker, typ: RecordType) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) for fname, ftype in fields_of(typ) do if ftype is TypeAliasType then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype is TypeDeclType then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ: RecordType, children: {Type}): Type + after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11998,11 +12034,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if iface is ArrayType then typ.interface_list[j] = iface elseif iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri is InterfaceType then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12042,7 +12078,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif ftype is TypeAliasType then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12061,55 +12097,55 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ: TypeArgType, _children: {Type}): Type - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self: TypeChecker, typ: TypeArgType, _children: {Type}): Type + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) return typ end, }, ["typevar"] = { - after = function(typ: TypeVarType, _children: {Type}): Type - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self: TypeChecker, typ: TypeVarType, _children: {Type}): Type + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ: NominalType, _children: {Type}): Type + after = function(self: TypeChecker, typ: NominalType, _children: {Type}): Type if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t is TypeArgType then -- convert nominal into a typevar typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint @@ -12120,18 +12156,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ: UnionType, _children: {Type}): Type + after = function(self: TypeChecker, typ: UnionType, _children: {Type}): Type local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or an_invalid(typ) end return typ end @@ -12139,59 +12176,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, } - local function internal_compiler_check(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - if type(t) ~= "table" then - error(((w as Node).kind or (w as Type).typename) .. " did not produce a type") - end - if type(t.typename) ~= "string" then - error(((w as Node).kind or (w as Type).typename) .. " type does not have a typename") - end - - return t - end - end - - local function store_type_after(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - local where = w as Where - - if where.y then - tc.store_type(where.y, where.x, t) - end - - return t - end - end - - local function debug_type_after(fn: function(Node, {Type}, Type): (Type)): (function(Node, {Type}, Type): (Type)) - return function(node: Node, children: {Type}, t: Type): Type - t = fn and fn(node, children, t) or t - node.debug_type = t - return t - end - end - - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end - - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) - end - - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) - end - local default_type_visitor = { - after = function(typ: Type, _children: {Type}): Type + after = function(_self: TypeChecker, typ: Type, _children: {Type}): Type return typ end, } @@ -12218,70 +12204,201 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + local type VisitorAfterPatcher = function(VisitorAfter): VisitorAfter - close_types(st[1]) - check_for_unused_vars(st[1], true) + local function internal_compiler_check(fn: VisitorAfter): VisitorAfter + return function(s: S, n: N, children: {Type}, t: Type): Type + t = fn and fn(s, n, children, t) or t + + if type(t) ~= "table" then + error(((n as Node).kind or (n as Type).typename) .. " did not produce a type") + end + if type(t.typename) ~= "string" then + error(((n as Node).kind or (n as Type).typename) .. " type does not have a typename") + end - clear_redundant_errors(errors) + return t + end + end - add_compat_entries(ast, all_needs_compat, env.gen_compat) + local function store_type_after(fn: VisitorAfter): VisitorAfter + return function(self: TypeChecker, n: N, children: {Type}, t: Type): Type + t = fn and fn(self, n, children, t) or t - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + local w = n as Where - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + if w.y then + self.collector.store_type(w.y, w.x, t) + end - if tc then - env.reporter:store_result(tc, env.globals) + return t + end end - return result -end + local function debug_type_after(fn: VisitorAfter): VisitorAfter + return function(s: S, node: Node, children: {Type}, t: Type): Type + t = fn and fn(s, node, children, t) or t --------------------------------------------------------------------------------- --- Report types --------------------------------------------------------------------------------- + node.debug_type = t + return t + end + end -function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} - local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer - local function le(a: {integer, integer}, b: {integer, integer}): boolean - return a[1] < b[1] - or (a[1] == b[1] and a[2] <= b[2]) + local function patch_visitors(my_visit_node: Visitor, + after_node: VisitorAfterPatcher, + my_visit_type?: Visitor, + after_type?: VisitorAfterPatcher): + Visitor, + Visitor + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) end - return binary_search(symbols, {at_y, at_x}, le) or 0 + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - local ret: {string:integer} = {} + local function set_feat(feat: Feat, default: boolean): boolean + if feat then + return (feat == "on") + else + return default + end + end - local n = find(tr.symbols, y, x) + tl.type_check = function(ast: Node, filename: string, opts: TypeCheckOptions, env?: Env): Result, string + assert(filename is string, "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts as {any:any}).env), "tl.type_check signature has changed, pass env separately") - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] + filename = filename or "?" + + opts = opts or {} + + if not env then + local err: string + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end + + local self: TypeChecker = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } + + setmetatable(self, { __index = TypeChecker }) + + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET + + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 + + self.subtype_relations = shallow_copy_table(self.subtype_relations) + + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true + + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true + -- in .lua files, all values can be used in a boolean context + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets: TupleType): TupleType + if #rets.tuple == 0 then + return a_vararg(rets, { an_unknown(rets) }) + end + return rets + end else - ret[s[3]] = s[4] - n = n - 1 + self.get_rets = function(rets: TupleType): TupleType + return rets + end end - end - return ret + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check + ) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after + ) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after + ) + end + + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) + + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) + + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) + end + + return result + end end -------------------------------------------------------------------------------- @@ -12297,9 +12414,24 @@ local function read_full_file(fd: FILE): string, string return content, err end -tl.process = function(filename: string, env: Env, fd?: FILE): Result, string - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename?: string, input?: string): Feat + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename: string, env: Env, fd?: FILE): Result, string if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12319,23 +12451,38 @@ tl.process = function(filename: string, env: Env, fd?: FILE): Result, string return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua: boolean - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") as boolean +function tl.target_from_lua_version(str: string): GenTarget + if str == "Lua 5.1" + or str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime: boolean, filename?: string, input?: string): EnvOptions + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat: GenCompat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + } + } end -function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string): Result - env = env or tl.init_env(is_lua) +function tl.process_string(input: string, env?: Env, filename?: string): Result + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12347,7 +12494,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12357,14 +12504,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: return result end - local opts: TypeCheckOptions = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12372,15 +12512,15 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: end tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code: string - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12396,28 +12536,25 @@ local function tl_package_loader(module_name: string): any, any if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type -- TODO: should this be a hard error? this seems analogous to -- finding a lua file with a syntax error in it - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname: string, loader_data: string): any @@ -12443,21 +12580,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str: string): TargetMode - if str == "Lua 5.1" - or str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax: boolean, env_tbl: {any:any}): Env +local function env_for(opts: EnvOptions, env_tbl: {any:any}): Env if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12466,7 +12592,7 @@ local function env_for(lax: boolean, env_tbl: {any:any}): Env tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12476,17 +12602,14 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12500,7 +12623,7 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a mode = mode:gsub("c", "") as LoadMode end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12508,4 +12631,29 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return load(code, chunkname, mode, ...) end +-------------------------------------------------------------------------------- +-- Backwards compatibility +-------------------------------------------------------------------------------- + +function tl.get_types(result: Result): TypeReport, TypeReporter + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax?: boolean, gen_compat?: boolean | GenCompat, gen_target?: GenTarget, predefined?: {string}): Env, string + local opts = { + defaults = { + feat_lax = (lax and "on" or "off") as Feat, + gen_compat = ((gen_compat is GenCompat) and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl