diff --git a/src/A64FX.rb b/src/A64FX.rb index 7125f08..2ade3d9 100644 --- a/src/A64FX.rb +++ b/src/A64FX.rb @@ -551,7 +551,7 @@ def convert_to_code_a64fx(conversion_type,predicate=$current_predicate) when :landnot then retval="svbic_b_z(" + predicate + "," + @lop.convert_to_code(conversion_type) + "," + @rop.convert_to_code(conversion_type) + ")" when :dot then - if @rop == "x" || @rop == "y" || @rop == "z" || @rop == "w" + if (@rop == "x" || @rop == "y" || @rop == "z" || @rop == "w") && @lop.type =~ /vec/ retval=@lop.convert_to_code(conversion_type)+"." retval += ["v0","v1","v2","v3"][["x","y","z","w"].index(@rop)] #@rop.convert_to_code(conversion_type) diff --git a/src/AVX-512.rb b/src/AVX-512.rb index 73746e5..25fd16c 100644 --- a/src/AVX-512.rb +++ b/src/AVX-512.rb @@ -304,7 +304,7 @@ def convert_to_code_avx512(conversion_type) end class TableDecl - def convert_to_code_a64fx(conversion_type) + def convert_to_code_avx512(conversion_type) ret = "" nelem = get_num_elem(@type,conversion_type) nreg = (@table.vals.length/nelem) @@ -356,7 +356,7 @@ def convert_to_code_avx512(conversion_type) class String def convert_to_code_avx512(conversion_type,h=$varhash) name = get_name(self) - #abort "error: undefined reference to #{name} in convert_to_code_a64fx of String" + #abort "error: undefined reference to #{name} in convert_to_code_avx512 of String" s = self if h[name] != nil iotype = h[name][0] diff --git a/src/AVX2.rb b/src/AVX2.rb index 7697532..2372c21 100644 --- a/src/AVX2.rb +++ b/src/AVX2.rb @@ -447,7 +447,7 @@ def convert_to_code_avx2(conversion_type) end class TableDecl - def convert_to_code_a64fx(conversion_code) + def convert_to_code_avx2(conversion_code) ret = "" simd_width = get_simd_width_avx2(@type) nreg = (@table.vals.length/simd_width) diff --git a/src/common.rb b/src/common.rb index 2b55446..51c04d3 100644 --- a/src/common.rb +++ b/src/common.rb @@ -197,6 +197,8 @@ def get_name(x) ret = get_name(x.exp) elsif x.class == String ret = x + elsif x.class == TableDecl + ret = get_name(x.name) else return nil abort "get_name is not allowed to use for #{x.class}" diff --git a/src/expand_function.rb b/src/expand_function.rb index 3661d5c..b53e0e2 100644 --- a/src/expand_function.rb +++ b/src/expand_function.rb @@ -14,6 +14,9 @@ def expand_function class Statement def expand_function + if $varhash[get_name(self)][3] == "local" + return [self] + end exp,statements = @expression.expand_function tmp = Statement.new([@name,exp,@type,@op]) statements += [tmp] @@ -53,6 +56,7 @@ class FuncCall def expand_function ret = self.dup statements = [] + if !$reserved_function.index(@name) abort "undefined reference to function #{@name}" if $funchash[@name] == nil function = $funchash[@name] diff --git a/src/intermediate_exp_class.rb b/src/intermediate_exp_class.rb index 4666825..bd536d6 100644 --- a/src/intermediate_exp_class.rb +++ b/src/intermediate_exp_class.rb @@ -61,7 +61,7 @@ def convert_to_code(conversion_type="reference") } ret += "} // loop of #{@index}\n" if conversion_type == "A64FX" - if @interval != "1" && @option == :up + if @interval != 1 && @option == :up $current_predicate = predicate $current_predicate = "svptrue_b#{$min_element_size}()" if $current_predicate == nil end @@ -225,6 +225,10 @@ def convert_to_code(conversion_type="reference") end ret end + + def get_related_variable + [] + end end class FloatingPoint @@ -254,6 +258,9 @@ def replace_fdpsname_recursive(h=$varhash) def replace_recursive(orig,replaced) self.dup end + def replace_by_list(n,l) + end + def isJRelated(list) false end @@ -700,6 +707,14 @@ def replace_recursive(orig,replaced) FuncCall.new([@name,ops,@type]) end + def replace_fdpsname_recursive(h=$varhash) + ops = Array.new + @ops.each{ |op| + ops.push(op.replace_fdpsname_recursive(h)) + } + FuncCall.new([@name,ops,@type]) + end + def replace_by_list(name_list,replaced_list) name_list.zip(replaced_list){ |n,r| self.replace_recursive(n,r) @@ -809,6 +824,7 @@ def convert_to_code(conversion_type="reference") pg_accum = "pg_accum" case @operator when :if + $predicate_queue.push($current_predicate) ret += "{\n" $accumulate_predicate = "pg#{$pg_count}" $varhash[$accumulate_predicate] = [nil,type,nil,nil] @@ -1042,7 +1058,7 @@ def convert_to_code(conversion_type,index = nil) index += ",#{i*@interval.to_i + @offset.to_i}" end end - ret += "int #{index_name}[#{nelem}] = {#{index}};\n" + ret += "alignas(32) int #{index_name}[#{nelem}] = {#{index}};\n" index_simd_width = 32 * nelem ret += "__m#{index_simd_width}i #{vindex_name} = " case index_simd_width @@ -1065,7 +1081,7 @@ def convert_to_code(conversion_type,index = nil) index += "#{i*@interval.to_i + @offset.to_i}" end end - ret += "int#{size}_t #{index_name}[#{nelem}] = {#{index}};\n" + ret += "alignas(#{size}) int#{size}_t #{index_name}[#{nelem}] = {#{index}};\n" ret += "__m512i #{vindex_name} = _mm512_load_epi#{size}(#{index_name});\n" ret += "#{@dest.convert_to_code(conversion_type)} = _mm512_i#{size}gather_#{get_type_suffix_avx512(@type)}(#{vindex_name},#{@src.convert_to_code(conversion_type)},#{scale});" else @@ -1264,7 +1280,9 @@ def convert_to_code(conversion_type="reference") when /A64FX/ ret = "#{@name.convert_to_code(conversion_type)} = svdup_n_#{get_type_suffix_a64fx(@type)}(#{@expression.convert_to_code(conversion_type)});" when /AVX2/ - ret = "#{@name.convert_to_code(conversion_type)} = _mm256_set1_#{get_type_suffix_avx2(@type)}(#{@expression.convert_to_code(conversion_type)});" + set1_suffix = "" + set1_suffix = "x" if @type =~ /(S|U)64/ + ret = "#{@name.convert_to_code(conversion_type)} = _mm256_set1_#{get_type_suffix_avx2(@type)}#{set1_suffix}(#{@expression.convert_to_code(conversion_type)});" when /AVX-512/ ret = "#{@name.convert_to_code(conversion_type)} = _mm512_set1_#{get_type_suffix_avx512(@type)}(#{@expression.convert_to_code(conversion_type)});" end @@ -1326,6 +1344,13 @@ def replace_recursive(orig,replaced) MADD.new([@operator,aop,bop,cop,@type]) end + def replace_fdpsname_recursive(h=$varhash) + aop = @aop.replace_fdpsname_recursive(h) + bop = @bop.replace_fdpsname_recursive(h) + cop = @cop.replace_fdpsname_recursive(h) + MADD.new([@operator,aop,bop,cop,@type]) + end + def replace_by_list(name_list,replaced_list) name_list.zip(replaced_list){ |n,r| self.replace_recursive(n,r) @@ -1450,8 +1475,10 @@ def convert_to_code(conversion_type="reference") if suffix == "epi32" ret = "_mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_ps(#{@op1.convert_to_code(conversion_type)}),#{predicate}));" elsif suffix == "epi64" - predicate = "_mm256_castsi256_pd(" + predicate + ")" - ret = "_mm256_blendv_pd(_mm256_castsi256_pd(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_pd(#{@op1.convert_to_code(conversion_type)}),#{predicate});" + #predicate = "_mm256_castsi256_pd(" + predicate + ")" + predicate = "_mm256_castps_pd(" + predicate + ")" + ret = "_mm256_blendv_pd(_mm256_castsi256_pd(#{@op2.convert_to_code(conversion_type)}),_mm256_castsi256_pd(#{@op1.convert_to_code(conversion_type)}),#{predicate})" + ret = "_mm256_castpd_si256(#{ret});" else predicate = "_mm256_castps_pd(" + predicate + ")" if suffix == "pd" ret = "_mm256_blendv_#{suffix}(#{@op2.convert_to_code(conversion_type)},#{@op1.convert_to_code(conversion_type)},#{predicate});" diff --git a/src/kernel_body_multi_prec.rb b/src/kernel_body_multi_prec.rb index 67af303..d54089f 100644 --- a/src/kernel_body_multi_prec.rb +++ b/src/kernel_body_multi_prec.rb @@ -185,11 +185,12 @@ def convert_to_code(conversion_type) sh_suffix = "ps" if size == 32 sh_suffix = "pd" if size == 64 lop = "#{src}" - lop = "_mm256_castsi256_ps(#{lop})" if type =~ /(S|U)(64|32)/ + lop = "_mm256_castsi256_#{sh_suffix}(#{lop})" if type =~ /(S|U)(64|32)/ + rop = lop rop = "_mm256_shuffle_#{sh_suffix}(#{rop},#{rop},#{imm8})" - rop = "_mm256_castps_si256(#{rop})" if type =~ /(S|U)(64|32)/ + rop = "_mm256_cast#{sh_suffix}_si256(#{rop})" if type =~ /(S|U)(64|32)/ ret += "#{src} = _mm256_#{op}_#{suffix}(#{src},#{rop});\n" if size == 32 rop = lop @@ -675,7 +676,6 @@ def split_downcast(ss) #ret += ["//downcasting from #{from} to #{to}"] nline = get_single_data_size(from) / get_single_data_size(to) - p "nline:",nline ops = Array.new tmps = Array.new for i in 0...nline @@ -683,7 +683,6 @@ def split_downcast(ss) replace = "#{orig}_#{i}" tmp = Statement.new([s.name,s.expression.ops[0],s.type,s.op]) tmp.replace_name(orig,replace) - tmps += [tmp] ops += [replace] end @@ -715,12 +714,19 @@ def load_local_var(name,type,nelem,iotype,offset = 0,h = $varhash) iotype = h[name][0] ij = "i" ij = "j" if iotype == "EPJ" + get_vector_elements(type).each{|dim| dst = name dst = Expression.new([:dot,name,dim,type_single]) if dim != "" - src = PointerOf.new([type_single,Expression.new([:array,"#{name}_tmp","#{ij}+#{offset}",])]) - src = PointerOf.new([type_single,Expression.new([:array,"#{name}_tmp_"+dim,"#{ij}+#{offset}",])]) if dim != "" - ret += [Load.new([dst,src,nelem*$max_element_size/get_single_data_size(type_single),type_single,iotype,"local"])] + src = Expression.new([:array,"#{name}_tmp","#{ij}+#{offset}",]) + src = Expression.new([:array,"#{name}_tmp_"+dim,"#{ij}+#{offset}",]) if dim != "" + #ret += [Load.new([dst,src,nelem*$max_element_size/get_single_data_size(type_single),type_single,iotype,"local"])] + if nelem == 1 then + ret += [Duplicate.new([dst,src,type_single])] + else + src = PointerOf.new([type_single,src]) + ret += [LoadState.new([dst,src,type_single])] + end } ret end @@ -788,7 +794,8 @@ def load_jvars(fvars,nelem,conversion_type,h=$varhash) suffix = "_#{i}" if nsplit > 1 ret += [Declaration.new([type,name+suffix])] get_vector_elements(type).each{ |dim| - index = "j+#{i*nelem}" + #index = "j+#{i*nelem}" + index = "j" src = Expression.new([:dot,Expression.new([:array,get_iotype_array(iotype),index]),fdpsname,type_single]) src = Expression.new([:dot,src,dim,type_single]) if type =~ /vec/ src = PointerOf.new([type,src]) @@ -1055,29 +1062,26 @@ def kernel_body_multi_prec(ninj,conversion_type,istart=0,h=$varhash,isTail = fal new_exp = exp.replace_fdpsname_recursive(h) loop_tmp.statements += [Statement.new([new_name,new_exp])] code += loop_tmp.convert_to_code("reference") if !isTail + elsif s.class == TableDecl + # do nothing else ss.push(s) end } # declare and load TABLE variable - tmp = Array.new - ss.each{ |s| + @statements.each{ |s| if s.class == TableDecl - ret.push(s) - tmp.push(s) + code += s.convert_to_code(conversion_type) end } - tmp.each{ |s| - ss.delete(s) - } fvars = generate_force_related_map(ss) lane_size = get_simd_width(conversion_type) / $min_element_size lane_size = 1 if lane_size == 0 split_vars = Array.new - ["EPI","FORCE"].each{ |io| + ["EPI","EPJ","FORCE"].each{ |io| fvars.each{ |v| iotype = h[v][0] if iotype == io diff --git a/src/parserdriver.rb b/src/parserdriver.rb index 1a425f0..dde6ecd 100644 --- a/src/parserdriver.rb +++ b/src/parserdriver.rb @@ -158,7 +158,7 @@ def check_references(h = $varhash) message += " " end message += "^\n" - message = "error : undefined reference to \"#{v}\"\n" + message + message = "error : undefined reference to \"#{v}\" in check_references\n" + message abort message end } @@ -1234,6 +1234,9 @@ def make_conditional_branch_block_recursive3(ss,h = $varhash,related_vars = []) related_vars += s.expression.get_related_variable end elsif s.class == Pragma + # do nothing + elsif s.class == TableDecl + # do nothing else related_vars += s.expression.get_related_variable end