Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function decoration FuncParamAttr support in SPIRV dialect. #645

Open
chengjunlu opened this issue Apr 27, 2023 · 1 comment
Open

Add function decoration FuncParamAttr support in SPIRV dialect. #645

chengjunlu opened this issue Apr 27, 2023 · 1 comment
Labels
Triton Issues tracking Triton/IMEX collaboration

Comments

@chengjunlu
Copy link

FuncParamAttr is not added to SPIR-V dialect yet which is required to link some functions in the Intel Math Libraries.

The import signature is:

OpName %__devicelib_imf_float2bfloat16 "__devicelib_imf_float2bfloat16" 
OpDecorate %__devicelib_imf_float2bfloat16 LinkageAttributes "__devicelib_imf_float2bfloat16" Export 
OpDecorate %__devicelib_imf_float2bfloat16 FuncParamAttr Zext 

The import function declaration in SPIRV dialect is:

 spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}

The error log when to serialize the SPIRV dialect:
error: unhandled decoration FuncParamAttr

@silee2 silee2 added the Triton Issues tracking Triton/IMEX collaboration label May 1, 2023
@chengjunlu
Copy link
Author

Here is the SPIRV dialect to link the function with FuncParamAttr = "Zext" decoration.


// -----// IR Dump After CSE (cse) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @__devicelib_imf_bfloat162float(i16) -> f32 "Inline" attributes {libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_bfloat162float", "Import"]}
  spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}
  spirv.func @kernel_0d1d2d(%arg0: !spirv.ptr<f32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {noinline = false, spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %0 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %1 = spirv.CompositeExtract %0[0 : i32] : vector<3xi64>
    %2 = spirv.SConvert %1 : i64 to i32
    %cst32_i32 = spirv.Constant 32 : i32
    %3 = spirv.UMod %2, %cst32_i32 : i32
    %4 = spirv.UDiv %2, %cst32_i32 : i32
    %cst4_i32 = spirv.Constant 4 : i32
    %5 = spirv.UMod %4, %cst4_i32 : i32
    %cst128_i32 = spirv.Constant 128 : i32
    %6 = spirv.UMod %3, %cst128_i32 : i32
    %cst1_i32 = spirv.Constant 1 : i32
    %7 = spirv.IMul %5, %cst32_i32 : i32
    %8 = spirv.IAdd %6, %7 : i32
    %9 = spirv.IMul %cst1_i32, %8 : i32
    %10 = spirv.Undef : !spirv.struct<(i32)>
    %11 = spirv.Undef : !spirv.struct<(!spirv.ptr<i16, CrossWorkgroup>)>
    %12 = spirv.PtrAccessChain %arg1[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
    %true = spirv.Constant true
    %13 = spirv.Undef : i16
    spirv.BranchConditional %true, ^bb1, ^bb2(%13 : i16)
  ^bb1:  // pred: ^bb0
    %14 = spirv.Load "CrossWorkgroup" %12 : i16
    spirv.Branch ^bb2(%14 : i16)
  ^bb2(%15: i16):  // 2 preds: ^bb0, ^bb1
    %16 = spirv.Undef : !spirv.struct<(i16)>
    %17 = spirv.PtrAccessChain %arg2[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
    spirv.BranchConditional %true, ^bb3, ^bb4(%13 : i16)
  ^bb3:  // pred: ^bb2
    %18 = spirv.Load "CrossWorkgroup" %17 : i16
    spirv.Branch ^bb4(%18 : i16)
  ^bb4(%19: i16):  // 2 preds: ^bb2, ^bb3
    %20 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%15) : (i16) -> f32
    %21 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%19) : (i16) -> f32
    %22 = spirv.FAdd %20, %21 : f32
    %23 = spirv.FunctionCall @__devicelib_imf_float2bfloat16(%22) : (f32) -> i16
    %24 = spirv.Undef : !spirv.struct<(!spirv.ptr<f32, CrossWorkgroup>)>
    %25 = spirv.PtrAccessChain %arg0[%9] : !spirv.ptr<f32, CrossWorkgroup>, i32
    %26 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%23) : (i16) -> f32
    %27 = spirv.Undef : !spirv.struct<(f32)>
    %28 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %29 = spirv.CompositeExtract %28[0 : i32] : vector<3xi64>
    %30 = spirv.SConvert %29 : i64 to i32
    spirv.BranchConditional %true, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %31 = spirv.Bitcast %26 : f32 to i32
    %32 = spirv.Bitcast %25 : !spirv.ptr<f32, CrossWorkgroup> to !spirv.ptr<i32, CrossWorkgroup>
    spirv.Store "CrossWorkgroup" %32, %31 : i32
    spirv.Branch ^bb6
  ^bb6:  // 2 preds: ^bb4, ^bb5
    spirv.Return
  }
}

@mshahneo mshahneo reopened this Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Triton Issues tracking Triton/IMEX collaboration
Projects
None yet
Development

No branches or pull requests

3 participants