From 6d7c8e49ff91b8a8f623b84545acdd90691c6d1c Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 24 Aug 2016 14:32:37 -0400 Subject: [PATCH] Extend support for :asvalue() + test --- src/terralib.lua | 382 +++++++++++++++++++++++------------------ tests/asvalue_select.t | 13 ++ 2 files changed, 227 insertions(+), 168 deletions(-) create mode 100644 tests/asvalue_select.t diff --git a/src/terralib.lua b/src/terralib.lua index ad052772..b49c2839 100644 --- a/src/terralib.lua +++ b/src/terralib.lua @@ -47,7 +47,7 @@ ident = escapedident(luaexpression expression) # removed during specializati field = recfield(ident key, tree value) | listfield(tree value) - + structbody = structentry(string key, luaexpression type) | structlist(structbody* entries) @@ -59,20 +59,20 @@ structdef = (luaexpression? metatype, structlist records) attr = (boolean nontemporal, number? alignment, boolean isvolatile) Symbol = (Type type, string displayname, number id) Label = (string displayname, number id) -tree = +tree = # trees that are introduced in parsing and are ... # removed during specialization luaexpression(function expression, boolean isexpression) # removed during typechecking | constructoru(field* records) #untyped version | selectu(tree value, ident field) #untyped version - | method(tree value,ident name,tree* arguments) + | method(tree value,ident name,tree* arguments) | statlist(tree* statements) | fornumu(param variable, tree initial, tree limit, tree? step,block body) #untyped version | defvar(param* variables, boolean hasinit, tree* initializers) | forlist(param* variables, tree iterator, block body) | functiondefu(param* parameters, boolean is_varargs, TypeOrLuaExpression? returntype, block body) - + # introduced temporarily during specialization/typing, but removed after typing | luaobject(any value) | setteru(function setter) # temporary node introduced and removed during typechecking to handle __update and __setfield @@ -110,7 +110,7 @@ tree = | constructor(tree* expressions) | returnstat(tree expression) | setter(allocvar rhs, tree setter) # handles custom assignment behavior, real rhs is first stored in 'rhs' and then the 'setter' expression uses it - + # special purpose nodes, they only occur in specific locations, but are considered trees because they can contain typed trees | ifbranch(tree condition, block body) | storelocation(number index, tree value) # for struct cast, value uses structvariable @@ -131,11 +131,11 @@ labelstate = undefinedlabel(gotostat * gotos, table* positions) #undefined label definition = functiondef(string? name, functype type, allocvar* parameters, boolean is_varargs, block body, table labeldepths, globalvalue* globalsused) | functionextern(string? name, functype type) - + globalvalue = terrafunction(definition? definition) | globalvariable(tree? initializer, number addressspace, boolean extern, boolean constant) attributes(string name, Type type, table anchor) - + overloadedterrafunction = (string name, terrafunction* definitions) ]] terra.irtypes = T @@ -152,7 +152,7 @@ local tokens = setmetatable({},{__index = function(self,idx) return idx end }) terra.isverbose = 0 --set by C api -local function dbprint(level,...) +local function dbprint(level,...) if terra.isverbose >= level then print(...) end @@ -174,7 +174,7 @@ end function T.tree:is(value) return self.kind == value end - + function terra.printraw(self) local function header(t) local mt = getmetatable(t) @@ -267,7 +267,7 @@ function terra.newanchor(depth) return setmetatable(body,terra.tree) end -function terra.istree(v) +function terra.istree(v) return T.tree:isclassof(v) end @@ -308,7 +308,7 @@ function terra.newenvironment(_luaenv) __index = function(_,idx) return self._localenv[idx] or self._luaenv[idx] end; - __newindex = function() + __newindex = function() error("cannot define global variables or assign to upvalues in an escape") end; }) @@ -333,12 +333,12 @@ local function formaterror(anchor,...) errlist:insert(anchor.filename..":"..anchor.linenumber..": ") for i = 1,select("#",...) do errlist:insert(tostring(select(i,...))) end errlist:insert("\n") - if not anchor.offset then + if not anchor.offset then return errlist:concat() end - + local filename = anchor.filename - local filetext = diagcache[filename] + local filetext = diagcache[filename] if not filetext then local file = io.open(filename,"r") if file then @@ -359,7 +359,7 @@ local function formaterror(anchor,...) while finish < filetext:len() and filetext:byte(finish + 1) ~= NL do finish = finish + 1 end - local line = filetext:sub(begin,finish) + local line = filetext:sub(begin,finish) errlist:insert(line) errlist:insert("\n") for i = begin,anchor.offset do @@ -449,7 +449,7 @@ function debug.traceback(msg,level) local whatname,what = debug.getlocal(level,2) assert(anchorname == "anchor" and whatname == "what") lines:insert("\n\t") - lines:insert(formaterror(anchor,"Errors reported during "..what):sub(1,-2)) + lines:insert(formaterror(anchor,"Errors reported during "..what):sub(1,-2)) else local short_src,currentline,linedefined = di.short_src,di.currentline,di.linedefined local file,outsideline = di.source:match("^@$terra$(.*)$terra$(%d+)$") @@ -467,7 +467,7 @@ function debug.traceback(msg,level) elseif di.what == "main" then lines:insert(" in main chunk") elseif di.what == "C" then - lines:insert( (" at %s"):format(tostring(di.func))) + lines:insert( (" at %s"):format(tostring(di.func))) else lines:insert((" in function <%s:%d>"):format(short_src,linedefined)) end @@ -556,8 +556,8 @@ function T.terrafunction:printstats() end function T.terrafunction:isextern() return self.definition and self.definition.kind == "functionextern" end function T.terrafunction:isdefined() return self.definition ~= nil end -function T.terrafunction:setname(name) - self.name = tostring(name) +function T.terrafunction:setname(name) + self.name = tostring(name) if self.definition then self.definition.name = name end return self end @@ -567,19 +567,19 @@ function T.terrafunction:adddefinition(functiondef) self:resetdefinition(functiondef) end function T.terrafunction:resetdefinition(functiondef) - if T.terrafunction:isclassof(functiondef) and functiondef:isdefined() then + if T.terrafunction:isclassof(functiondef) and functiondef:isdefined() then functiondef = functiondef.definition end assert(T.definition:isclassof(functiondef), "expected a defined terra function") if self.readytocompile then error("cannot reset a definition of function that has already been compiled",2) end - if self.type ~= functiondef.type and self.type ~= terra.types.placeholderfunction then + if self.type ~= functiondef.type and self.type ~= terra.types.placeholderfunction then error(("attempting to define terra function declaration with type %s with a terra function definition of type %s"):format(tostring(self.type),tostring(functiondef.type))) end self.definition,self.type,functiondef.name = functiondef,functiondef.type,assert(self.name) end function T.terrafunction:gettype(nop) assert(nop == nil, ":gettype no longer takes any callbacks for when a function is complete") - if self.type == terra.types.placeholderfunction then + if self.type == terra.types.placeholderfunction then error("function being recursively referenced needs an explicit return type, function defintion at: "..formaterror(self.anchor,""),2) end return self.type @@ -645,7 +645,7 @@ local function constantcheck(e,checklvalue) if e.expression.type:isarray() then if checklvalue then constantcheck(e.expression,true) - else + else erroratlocation(e,"non-constant cast of array to pointer used as a constant initializer") end else constantcheck(e.expression) end @@ -656,7 +656,7 @@ local function constantcheck(e,checklvalue) else erroratlocation(e,"non-constant expression being used as a constant initializer") end - return e + return e end local function createglobalinitializer(anchor, typ, c) @@ -753,7 +753,7 @@ local compilationunit = {} compilationunit.__index = compilationunit function terra.newcompilationunit(target,opt) assert(terra.istarget(target),"expected a target object") - return setmetatable({ symbols = newweakkeytable(), + return setmetatable({ symbols = newweakkeytable(), collectfunctions = opt, llvm_cu = cdatawithdestructor(terra.initcompilationunit(target.llvm_target,opt),terra.freecompilationunit) },compilationunit) -- mapping from Types,Functions,Globals,Constants -> llvm value associated with them for this compilation end @@ -806,7 +806,7 @@ end function terra.createmacro(fromterra,fromlua) return setmetatable({fromterra = fromterra,fromlua = fromlua}, terra.macro) end -function terra.internalmacro(...) +function terra.internalmacro(...) local m = terra.createmacro(...) m._internal = true return m @@ -858,7 +858,7 @@ function T.quote:asvalue() elseif e:is "constructor" then local t,typ = {},e.type for i,r in ipairs(typ:getentries()) do - local v,e = getvalue(e.expressions[i]) + local v,e = getvalue(e.expressions[i]) if e then return nil,e end local key = typ.convertible == "tuple" and i or r.field t[key] = v @@ -868,9 +868,55 @@ function T.quote:asvalue() local v,er = getvalue(e.operands[1]) return type(v) == "number" and -v, er elseif e:is "var" then return e.symbol - else - return nil, "not a constant value (note: :asvalue() isn't implement for all constants yet)" + elseif e:is "operator" and ( + e.operator == tokens[">"] or + e.operator == tokens[">="] or + e.operator == tokens["<"] or + e.operator == tokens["<="] or + e.operator == tokens["+"] or + e.operator == tokens["-"] or + e.operator == tokens["*"] or + e.operator == tokens["/"] or + e.operator == tokens["%"] ) and #e.operands == 2 then + local op1 = getvalue(e.operands[1]) + local op2 = getvalue(e.operands[2]) + if op1 ~= nil and op2 ~= nil then + if e.operator == tokens[">"] then return op1 > op2 end + if e.operator == tokens[">="] then return op1 >= op2 end + if e.operator == tokens["<"] then return op1 < op2 end + if e.operator == tokens["<="] then return op1 <= op2 end + if e.operator == tokens["+"] then return op1 + op2 end + if e.operator == tokens["-"] then return op1 - op2 end + if e.operator == tokens["*"] then return op1 * op2 end + if e.operator == tokens["/"] then return op1 / op2 end + if e.operator == tokens["%"] then return op1 % op2 end + end + elseif e:is "operator" and ( + e.operator == "and" or + e.operator == "or" ) and #e.operands == 2 then + local op1 = getvalue(e.operands[1]) + local op2 = getvalue(e.operands[2]) + if op1 ~= nil and op2 ~= nil then + if e.operator == "and" then return op1 and op2 end + if e.operator == "or" then return op1 or op2 end + end + -- Short-circuit case + if op1 ~= nil or op2 ~= nil then + if e.operator == "and" and + ( (op1 ~= nil and not op1) or (op2 ~= nil and not op2) ) then + return false + end + if e.operator == "or" and (op1 or op2) then return true end + end + elseif e:is "operator" and e.operator == tokens["select"] and #e.operands == 3 then + local cmp = getvalue(e.operands[1]) + if cmp ~= nil then + if cmp then return getvalue(e.operands[2]) end + return getvalue(e.operands[3]) + end end + + return nil, "not a constant value (note: :asvalue() isn't implement for all constants yet)" end return getvalue(self.tree) end @@ -900,7 +946,7 @@ function T.Symbol:__tostring() end function T.Symbol:tocname() return "__symbol"..tostring(self.id) end -_G["symbol"] = terra.newsymbol +_G["symbol"] = terra.newsymbol -- LABEL function terra.islabel(l) return T.Label:isclassof(l) end @@ -950,7 +996,7 @@ terra.asm = terra.internalmacro(function(diag,tree,returntype, asm, constraints, local args = List{...} return typecheck(newobject(tree, T.inlineasm,returntype:astype(), tostring(asm:asvalue()), not not volatile:asvalue(), tostring(constraints:asvalue()), args)) end) - + local evalluaexpression -- CONSTRUCTORS @@ -968,7 +1014,7 @@ local function layoutstruct(st,tree,env) end return { field = v.key, type = resolvedtype } end - + local function getrecords(records) return records:map(function(v) if v.kind == "structlist" then @@ -1032,7 +1078,7 @@ function terra.defineobjects(fmt,envfn,...) end return t,name:match("[^.]*$") end - + local decls = terralib.newlist() for i,c in ipairs(cmds) do --pass: declare all structs if "s" == c.c then @@ -1096,7 +1142,7 @@ function terra.defineobjects(fmt,envfn,...) end end diag:finishandabortiferrors("Errors reported during function declaration.",2) - + for i,c in ipairs(cmds) do -- pass: define structs if "s" == c.c and c.tree then layoutstruct(decls[i],c.tree,env) @@ -1143,7 +1189,7 @@ end -- TYPE -do +do --returns a function string -> string that makes names unique by appending numbers local function uniquenameset(sep) @@ -1164,7 +1210,7 @@ do local function tovalididentifier(name) return tostring(name):gsub("[^_%w]","_"):gsub("^(%d)","_%1"):gsub("^$","_") --sanitize input to be valid identifier end - + local function memoizefunction(fn) local info = debug.getinfo(fn,'u') local nparams = not info.isvararg and info.nparams @@ -1189,7 +1235,7 @@ do return v end end - + local types = {} local defaultproperties = { "name", "tree", "undefined", "incomplete", "convertible", "cachedcstring", "llvm_definingfunction" } for i,dp in ipairs(defaultproperties) do @@ -1206,9 +1252,9 @@ do end T.Type.__tostring = nil --force override to occur T.Type.__tostring = memoizefunction(function(self) - if self:isstruct() then + if self:isstruct() then if self.metamethods.__typename then - local status,r = pcall(function() + local status,r = pcall(function() return tostring(self.metamethods.__typename(self)) end) if status then return r end @@ -1227,7 +1273,7 @@ do if not self.name then error("unknown type?") end return self.name end) - + T.Type.printraw = terra.printraw function T.Type:isprimitive() return self.kind == "primitive" end function T.Type:isintegral() return self.kind == "primitive" and self.type == "integer" end @@ -1242,20 +1288,20 @@ do function T.Type:ispointertostruct() return self:ispointer() and self.type:isstruct() end function T.Type:ispointertofunction() return self:ispointer() and self.type:isfunction() end function T.Type:isaggregate() return self:isstruct() or self:isarray() end - + function T.Type:iscomplete() return not self.incomplete end - + function T.Type:isvector() return self.kind == "vector" end - + function T.Type:isunit() return types.unit == self end - + local applies_to_vectors = {"isprimitive","isintegral","isarithmetic","islogical", "canbeord"} for i,n in ipairs(applies_to_vectors) do T.Type[n.."orvector"] = function(self) - return self[n](self) or (self:isvector() and self.type[n](self.type)) + return self[n](self) or (self:isvector() and self.type[n](self.type)) end end - + --pretty print of layout of type function T.Type:layoutstring() local seen = {} @@ -1305,7 +1351,7 @@ do if not self[key] then if self[inside] then erroratlocation(self.anchor,erroronrecursion) - else + else self[inside] = true self[key] = getvalue(self) self[inside] = nil @@ -1319,25 +1365,25 @@ do local str = "struct "..nm.." { " local entries = layout.entries for i,v in ipairs(entries) do - + local prevalloc = entries[i-1] and entries[i-1].allocation local nextalloc = entries[i+1] and entries[i+1].allocation - + if v.inunion and prevalloc ~= v.allocation then str = str .. " union { " end - + local keystr = terra.islabel(v.key) and v.key:tocname() or v.key str = str..v.type:cstring().." "..keystr.."; " - + if v.inunion and nextalloc ~= v.allocation then str = str .. " }; " end - + end str = str .. "};" local status,err = pcall(ffi.cdef,str) - if not status then + if not status then if err:match("attempt to redefine") then print(("warning: attempting to define a C struct %s that has already been defined by the luajit ffi, assuming the Terra type matches it."):format(nm)) else error(err) end @@ -1395,7 +1441,7 @@ do elseif self:isstruct() then local nm = uniquecname(tostring(self)) ffi.cdef("typedef struct "..nm.." "..nm..";") --just make a typedef to the opaque type - --when the struct is + --when the struct is self.cachedcstring = nm if self.cachedlayout then definecstruct(nm,self.cachedlayout) @@ -1414,7 +1460,7 @@ do local pow2 = 1 --round N to next power of 2 while pow2 < self.N do pow2 = 2*pow2 end ffi.cdef("typedef "..value.." "..nm.." __attribute__ ((vector_size("..tostring(pow2*elemSz)..")));") - self.cachedcstring = nm + self.cachedcstring = nm elseif self == types.niltype then local nilname = uniquecname("niltype") ffi.cdef("typedef void * "..nilname..";") @@ -1427,13 +1473,13 @@ do error("NYI - cstring") end if not self.cachedcstring then error("cstring not set? "..tostring(self)) end - + --create a map from this ctype to the terra type to that we can implement terra.typeof(cdata) local ctype = ffi.typeof(self.cachedcstring) types.ctypetoterra[tonumber(ctype)] = self local rctype = ffi.typeof(self.cachedcstring.."&") types.ctypetoterra[tonumber(rctype)] = self - + if self:isstruct() then local function index(obj,idx) local method = self:getmethod(idx) @@ -1448,7 +1494,7 @@ do return self.cachedcstring end - + T.struct.getentries = memoizeproperty{ name = "entries"; @@ -1465,7 +1511,7 @@ do end local function checkentry(e,results) if type(e) == "table" then - local f = e.field or e[1] + local f = e.field or e[1] local t = e.type or e[2] if terra.types.istype(t) and (type(f) == "string" or terra.islabel(f)) then results:insert { type = t, field = f} @@ -1493,7 +1539,7 @@ do end end T.struct.getlayout = memoizeproperty { - name = "layout"; + name = "layout"; erroronrecursion = "type recursively contains itself, or using a type whose layout failed"; getvalue = function(self) local tree = self.anchor @@ -1501,7 +1547,7 @@ do local nextallocation = 0 local uniondepth = 0 local unionsize = 0 - + local layout = { entries = terra.newlist(), keytoindex = {} @@ -1513,12 +1559,12 @@ do elseif t:isarray() then ensurelayout(t.type) elseif t == types.opaque then - reportopaque(self) + reportopaque(self) end end ensurelayout(t) local entry = { type = t, key = k, allocation = nextallocation, inunion = uniondepth > 0 } - + if layout.keytoindex[entry.key] ~= nil then erroratlocation(tree,"duplicate field ",tostring(entry.key)) end @@ -1553,7 +1599,7 @@ do end end addentrylist(entries) - + dbprint(2,"Resolved Named Struct To:") dbprintraw(2,self) if self.cachedcstring then @@ -1567,7 +1613,7 @@ do self.returntype:complete() return self end - function T.Type:complete() + function T.Type:complete() if self.incomplete then if self:isarray() then self.type:complete() @@ -1598,7 +1644,7 @@ do function T.Type:tcomplete(anchor) return invokeuserfunction(anchor,"finalizing type",false,self.complete,self) end - + local function defaultgetmethod(self,methodname) local fnlike = self.methods[methodname] if not fnlike and terra.ismacro(self.metamethods.__methodmissing) then @@ -1628,20 +1674,20 @@ do function T.struct:getfields() return self:getlayout().entries end - + function types.istype(t) return T.Type:isclassof(t) end - + --map from luajit ffi ctype objects to corresponding terra type types.ctypetoterra = {} - + local function globaltype(name, typ) typ.name = typ.name or name rawset(_G,name,typ) types[name] = typ end - + --initialize integral types local integer_sizes = {1,2,4,8} for _,size in ipairs(integer_sizes) do @@ -1654,23 +1700,23 @@ do globaltype(name,typ) typ:cstring() -- force registration of integral types so calls like terra.typeof(1LL) work end - end - + end + globaltype("float", T.primitive("float",4,true)) globaltype("double",T.primitive("float",8,true)) globaltype("bool", T.primitive("logical",1,false)) - + types.error,T.error.name = T.error,"" T.luaobjecttype.name = "luaobjecttype" - + types.niltype = T.niltype globaltype("niltype",T.niltype) - + types.opaque,T.opaque.incomplete = T.opaque,true globaltype("opaque", T.opaque) - + types.array,types.vector,types.functype = T.array,T.vector,T.functype - + T.functype.incomplete = true function T.functype:init() if self.isvararg and #self.parameters == 0 then error("vararg functions must have at least one concrete parameter") end @@ -1679,18 +1725,18 @@ do function T.array:init() self.incomplete = true end - + function T.vector:init() if not self.type:isprimitive() and self.type ~= T.error then error("vectors must be composed of primitive types (for now...) but found type "..tostring(self.type)) end end - + types.tuple = memoizefunction(function(...) local args = terra.newlist {...} local t = types.newstruct() for i,e in ipairs(args) do - if not types.istype(e) then + if not types.istype(e) then error("expected a type but found "..type(e)) end t.entries:insert {"_"..(i-1),e} @@ -1714,7 +1760,7 @@ do function types.newstructwithanchor(displayname,anchor) assert(displayname ~= "") local name = getuniquestructname(displayname) - local tbl = T.struct(name) + local tbl = T.struct(name) tbl.entries = List() tbl.methods = {} tbl.metamethods = {} @@ -1722,7 +1768,7 @@ do tbl.incomplete = true return tbl end - + function types.funcpointer(parameters,ret,isvararg) if types.istype(parameters) then parameters = {parameters} @@ -1761,7 +1807,7 @@ end -- TYPECHECKER function evalluaexpression(env, e) if not T.luaexpression:isclassof(e) then - error("not a lua expression?") + error("not a lua expression?") end assert(type(e.expression) == "function") local fn = e.expression @@ -1788,7 +1834,7 @@ function evaltype(diag,env,typ) diag:reporterror(typ,"expected a type but found ",terra.type(v)) return terra.types.error end - + function evaluateparameterlist(diag, env, paramlist, requiretypes) local result = List() for i,p in ipairs(paramlist) do @@ -1820,21 +1866,21 @@ function evaluateparameterlist(diag, env, paramlist, requiretypes) if requiretypes and not entry.type then diag:reporterror(entry,"type must be specified for parameters and uninitialized variables") end - + end return result end - + local function semanticcheck(diag,parameters,block) local symbolenv = terra.newenvironment() - + local labelstates = {} -- map from label value to labelstate object, either representing a defined or undefined label - local globalsused = List() - + local globalsused = List() + local loopdepth = 0 local function enterloop() loopdepth = loopdepth + 1 end local function leaveloop() loopdepth = loopdepth - 1 end - + local scopeposition = List() local function getscopeposition() return List { unpack(scopeposition) } end local function getscopedepth(position) @@ -1959,7 +2005,7 @@ local function semanticcheck(diag,parameters,block) end visit(parameters) visit(block) - + --check the label table for any labels that have been referenced but not defined local labeldepths = {} for k,state in pairs(labelstates) do @@ -1969,7 +2015,7 @@ local function semanticcheck(diag,parameters,block) labeldepths[k] = getscopedepth(state.position) end end - + return labeldepths, globalsused end @@ -1977,7 +2023,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local env = terra.newenvironment(luaenv or {}) local diag = terra.newdiagnostics() simultaneousdefinitions = simultaneousdefinitions or {} - + local invokeuserfunction = function(...) diag:finishandabortiferrors("Errors reported during typechecking.",2) return invokeuserfunction(...) @@ -1986,7 +2032,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) diag:finishandabortiferrors("Errors reported during typechecking.",2) return evalluaexpression(...) end - + local function checklabel(e,stringok) if e.kind == "namedident" then return e end local r = evalluaexpression(env:combinedenv(),e.expression) @@ -2052,7 +2098,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local layout = v.type:getlayout(v) local index = layout.keytoindex[field] - + if index == nil then return nil,false end @@ -2152,7 +2198,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) end end local structvariable, var_ref = allocvar(exp,exp.type,"") - + local entries = List() if #from.entries > #to.entries or (not explicit and #from.entries ~= #to.entries) then err("structural cast invalid, source has ",#from.entries," fields but target has only ",#to.entries) @@ -2186,7 +2232,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return createcast(exp,typ), true elseif typ:ispointer() and exp.type == terra.types.niltype then --niltype can be any pointer return createcast(exp,typ), true - elseif typ:isstruct() and typ.convertible and exp.type:isstruct() and exp.type.convertible then + elseif typ:isstruct() and typ.convertible and exp.type:isstruct() and exp.type.convertible then return structcast(false,exp,typ,speculative), true elseif typ:ispointer() and exp.type:isarray() and typ.type == exp.type.type then return createcast(exp,typ), true @@ -2214,7 +2260,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local success,result = invokeuserfunction(exp, "invoking __cast", true,__cast,exp.type,typ,quotedexp) if success then local result = asterraexpression(exp,result) - if result.type ~= typ then + if result.type ~= typ then diag:reporterror(exp,"user-defined cast returned expression with the wrong type.") end return result,true @@ -2247,7 +2293,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) elseif (typ:isprimitive() and exp.type:isprimitive()) or (typ:isvector() and exp.type:isvector() and typ.N == exp.type.N) then --explicit conversions from logicals to other primitives are allowed return createcast(exp,typ) - elseif typ:isstruct() and exp.type:isstruct() and exp.type.convertible then + elseif typ:isstruct() and exp.type:isstruct() and exp.type.convertible then return structcast(true,exp,typ) else return insertcast(exp,typ) --otherwise, allow any implicit casts @@ -2326,7 +2372,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) entries:insert(rt) end return terra.types.tuple(unpack(entries)) - else + else err() return terra.types.error end @@ -2344,7 +2390,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return e:aserror() end return ee:copy { operands = List{e} }:withtype(e.type) - end + end local function meetbinary(e,property,lhs,rhs) @@ -2371,9 +2417,9 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if #operands == 1 then return checkunary(e,operands,"isarithmeticorvector") end - + local l,r = unpack(operands) - + local function pointerlike(t) return t:ispointer() or t:isarray() end @@ -2424,15 +2470,15 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) else typ = a.type end - + a = insertcast(a,typ) b = insertcast(b,typ) - + else diag:reporterror(ee,"arguments to shift must be integers but found ",a.type," and ", b.type) end end - + return ee:copy { operands = List{a,b} }:withtype(typ) end @@ -2446,7 +2492,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) diag:reporterror(ee,"conditional in select is not the same shape as ",cond.type) end elseif cond.type ~= terra.types.bool then - diag:reporterror(ee,"expected a boolean or vector of booleans but found ",cond.type) + diag:reporterror(ee,"expected a boolean or vector of booleans but found ",cond.type) end end return ee:copy {operands = List {cond,l,r}}:withtype(t) @@ -2475,7 +2521,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local function checkoperator(ee) local op_string = ee.operator - + --check non-overloadable operators first if op_string == "@" then local e = checkexp(ee.operands[1]) @@ -2485,16 +2531,16 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local ty = terra.types.pointer(e.type) return ee:copy { operands = List {e} }:withtype(ty) end - + local op, genericoverloadmethod, unaryoverloadmethod = unpack(operator_table[op_string] or {}) - + if op == nil then diag:reporterror(ee,"operator ",op_string," not defined in terra code.") return ee:aserror() end - + local operands = ee.operands:map(checkexp) - + local overloads = terra.newlist() for i,e in ipairs(operands) do if e.type:isstruct() then @@ -2505,7 +2551,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) end end end - + if #overloads > 0 then return checkcall(ee, overloads, operands, "all", true, "expression") end @@ -2514,7 +2560,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) --functions to handle typecheck invocations (functions,methods,macros,operator overloads) local function removeluaobject(e) - if not e:is "luaobject" or e.type == terra.types.error then + if not e:is "luaobject" or e.type == terra.types.error then return e --don't repeat error messages else if terra.types.istype(e.value) then @@ -2565,7 +2611,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local function tryinsertcasts(anchor, typelists,castbehavior, speculate, allowambiguous, paramlist) local PERFECT_MATCH,CAST_MATCH,TOP = 1,2,math.huge - + local function trylist(typelist, speculate) if #typelist ~= #paramlist then if not speculate then @@ -2645,7 +2691,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if #results > 1 and not allowambiguous then local strings = results:map(function(x) return mkstring(typelists[x.idx],"type list (",",",") ") end) diag:reporterror(anchor,"call to overloaded function is ambiguous. can apply to ",unpack(strings)) - end + end return results[1].expressions, results[1].idx end end @@ -2704,11 +2750,11 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if location == "lexpression" and typ.metamethods.__update then local function setter(rhs) arguments:insert(rhs) - return checkmethodwithreciever(exp, true, "__update", fnlike, arguments, "statement") + return checkmethodwithreciever(exp, true, "__update", fnlike, arguments, "statement") end return newobject(exp,T.setteru,setter) end - return checkmethodwithreciever(exp, true, "__apply", fnlike, arguments, location) + return checkmethodwithreciever(exp, true, "__apply", fnlike, arguments, location) end end return checkcall(exp, terra.newlist { fnlike } , arguments, "none", false, location) @@ -2716,8 +2762,8 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) function checkcall(anchor, fnlikelist, arguments, castbehavior, allowambiguous, location) --arguments are always typed trees, or a lua object assert(#fnlikelist > 0) - - --collect all the terra functions, stop collecting when we reach the first + + --collect all the terra functions, stop collecting when we reach the first --macro and record it as themacro local terrafunctions = terra.newlist() local themacro = nil @@ -2751,14 +2797,14 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if fn.type ~= terra.types.error then diag:reporterror(anchor,"expected a function but found ",fn.type) end - end + end end local function createcall(callee, paramlist) callee.type.type:tcompletefunction(anchor) return newobject(anchor,T.apply,callee,paramlist):withtype(callee.type.type.returntype) end - + if #terrafunctions > 0 then local paramlist = arguments:map(removeluaobject) local function getparametertypes(fn) --get the expected types for parameters to the call (this extends the function type to the length of the parameters if the function is vararg) @@ -2824,7 +2870,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local v = checkexp(e.value,"luavalue") local f = checklabel(e.field,true) local field = f.value - + if v:is "luaobject" then -- handle A.B where A is a luatable or type --check for and handle Type.staticmethod if terra.types.istype(v.value) and v.value:isstruct() then @@ -2846,7 +2892,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return asterraexpression(e,selected,location) end end - + if v.type:ispointertostruct() then --allow 1 implicit dereference v = insertdereference(v) end @@ -2856,12 +2902,12 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if not success then --struct has no member field, call metamethod __entrymissing local typ = v.type - + local function checkmacro(metamethod,arguments,location) local named = terra.internalmacro(function(ctx,tree,...) return typ.metamethods[metamethod]:run(ctx,tree,field,...) end) - local getter = asterraexpression(e, named, "luaobject") + local getter = asterraexpression(e, named, "luaobject") return checkcall(v, terra.newlist{ getter }, arguments, "first", false, location) end if location == "lexpression" and typ.metamethods.__setentry then @@ -2891,7 +2937,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) elseif e:is "index" then local v = checkexp(e.value) local idx = checkexp(e.index) - local typ,lvalue = terra.types.error, v.type:ispointer() or (v.type:isarray() and v.lvalue) + local typ,lvalue = terra.types.error, v.type:ispointer() or (v.type:isarray() and v.lvalue) if v.type:ispointer() or v.type:isarray() or v.type:isvector() then typ = v.type.type if not idx.type:isintegral() and idx.type ~= terra.types.error then @@ -2912,7 +2958,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) elseif e:is "vectorconstructor" or e:is "arrayconstructor" then local entries = checkexpressions(e.expressions) local N = #entries - + local typ if e.oftype ~= nil then typ = e.oftype:tcomplete(e) @@ -2921,14 +2967,14 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) diag:reporterror(e,"cannot determine type of empty aggregate") return e:aserror() end - + --figure out what type this vector has typ = entries[1].type for i,e2 in ipairs(entries) do typ = typemeet(e,typ,e2.type) end end - + local aggtype if e:is "vectorconstructor" then if not typ:isprimitive() and typ ~= terra.types.error then @@ -2939,7 +2985,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) else aggtype = terra.types.array(typ,N) end - + --insert the casts to the right type in the parameter list local typs = entries:map(function(x) return typ end) entries = insertcasts(e,typs,entries) @@ -3001,7 +3047,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return e:aserror() end end - + local result = docheck(e_) --freeze all types returned by the expression (or list of expressions) if not result:is "luaobject" and not result:is "setteru" then @@ -3013,7 +3059,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) if location ~= "luavalue" then result = removeluaobject(result) end - + return result end @@ -3109,7 +3155,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return checkblock(s) elseif s:is "returnstat" then return s:copy { expression = checkexp(s.expression)} - elseif s:is "label" or s:is "gotostat" then + elseif s:is "label" or s:is "gotostat" then local ss = checklabel(s.label) return copyobject(s, { label = ss }) elseif s:is "breakstat" then @@ -3118,7 +3164,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return checkcondbranch(s) elseif s:is "fornumu" then local initial, limit, step = checkexp(s.initial), checkexp(s.limit), s.step and checkexp(s.step) - local t = typemeet(initial,initial.type,limit.type) + local t = typemeet(initial,initial.type,limit.type) t = step and typemeet(limit,t,step.type) or t local variables = checkformalparameterlist(List {s.variable },false) if #variables ~= 1 then @@ -3133,7 +3179,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return newobject(s,T.fornum,variable,initial,limit,step,body) elseif s:is "forlist" then local iterator = checkexp(s.iterator) - + local typ = iterator.type if typ:ispointertostruct() then typ,iterator = typ.type, insertdereference(iterator) @@ -3143,7 +3189,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return s end local generator = typ.metamethods.__for - + local function bodycallback(...) local exps = List() for i = 1,select("#",...) do @@ -3158,7 +3204,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local stats = createstatementlist(s, List { assign, body }) return terra.newquote(stats) end - + local value = invokeuserfunction(s, "invoking __for", false ,generator,terra.newquote(iterator), bodycallback) return asterraexpression(s,value,"statement") elseif s:is "ifstat" then @@ -3172,7 +3218,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) elseif s:is "defvar" then local rhs = s.hasinit and checkexpressions(s.initializers) local lhs = checkformalparameterlist(s.variables, not s.hasinit) - local res = s.hasinit and createassignment(s,lhs,rhs) + local res = s.hasinit and createassignment(s,lhs,rhs) or createstatementlist(s,lhs) return res elseif s:is "assignment" then @@ -3202,7 +3248,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) else newstats:insert(s) end - end + end for _,s in ipairs(stmts) do local r = checksingle(s) if r.kind == "statlist" then -- lists of statements are spliced directly into the list @@ -3289,7 +3335,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local typed_parameters = checkformalparameterlist(topexp.parameters, true) local parameter_types = typed_parameters:map("type") local body,returntype = checkreturns(checkblock(topexp.body),topexp.returntype) - + local fntype = terra.types.functype(parameter_types,returntype,false):tcompletefunction(topexp) diag:finishandabortiferrors("Errors reported during typechecking.",2) local labeldepths,globalsused = semanticcheck(diag,typed_parameters,body) @@ -3331,7 +3377,7 @@ function terra.registerinternalizedfiles(names,contents,sizes) cur.children = cur.children or {} cur.kind = "directory" if not cur.children[segment] then - cur.children[segment] = {} + cur.children[segment] = {} end cur = cur.children[segment] end @@ -3370,7 +3416,7 @@ function terra.includecstring(code,cargs,target) args:insert("-internal-isystem") args:insert(path) end - + if cargs then args:insertall(cargs) end @@ -3444,7 +3490,7 @@ local function createunpacks(tupleonly) local entries = typ:getentries() from = from and tonumber(from:asvalue()) or 1 to = to and tonumber(to:asvalue()) or #entries - for i = from,to do + for i = from,to do local e= entries[i] if e.field then local ident = newobject(tree,type(e.field) == "string" and T.namedident or T.labelident,e.field) @@ -3455,7 +3501,7 @@ local function createunpacks(tupleonly) end local function unpacklua(cdata,from,to) local t = type(cdata) == "cdata" and terra.typeof(cdata) - if not t or not t:isstruct() or (tupleonly and t.convertible ~= "tuple") then + if not t or not t:isstruct() or (tupleonly and t.convertible ~= "tuple") then return cdata end local results = terralib.newlist() @@ -3492,7 +3538,7 @@ local function createattributetable(q) if type(attr) ~= "table" then error("attributes must be a table") end - return T.attr(attr.nontemporal and true or false, + return T.attr(attr.nontemporal and true or false, type(attr.align) == "number" and attr.align or nil, attr.isvolatile and true or false) end @@ -3521,7 +3567,7 @@ function prettystring(toptree,breaklines) local buffer = terralib.newlist() -- list of strings that concat together into the pretty output local env = terra.newenvironment({}) local indentstack = terralib.newlist{ 0 } -- the depth of each indent level - + local currentlinelength = 0 local function enterblock() indentstack:insert(indentstack[#indentstack] + 4) @@ -3535,7 +3581,7 @@ function prettystring(toptree,breaklines) local function emit(fmt,...) local function toformat(x) if type(x) ~= "number" and type(x) ~= "string" then - return tostring(x) + return tostring(x) else return x end @@ -3551,7 +3597,7 @@ function prettystring(toptree,breaklines) end local function differentlocation(a,b) return (a.linenumber ~= b.linenumber or a.filename ~= b.filename) - end + end local lastanchor = { linenumber = "", filename = "" } local function begin(anchor,...) local fname = differentlocation(lastanchor,anchor) and (anchor.filename..":"..anchor.linenumber..": ") @@ -3603,10 +3649,10 @@ function prettystring(toptree,breaklines) end local function emitParam(p) assert(T.allocvar:isclassof(p) or T.param:isclassof(p)) - if T.unevaluatedparam:isclassof(p) then + if T.unevaluatedparam:isclassof(p) then emit("%s%s",IdentToString(p.name),p.type and " : "..luaexpression or "") else - emitIdent(p.name,p.symbol) + emitIdent(p.name,p.symbol) if p.type then emit(" : %s",p.type) end end end @@ -3661,7 +3707,7 @@ function prettystring(toptree,breaklines) begin(s,"for ") emitParam(s.variable) emit(" = ") - emitExp(s.initial) emit(",") emitExp(s.limit) + emitExp(s.initial) emit(",") emitExp(s.limit) if s.step then emit(",") emitExp(s.step) end emit(" do\n") emitStmt(s.body) @@ -3716,7 +3762,7 @@ function prettystring(toptree,breaklines) emit("\n") end end - + local function makeprectable(...) local lst = {...} local sz = #lst @@ -3734,7 +3780,7 @@ function prettystring(toptree,breaklines) "~=",3,">",3,">=",3, "and",2,"or",1, "@",9,"&",9,"not",9,"select",12) - + local function getprec(e) if e:is "operator" then if "-" == e.operator and #e.operands == 1 then return 9 --unary minus case @@ -3833,7 +3879,7 @@ function prettystring(toptree,breaklines) emit(")") elseif e:is "constructor" then local success,keys = pcall(function() return e.type:getlayout().entries:map(function(e) return tostring(e.key) end) end) - if not success then emit(" = ") + if not success then emit(" = ") else emitList(keys,"",", "," = ",emit) end emitParamList(e.expressions) elseif e:is "constructoru" then @@ -4009,8 +4055,8 @@ end -- configure path variables terra.cudahome = os.getenv("CUDA_HOME") or (ffi.os == "Windows" and os.getenv("CUDA_PATH")) or "/usr/local/cuda" -terra.cudalibpaths = ({ OSX = {driver = "/usr/local/cuda/lib/libcuda.dylib", runtime = "$CUDA_HOME/lib/libcudart.dylib", nvvm = "$CUDA_HOME/nvvm/lib/libnvvm.dylib"}; - Linux = {driver = "libcuda.so", runtime = "$CUDA_HOME/lib64/libcudart.so", nvvm = "$CUDA_HOME/nvvm/lib64/libnvvm.so"}; +terra.cudalibpaths = ({ OSX = {driver = "/usr/local/cuda/lib/libcuda.dylib", runtime = "$CUDA_HOME/lib/libcudart.dylib", nvvm = "$CUDA_HOME/nvvm/lib/libnvvm.dylib"}; + Linux = {driver = "libcuda.so", runtime = "$CUDA_HOME/lib64/libcudart.so", nvvm = "$CUDA_HOME/nvvm/lib64/libnvvm.so"}; Windows = {driver = "nvcuda.dll", runtime = "$CUDA_HOME\\bin\\cudart64_*.dll", nvvm = "$CUDA_HOME\\nvvm\\bin\\nvvm64_*.dll"}; })[ffi.os] for name,path in pairs(terra.cudalibpaths) do path = path:gsub("%$CUDA_HOME",terra.cudahome) @@ -4022,7 +4068,7 @@ for name,path in pairs(terra.cudalibpaths) do end end terra.cudalibpaths[name] = path -end +end terra.systemincludes = List() if ffi.os == "Windows" then @@ -4033,7 +4079,7 @@ if ffi.os == "Windows" then return result or default end terra.vshome = registrystring([[HKLM\Software\WOW6432Node\Microsoft\VisualStudio\12.0]],"ShellFolder",[[C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\]]) - local windowsdk = registrystring([[HKLM\SOFTWARE\Wow6432Node\Microsoft\Microsoft SDKs\Windows\v8.1]],"InstallationFolder",[[C:\Program Files (x86)\Windows Kits\8.1\]]) + local windowsdk = registrystring([[HKLM\SOFTWARE\Wow6432Node\Microsoft\Microsoft SDKs\Windows\v8.1]],"InstallationFolder",[[C:\Program Files (x86)\Windows Kits\8.1\]]) terra.systemincludes:insertall { ("%sVC/INCLUDE"):format(terra.vshome), @@ -4209,7 +4255,7 @@ function terra.linkllvmstring(str,target) return terra.linkllvm(str,target,true) terra.languageextension = { tokentype = {}; --metatable for tokentype objects - tokenkindtotoken = {}; --map from token's kind id (terra.kind.name), to the singleton table (terra.languageextension.name) + tokenkindtotoken = {}; --map from token's kind id (terra.kind.name), to the singleton table (terra.languageextension.name) } function terra.importlanguage(languages,entrypoints,langstring) @@ -4218,7 +4264,7 @@ function terra.importlanguage(languages,entrypoints,langstring) if not lang or type(lang) ~= "table" then error("expected a table to define language") end lang.name = lang.name or "anonymous" local function haslist(field,typ) - if not lang[field] then + if not lang[field] then error(field .. " expected to be list of "..typ) end for i,k in ipairs(lang[field]) do @@ -4229,7 +4275,7 @@ function terra.importlanguage(languages,entrypoints,langstring) end haslist("keywords","string") haslist("entrypoints","string") - + for i,e in ipairs(lang.entrypoints) do if entrypoints[e] then error(("language '%s' uses entrypoint '%s' already defined by language '%s'"):format(lang.name,e,entrypoints[e].name),-1) @@ -4275,7 +4321,7 @@ end function terra.runlanguage(lang,cur,lookahead,next,embeddedcode,source,isstatement,islocal) local lex = {} - + lex.name = terra.languageextension.name lex.string = terra.languageextension.string lex.number = terra.languageextension.number @@ -4310,7 +4356,7 @@ function terra.runlanguage(lang,cur,lookahead,next,embeddedcode,source,isstateme return v end local function doembeddedcode(self,isterra,isexp) - self._cur,self._lookahead = nil,nil --parsing an expression invalidates our lua representations + self._cur,self._lookahead = nil,nil --parsing an expression invalidates our lua representations local expr = embeddedcode(isterra,isexp) return function(env) local oldenv = getfenv(expr) @@ -4337,7 +4383,7 @@ function terra.runlanguage(lang,cur,lookahead,next,embeddedcode,source,isstateme function lex:typetostring(name) return name end - + function lex:nextif(typ) if self:cur().type == typ then return self:next() @@ -4385,7 +4431,7 @@ function terra.runlanguage(lang,cur,lookahead,next,embeddedcode,source,isstateme else lex:error("unexpected token") end - + if not constructor or type(constructor) ~= "function" then error("expected language to return a construction function") end @@ -4395,9 +4441,9 @@ function terra.runlanguage(lang,cur,lookahead,next,embeddedcode,source,isstateme return b == 1 and e == string.len(str) end - --fixup names + --fixup names - if not names then + if not names then names = {} end diff --git a/tests/asvalue_select.t b/tests/asvalue_select.t new file mode 100644 index 00000000..a68567dc --- /dev/null +++ b/tests/asvalue_select.t @@ -0,0 +1,13 @@ +local s = `terralib.select(0 > 0, 0, 1 + 1) +assert(s:asvalue() == 2) +local s = `terralib.select(0 >= -1, 0, 1) +assert(s:asvalue() == 0) +local s = `terralib.select(false and false, 0, 1 - 2) +assert(s:asvalue() == -1) +local s = `terralib.select(false or true, 0, 1) +assert(s:asvalue() == 0) +local foo = terra(i : int) return i > 0 end +local s = `terralib.select(foo(123) or true, 0, 1) +assert(s:asvalue() == 0) +local s = `terralib.select(foo(123) and false, 0, 1) +assert(s:asvalue() == 1)