Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(PUNET/I8): gfx1100 compilation issue with qbmm #44

Open
monorimet opened this issue Jul 2, 2024 · 0 comments
Open

(PUNET/I8): gfx1100 compilation issue with qbmm #44

monorimet opened this issue Jul 2, 2024 · 0 comments

Comments

@monorimet
Copy link
Collaborator

Using the IR in punet/base_ir and the latest wmma spec, on the shared/sdxl_quantized branch of IREE, I run into compile issue with the int8 quantized punet model:

(shark.venv) PS C:\V\SHARK-Turbine> iree-compile .\vmfbs\punet_06_29.mlir --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1100 --iree-opt-const-eval=false --iree-opt-const-expr-hoisting=false --iree-opt-data-tiling=false --iree-flow-enable-aggressive-fusion --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch --iree-codegen-gpu-native-math-precision=true --iree-execution-model=async-external --iree-codegen-transform-dialect-library=./vmfbs/attention_and_matmul_spec_wmma.mlir --iree-hal-dump-executable-files-to=./punet_dps_0702_gfx1100 --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))' --mlir-print-op-on-diagnostic=false -o sdxl_turbo_bs1_64_512x512_i8_punet_gfx1100.vmfb

<unknown>:0: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
./punet_dps_0702_gfx1100\configured_module_main$async_dispatch_104.mlir:2:2: error: failed to translate the MLIR LLVM dialect to the native llvm::Module
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {      
 ^
./punet_dps_0702_gfx1100\configured_module_main$async_dispatch_104.mlir:2:2: error: failed to serialize executable for target backend rocm
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {      
 ^
./punet_dps_0702_gfx1100\configured_module_main$async_dispatch_104.mlir:1:0: error: failed to serialize executables

The dispatch:

hal.executable public @main$async_dispatch_104 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {
    hal.executable.export public @main$async_dispatch_104_batch_matmul_transpose_b_2x4096x640x640_i8xi8xi32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>, #hal.interface.binding<0, 4>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main$async_dispatch_104_batch_matmul_transpose_b_2x4096x640x640_i8xi8xi32() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUMatmulSimt workgroup_size = [32, 8, 1] subgroup_size = 32, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>} {
        %c0_i32 = arith.constant 0 : i32
        %c140352 = arith.constant 140352 : index
        %c959552 = arith.constant 959552 : index
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096x640xi8>>
        %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c140352) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x640x640xi8>>
        %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c959552) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096xi32>>
        %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640xi8>>
        %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640xi32>>
        %9 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640xf32>>
        %10 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<2x4096x640xf16>>
        %11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096x640xi8>> -> tensor<2x4096x640xi8>
        %12 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [2, 640, 640], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x640x640xi8>> -> tensor<2x640x640xi8>
        %13 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [2, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096xi32>> -> tensor<2x4096xi32>
        %14 = flow.dispatch.tensor.load %7, offsets = [0], sizes = [640], strides = [1] : !flow.dispatch.tensor<readonly:tensor<640xi8>> -> tensor<640xi8>
        %15 = flow.dispatch.tensor.load %8, offsets = [0], sizes = [640], strides = [1] : !flow.dispatch.tensor<readonly:tensor<640xi32>> -> tensor<640xi32>
        %16 = flow.dispatch.tensor.load %9, offsets = [0], sizes = [640], strides = [1] : !flow.dispatch.tensor<readonly:tensor<640xf32>> -> tensor<640xf32>
        %17 = tensor.empty() : tensor<2x4096x640xf16>
        %18 = tensor.empty() : tensor<2x4096x640xi32>
        %19 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 128, 32]]>} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
        %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 128, 32]]>} {
        ^bb0(%in: i8, %in_0: i8, %out: i32):
          %22 = arith.extsi %in : i8 to i32
          %23 = arith.extsi %in_0 : i8 to i32
          %24 = arith.muli %22, %23 : i32
          %25 = arith.addi %out, %24 : i32
          linalg.yield %25 : i32
        } -> tensor<2x4096x640xi32>
        %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%20, %13, %14, %15, %16 : tensor<2x4096x640xi32>, tensor<2x4096xi32>, tensor<640xi8>, tensor<640xi32>, tensor<640xf32>) outs(%17 : tensor<2x4096x640xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 128, 32]]>} {
        ^bb0(%in: i32, %in_0: i32, %in_1: i8, %in_2: i32, %in_3: f32, %out: f16):
          %22 = arith.extsi %in_1 : i8 to i32
          %23 = arith.muli %in_0, %22 : i32
          %24 = arith.subi %in, %23 : i32
          %25 = arith.addi %24, %in_2 : i32
          %26 = arith.sitofp %25 : i32 to f32
          %27 = arith.mulf %26, %in_3 : f32
          %28 = arith.truncf %27 : f32 to f16
          linalg.yield %28 : f16
        } -> tensor<2x4096x640xf16>
        flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x4096x640xf16>>
        return
      }
    }
  }
}

Latest IREE commit:

6cc8afe65b23325f5a2de46c3ec2bc4b34f00ba4

Attention and matmul spec (wmma):
attention_and_matmul_spec_wmma.mlir.txt

Since we are bringing up on gfx942 this is not an immediate blocker, but filing as we'll need this fixed at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant