From 0d34326d71cb5dda4814076cee733ced7dd43958 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 25 Oct 2024 15:59:09 -0300 Subject: [PATCH] fix: type variables from returns resolve for arguments Fixes #838. --- spec/lang/inference/function_call_spec.lua | 70 ++++++++++++++++++++ spec/lang/inference/function_result_spec.lua | 19 ------ spec/lang/statement/forin_spec.lua | 2 +- tl.lua | 4 +- tl.tl | 4 +- 5 files changed, 75 insertions(+), 24 deletions(-) create mode 100644 spec/lang/inference/function_call_spec.lua delete mode 100644 spec/lang/inference/function_result_spec.lua diff --git a/spec/lang/inference/function_call_spec.lua b/spec/lang/inference/function_call_spec.lua new file mode 100644 index 000000000..c9758bcde --- /dev/null +++ b/spec/lang/inference/function_call_spec.lua @@ -0,0 +1,70 @@ +local util = require("spec.util") + +describe("function call", function() + describe("results", function() + it("should be adjusted down to 1 result in an expression list", util.check([[ + local function f(): string, number + end + local a, b = f(), "hi" + a = "hey" + ]])) + + it("can resolve type arguments based on expected type at use site (#512)", util.check([[ + local function get_foos():{T} + return {} + end + + local foos:{integer} = get_foos() + print(foos) + ]])) + end) + + describe("arguments", function() + it("type variables from returns resolve for arguments (regression test for #838)", util.check([[ + local fcts: {integer:function(val: any, opt?: string): any} + + local function bar (val: number): number + print(val) + return val + end + + local function bar2 (val: number, val2: string): number + print(val, val2) + return val + end + + fcts = { -- OK, with table constructor + [11] = function (val: string): string + print(val) + return val + end, + [12] = function (val: string, val2: string): string + print(val, val2) + return val + end, + [21] = bar, + [22] = bar2, + } + setmetatable(fcts, { + __tostring = function(): string return 'fcts' end + }) + + fcts = setmetatable({ -- Ok, as an argument via type variable + [11] = function (val: string): string + print(val) + return val + end, + [12] = function (val: string, val2: string): string + print(val, val2) + return val + end, + [21] = bar, + [22] = bar2, + }, { + __tostring = function(): string return 'fcts' end + }) + + print(fcts) + ]])) + end) +end) diff --git a/spec/lang/inference/function_result_spec.lua b/spec/lang/inference/function_result_spec.lua deleted file mode 100644 index 96891cbeb..000000000 --- a/spec/lang/inference/function_result_spec.lua +++ /dev/null @@ -1,19 +0,0 @@ -local util = require("spec.util") - -describe("function results", function() - it("should be adjusted down to 1 result in an expression list", util.check([[ - local function f(): string, number - end - local a, b = f(), "hi" - a = "hey" - ]])) - - it("can resolve type arguments based on expected type at use site (#512)", util.check([[ - local function get_foos():{T} - return {} - end - - local foos:{integer} = get_foos() - print(foos) - ]])) -end) diff --git a/spec/lang/statement/forin_spec.lua b/spec/lang/statement/forin_spec.lua index 3a5916404..327068e85 100644 --- a/spec/lang/statement/forin_spec.lua +++ b/spec/lang/statement/forin_spec.lua @@ -96,7 +96,7 @@ describe("forin", function() it("with a callable record iterator", util.check([[ local record R incr: integer - metamethod __call: function(): integer + metamethod __call: function(self): integer end local function foo(incr: integer): R diff --git a/tl.lua b/tl.lua index 93c69f3d8..ec3321984 100644 --- a/tl.lua +++ b/tl.lua @@ -12267,13 +12267,13 @@ self:expand_type(node, values, elements) }) for _, typ in ipairs(e1args) do at = at + 1 if node.e2[at] then - node.e2[at].expected = typ + node.e2[at].expected = self:infer_at(node.e2[at], typ) end end if e1type.args.is_va then local typ = e1args[#e1args] for i = at + 1, #node.e2 do - node.e2[i].expected = typ + node.e2[i].expected = self:infer_at(node.e2[i], typ) end end end diff --git a/tl.tl b/tl.tl index c894a74d6..af2e01ff5 100644 --- a/tl.tl +++ b/tl.tl @@ -12267,13 +12267,13 @@ do for _, typ in ipairs(e1args) do at = at + 1 if node.e2[at] then - node.e2[at].expected = typ + node.e2[at].expected = self:infer_at(node.e2[at], typ) end end if e1type.args.is_va then local typ = e1args[#e1args] for i = at + 1, #node.e2 do - node.e2[i].expected = typ + node.e2[i].expected = self:infer_at(node.e2[i], typ) end end end