Skip to content

Commit

Permalink
Disallow returning multiple types from modules, fixes #227
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Nov 4, 2023
1 parent a60b9fe commit a1b4a34
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 13 deletions.
32 changes: 23 additions & 9 deletions lualib/nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ local function visitor_Call(context, node, argnodes, calleetype, calleesym, call
end
attr.calleesym = calleesym
if calleetype then
attr.type = calleetype:get_return_type(1)
attr.type, attr.value = calleetype:get_return_type_and_value(1)
sideeffect = calleetype.sideeffect
if calleetype.symbol then
calleetype.symbol:add_use_by(context.state.funcscope.funcsym)
Expand Down Expand Up @@ -2299,14 +2299,8 @@ function visitors.VarDecl(context, node)
if vartype.is_nolvalue then
varnode:raisef("variable declaration cannot be of the type '%s'", vartype)
end
if vartype.is_type and not valnode then
varnode:raisef("a type declaration must assign to a type")
end
end
assert(symbol.type == vartype)
if (varnode.attr.comptime or varnode.attr.const) and not varnode.attr.nodecl and not valnode then
varnode:raisef("const variables must have an initial value")
end
if valnode then
context:traverse_node(valnode, {symbol=symbol, desiredtype=vartype})
valtype = valnode.attr.type
Expand Down Expand Up @@ -2339,6 +2333,16 @@ function visitors.VarDecl(context, node)
elseif vartype == primtypes.type and valtype ~= primtypes.type then
valnode:raisef("cannot assign a type to '%s'", valtype)
end
else
if i > 1 and (valtype and valtype.is_type) then
varnode:raisef("a type declaration can only assign to the first assignment expression")
end
if vartype and vartype.is_type then
varnode:raisef("a type declaration must assign to a type")
end
if (varnode.attr.comptime or varnode.attr.const) and not varnode.attr.nodecl then
varnode:raisef("const variables must have an initial value")
end
end
if not inscope then
symbol.scope:add_symbol(symbol)
Expand Down Expand Up @@ -2494,14 +2498,20 @@ function visitors.Return(context, node)
end
end
if retnode then
if rettype and rettype.is_type then
funcscope:add_return_value(i, retnode.attr.value)
end
done = done and retnode.done and true
end
end
node.done = done
else
context:traverse_nodes(retnodes)
for i,retnode,rettype in iargnodes(retnodes) do
funcscope:add_return_type(i, rettype, retnode and retnode.attr.value, retnode)
funcscope:add_return_type(i, rettype, retnode)
if rettype and retnode and rettype.is_type then
funcscope:add_return_value(i, retnode.attr.value)
end
end
end
end
Expand Down Expand Up @@ -2534,7 +2544,11 @@ function visitors.In(context, node)
else
context:traverse_node(retnode)
local retattr = retnode.attr
exprscope:add_return_type(1, retattr.type, retattr.value, retnode)
local rettype = retattr.type
exprscope:add_return_type(1, retattr.type, retnode)
if rettype and rettype.is_type then
exprscope:add_return_value(1, retattr.value)
end
end
end

Expand Down
18 changes: 14 additions & 4 deletions lualib/nelua/scope.lua
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ function Scope:resolve_symbols()
return count
end

function Scope:add_return_type(index, type, value, refnode)
function Scope:add_return_type(index, type, refnode)
if not type then
-- ignore the unknown types in recursive functions
if refnode then
Expand Down Expand Up @@ -386,9 +386,19 @@ function Scope:add_return_type(index, type, value, refnode)
elseif type and not tabler.ifind(rettypes, type) then
rettypes[#rettypes+1] = type
end
if value then
self.retvalues = self.retvalues or {}
self.retvalues[index] = value
end

function Scope:add_return_value(index, value)
if not value then return end
local retvalues = self.retvalues
if not retvalues then
retvalues = {}
self.retvalues = retvalues
end
if retvalues[index] == nil then
retvalues[index] = value
elseif retvalues[index] ~= value then
self.node:raisef("function cannot return multiple distinct compile time types")
end
end

Expand Down
15 changes: 15 additions & 0 deletions lualib/nelua/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,21 @@ function FunctionType:get_return_type(index)
end
end

-- Get the return value in the specified index.
function FunctionType:get_return_type_and_value(index)
local rettype = self:get_return_type(index)
if not rettype.is_comptime then return rettype end
local node = self.node
if not node then return rettype end
local scope = node.scope
if not scope then return rettype end
local retvalues = scope.retvalues
if not retvalues then return rettype end
local retvalue = retvalues[index]
if not retvalue then return rettype end
return rettype, retvalue
end

-- Get the desired type when converting this type from another type.
function FunctionType:get_convertible_from_type(type, explicit, fromcall)
if type.is_nilptr then
Expand Down
6 changes: 6 additions & 0 deletions spec/analyzer_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,13 @@ it("function definition", function()
]])
expect.analyze_ast([[
local function f(): type return @integer end
local int = f()
local a: int = 0
]])
expect.analyze_error([[
local function f() return @integer, @string end
local int, str = f()
]], "a type declaration can only assign to the first assignment expression")
expect.analyze_error([[
do
global function f() end
Expand Down

0 comments on commit a1b4a34

Please sign in to comment.