Skip to content

Commit

Permalink
fix: propagate bidirectional inference for variadic arguments
Browse files Browse the repository at this point in the history
Closes #715.
  • Loading branch information
hishamhm committed Nov 6, 2023
1 parent a14954f commit c5150c6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
2 changes: 1 addition & 1 deletion spec/assignment/to_map_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ describe("assignment to maps", function()
f({"string value", pi=math.pi})
]], {
{ msg = "argument 1: in map key: got integer, expected string" }
{ msg = "in map key: got integer, expected string" }
}))
end)
20 changes: 17 additions & 3 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9576,8 +9576,13 @@ tl.type_check = function(ast, opts)
assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n))
end
elseif node[i].key_parsed == "implicit" then
if is_map then
assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key"))
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value"))
end
force_array = expand_type(node[i], force_array, child.vtype)
elseif is_map then
force_array = nil
assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key"))
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value"))
else
Expand Down Expand Up @@ -9881,9 +9886,18 @@ tl.type_check = function(ast, opts)
if node.expected then
is_a(node.e1.type.rets, node.expected)
end
for i, typ in ipairs(node.e1.type.args) do
if node.e2[i + argdelta] then
node.e2[i + argdelta].expected = typ
local e1args = node.e1.type.args
local at = argdelta
for _, typ in ipairs(e1args) do
at = at + 1
if node.e2[at] then
node.e2[at].expected = typ
end
end
if e1args.is_va then
local typ = e1args[#e1args]
for i = at + 1, #node.e2 do
node.e2[i].expected = typ
end
end
end
Expand Down
20 changes: 17 additions & 3 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -9576,8 +9576,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n))
end
elseif node[i].key_parsed == "implicit" then
if is_map then
assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key"))
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value"))
end
force_array = expand_type(node[i], force_array, child.vtype)
elseif is_map then
force_array = nil
assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key"))
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value"))
else
Expand Down Expand Up @@ -9881,9 +9886,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
if node.expected then
is_a(node.e1.type.rets, node.expected)
end
for i, typ in ipairs(node.e1.type.args) do
if node.e2[i + argdelta] then
node.e2[i + argdelta].expected = typ
local e1args = node.e1.type.args
local at = argdelta
for _, typ in ipairs(e1args) do
at = at + 1
if node.e2[at] then
node.e2[at].expected = typ
end
end
if e1args.is_va then
local typ = e1args[#e1args]
for i = at + 1, #node.e2 do
node.e2[i].expected = typ
end
end
end
Expand Down

0 comments on commit c5150c6

Please sign in to comment.