Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
subarutaro committed Sep 5, 2021
2 parents 6b03a98 + f8a55af commit 381574c
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/A64FX.rb
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def convert_to_code_a64fx(conversion_type,predicate=$current_predicate)
when :landnot then
retval="svbic_b_z(" + predicate + "," + @lop.convert_to_code(conversion_type) + "," + @rop.convert_to_code(conversion_type) + ")"
when :dot then
if @rop == "x" || @rop == "y" || @rop == "z" || @rop == "w"
if (@rop == "x" || @rop == "y" || @rop == "z" || @rop == "w") && @lop.type =~ /vec/
retval=@lop.convert_to_code(conversion_type)+"."
retval += ["v0","v1","v2","v3"][["x","y","z","w"].index(@rop)]
#@rop.convert_to_code(conversion_type)
Expand Down
4 changes: 2 additions & 2 deletions src/AVX-512.rb
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def convert_to_code_avx512(conversion_type)
end

class TableDecl
def convert_to_code_a64fx(conversion_type)
def convert_to_code_avx512(conversion_type)
ret = ""
nelem = get_num_elem(@type,conversion_type)
nreg = (@table.vals.length/nelem)
Expand Down Expand Up @@ -356,7 +356,7 @@ def convert_to_code_avx512(conversion_type)
class String
def convert_to_code_avx512(conversion_type,h=$varhash)
name = get_name(self)
#abort "error: undefined reference to #{name} in convert_to_code_a64fx of String"
#abort "error: undefined reference to #{name} in convert_to_code_avx512 of String"
s = self
if h[name] != nil
iotype = h[name][0]
Expand Down
2 changes: 1 addition & 1 deletion src/AVX2.rb
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def convert_to_code_avx2(conversion_type)
end

class TableDecl
def convert_to_code_a64fx(conversion_code)
def convert_to_code_avx2(conversion_code)
ret = ""
simd_width = get_simd_width_avx2(@type)
nreg = (@table.vals.length/simd_width)
Expand Down
2 changes: 2 additions & 0 deletions src/common.rb
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def get_name(x)
ret = get_name(x.exp)
elsif x.class == String
ret = x
elsif x.class == TableDecl
ret = get_name(x.name)
else
return nil
abort "get_name is not allowed to use for #{x.class}"
Expand Down
4 changes: 4 additions & 0 deletions src/expand_function.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def expand_function

class Statement
def expand_function
if $varhash[get_name(self)][3] == "local"
return [self]
end
exp,statements = @expression.expand_function
tmp = Statement.new([@name,exp,@type,@op])
statements += [tmp]
Expand Down Expand Up @@ -53,6 +56,7 @@ class FuncCall
def expand_function
ret = self.dup
statements = []

if !$reserved_function.index(@name)
abort "undefined reference to function #{@name}" if $funchash[@name] == nil
function = $funchash[@name]
Expand Down
39 changes: 33 additions & 6 deletions src/intermediate_exp_class.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def convert_to_code(conversion_type="reference")
}
ret += "} // loop of #{@index}\n"
if conversion_type == "A64FX"
if @interval != "1" && @option == :up
if @interval != 1 && @option == :up
$current_predicate = predicate
$current_predicate = "svptrue_b#{$min_element_size}()" if $current_predicate == nil
end
Expand Down Expand Up @@ -225,6 +225,10 @@ def convert_to_code(conversion_type="reference")
end
ret
end

def get_related_variable
[]
end
end

class FloatingPoint
Expand Down Expand Up @@ -254,6 +258,9 @@ def replace_fdpsname_recursive(h=$varhash)
def replace_recursive(orig,replaced)
self.dup
end
def replace_by_list(n,l)
end

def isJRelated(list)
false
end
Expand Down Expand Up @@ -700,6 +707,14 @@ def replace_recursive(orig,replaced)
FuncCall.new([@name,ops,@type])
end

def replace_fdpsname_recursive(h=$varhash)
ops = Array.new
@ops.each{ |op|
ops.push(op.replace_fdpsname_recursive(h))
}
FuncCall.new([@name,ops,@type])
end

def replace_by_list(name_list,replaced_list)
name_list.zip(replaced_list){ |n,r|
self.replace_recursive(n,r)
Expand Down Expand Up @@ -809,6 +824,7 @@ def convert_to_code(conversion_type="reference")
pg_accum = "pg_accum"
case @operator
when :if
$predicate_queue.push($current_predicate)
ret += "{\n"
$accumulate_predicate = "pg#{$pg_count}"
$varhash[$accumulate_predicate] = [nil,type,nil,nil]
Expand Down Expand Up @@ -1042,7 +1058,7 @@ def convert_to_code(conversion_type,index = nil)
index += ",#{i*@interval.to_i + @offset.to_i}"
end
end
ret += "int #{index_name}[#{nelem}] = {#{index}};\n"
ret += "alignas(32) int #{index_name}[#{nelem}] = {#{index}};\n"
index_simd_width = 32 * nelem
ret += "__m#{index_simd_width}i #{vindex_name} = "
case index_simd_width
Expand All @@ -1065,7 +1081,7 @@ def convert_to_code(conversion_type,index = nil)
index += "#{i*@interval.to_i + @offset.to_i}"
end
end
ret += "int#{size}_t #{index_name}[#{nelem}] = {#{index}};\n"
ret += "alignas(#{size}) int#{size}_t #{index_name}[#{nelem}] = {#{index}};\n"
ret += "__m512i #{vindex_name} = _mm512_load_epi#{size}(#{index_name});\n"
ret += "#{@dest.convert_to_code(conversion_type)} = _mm512_i#{size}gather_#{get_type_suffix_avx512(@type)}(#{vindex_name},#{@src.convert_to_code(conversion_type)},#{scale});"
else
Expand Down Expand Up @@ -1264,7 +1280,9 @@ def convert_to_code(conversion_type="reference")
when /A64FX/
ret = "#{@name.convert_to_code(conversion_type)} = svdup_n_#{get_type_suffix_a64fx(@type)}(#{@expression.convert_to_code(conversion_type)});"
when /AVX2/
ret = "#{@name.convert_to_code(conversion_type)} = _mm256_set1_#{get_type_suffix_avx2(@type)}(#{@expression.convert_to_code(conversion_type)});"
set1_suffix = ""
set1_suffix = "x" if @type =~ /(S|U)64/
ret = "#{@name.convert_to_code(conversion_type)} = _mm256_set1_#{get_type_suffix_avx2(@type)}#{set1_suffix}(#{@expression.convert_to_code(conversion_type)});"
when /AVX-512/
ret = "#{@name.convert_to_code(conversion_type)} = _mm512_set1_#{get_type_suffix_avx512(@type)}(#{@expression.convert_to_code(conversion_type)});"
end
Expand Down Expand Up @@ -1326,6 +1344,13 @@ def replace_recursive(orig,replaced)
MADD.new([@operator,aop,bop,cop,@type])
end

def replace_fdpsname_recursive(h=$varhash)
aop = @aop.replace_fdpsname_recursive(h)
bop = @bop.replace_fdpsname_recursive(h)
cop = @cop.replace_fdpsname_recursive(h)
MADD.new([@operator,aop,bop,cop,@type])
end

def replace_by_list(name_list,replaced_list)
name_list.zip(replaced_list){ |n,r|
self.replace_recursive(n,r)
Expand Down Expand Up @@ -1450,8 +1475,10 @@ def convert_to_code(conversion_type="reference")
if suffix == "epi32"
ret = "_mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_ps(#{@op1.convert_to_code(conversion_type)}),#{predicate}));"
elsif suffix == "epi64"
predicate = "_mm256_castsi256_pd(" + predicate + ")"
ret = "_mm256_blendv_pd(_mm256_castsi256_pd(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_pd(#{@op1.convert_to_code(conversion_type)}),#{predicate});"
#predicate = "_mm256_castsi256_pd(" + predicate + ")"
predicate = "_mm256_castps_pd(" + predicate + ")"
ret = "_mm256_blendv_pd(_mm256_castsi256_pd(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_pd(#{@op1.convert_to_code(conversion_type)}),#{predicate})"
ret = "_mm256_castpd_si256(#{ret});"
else
predicate = "_mm256_castps_pd(" + predicate + ")" if suffix == "pd"
ret = "_mm256_blendv_#{suffix}(#{@op2.convert_to_code(conversion_type)},#{@op1.convert_to_code(conversion_type)},#{predicate});"
Expand Down
36 changes: 20 additions & 16 deletions src/kernel_body_multi_prec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,12 @@ def convert_to_code(conversion_type)
sh_suffix = "ps" if size == 32
sh_suffix = "pd" if size == 64
lop = "#{src}"
lop = "_mm256_castsi256_ps(#{lop})" if type =~ /(S|U)(64|32)/
lop = "_mm256_castsi256_#{sh_suffix}(#{lop})" if type =~ /(S|U)(64|32)/


rop = lop
rop = "_mm256_shuffle_#{sh_suffix}(#{rop},#{rop},#{imm8})"
rop = "_mm256_castps_si256(#{rop})" if type =~ /(S|U)(64|32)/
rop = "_mm256_cast#{sh_suffix}_si256(#{rop})" if type =~ /(S|U)(64|32)/
ret += "#{src} = _mm256_#{op}_#{suffix}(#{src},#{rop});\n"
if size == 32
rop = lop
Expand Down Expand Up @@ -675,15 +676,13 @@ def split_downcast(ss)
#ret += ["//downcasting from #{from} to #{to}"]

nline = get_single_data_size(from) / get_single_data_size(to)
p "nline:",nline
ops = Array.new
tmps = Array.new
for i in 0...nline
orig = get_name(s)
replace = "#{orig}_#{i}"
tmp = Statement.new([s.name,s.expression.ops[0],s.type,s.op])
tmp.replace_name(orig,replace)

tmps += [tmp]
ops += [replace]
end
Expand Down Expand Up @@ -715,12 +714,19 @@ def load_local_var(name,type,nelem,iotype,offset = 0,h = $varhash)
iotype = h[name][0]
ij = "i"
ij = "j" if iotype == "EPJ"

get_vector_elements(type).each{|dim|
dst = name
dst = Expression.new([:dot,name,dim,type_single]) if dim != ""
src = PointerOf.new([type_single,Expression.new([:array,"#{name}_tmp","#{ij}+#{offset}",])])
src = PointerOf.new([type_single,Expression.new([:array,"#{name}_tmp_"+dim,"#{ij}+#{offset}",])]) if dim != ""
ret += [Load.new([dst,src,nelem*$max_element_size/get_single_data_size(type_single),type_single,iotype,"local"])]
src = Expression.new([:array,"#{name}_tmp","#{ij}+#{offset}",])
src = Expression.new([:array,"#{name}_tmp_"+dim,"#{ij}+#{offset}",]) if dim != ""
#ret += [Load.new([dst,src,nelem*$max_element_size/get_single_data_size(type_single),type_single,iotype,"local"])]
if nelem == 1 then
ret += [Duplicate.new([dst,src,type_single])]
else
src = PointerOf.new([type_single,src])
ret += [LoadState.new([dst,src,type_single])]
end
}
ret
end
Expand Down Expand Up @@ -788,7 +794,8 @@ def load_jvars(fvars,nelem,conversion_type,h=$varhash)
suffix = "_#{i}" if nsplit > 1
ret += [Declaration.new([type,name+suffix])]
get_vector_elements(type).each{ |dim|
index = "j+#{i*nelem}"
#index = "j+#{i*nelem}"
index = "j"
src = Expression.new([:dot,Expression.new([:array,get_iotype_array(iotype),index]),fdpsname,type_single])
src = Expression.new([:dot,src,dim,type_single]) if type =~ /vec/
src = PointerOf.new([type,src])
Expand Down Expand Up @@ -1055,29 +1062,26 @@ def kernel_body_multi_prec(ninj,conversion_type,istart=0,h=$varhash,isTail = fal
new_exp = exp.replace_fdpsname_recursive(h)
loop_tmp.statements += [Statement.new([new_name,new_exp])]
code += loop_tmp.convert_to_code("reference") if !isTail
elsif s.class == TableDecl
# do nothing
else
ss.push(s)
end
}

# declare and load TABLE variable
tmp = Array.new
ss.each{ |s|
@statements.each{ |s|
if s.class == TableDecl
ret.push(s)
tmp.push(s)
code += s.convert_to_code(conversion_type)
end
}
tmp.each{ |s|
ss.delete(s)
}

fvars = generate_force_related_map(ss)
lane_size = get_simd_width(conversion_type) / $min_element_size
lane_size = 1 if lane_size == 0

split_vars = Array.new
["EPI","FORCE"].each{ |io|
["EPI","EPJ","FORCE"].each{ |io|
fvars.each{ |v|
iotype = h[v][0]
if iotype == io
Expand Down
5 changes: 4 additions & 1 deletion src/parserdriver.rb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def check_references(h = $varhash)
message += " "
end
message += "^\n"
message = "error : undefined reference to \"#{v}\"\n" + message
message = "error : undefined reference to \"#{v}\" in check_references\n" + message
abort message
end
}
Expand Down Expand Up @@ -1234,6 +1234,9 @@ def make_conditional_branch_block_recursive3(ss,h = $varhash,related_vars = [])
related_vars += s.expression.get_related_variable
end
elsif s.class == Pragma
# do nothing
elsif s.class == TableDecl
# do nothing
else
related_vars += s.expression.get_related_variable
end
Expand Down

0 comments on commit 381574c

Please sign in to comment.