Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed executable compilation (tensor.extract_slice and an illegal vector.transfer_read) #42

Open
stellaraccident opened this issue Jun 30, 2024 · 4 comments
Assignees
Labels
sdxl-int8 Issues replated to SDXL quantized model support

Comments

@stellaraccident
Copy link
Contributor

The latest int8-model has some compilation issues. If running with a debug compiler, there is an assert:

iree-compile: /data/home/slaurenz/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLExtras.h:866: detail::zippy<llvm::detail::zip_first, T, U, Args...> llvm::zip_equal(T &&, U &&, Args &&...) [T = llvm::ArrayRef<long> &, U = llvm::ArrayRef<long> &, Args = <>]: Assertion `all_equal({range_size(t), range_size(u), range_size(args)...}) && "Iteratees do not have equal length"' failed.

...

#14 0x00007f18bdc28f98 llvm::pointer_union_detail::PointerUnionMembers<llvm::PointerUnion<mlir::Attribute, mlir::Value>, llvm::PointerIntPair<void*, 1u, int, llvm::pointer_union_detail::PointerUnionUIntTraits<mlir::Attribute, mlir::Value>, llvm::PointerIntPairInfo<void*, 1u, llvm::pointer_union_detail::PointerUnionUIntTraits<mlir::Attribute, mlir::Value>>>, 1, mlir::Value>::PointerUnionMembers(mlir::Value) /data/home/slaurenz/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/PointerUnion.h:77:16                                                                                                                                                    #15 0x00007f18bdc28f98 llvm::pointer_union_detail::PointerUnionMembers<llvm::PointerUnion<mlir::Attribute, mlir::Value>, llvm::PointerIntPair<void*, 1u, int, llvm::pointer_union_detail::PointerUnionUIntTraits<mlir::Attribute, mlir::Value>, llvm::PointerIntPairInfo<void*, 1u, llvm::pointer_union_detail::PointerUnionUIntTraits$mlir::Attribute, mlir::Value>>>, 0, mlir::Attribute, mlir::Value>::PointerUnionMembers(mlir::Value) /data/home/slaurenz/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/PointerUnion.h:74:17                                                                                                                                   #16 0x00007f18bdc28f98 llvm::PointerUnion<mlir::Attribute, mlir::Value>::PointerUnion(mlir::Value) /data/home/slaurenz/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/PointerUnion.h:138:15                                                                                                                                   #17 0x00007f18bdc28f98 mlir::OpFoldResult::OpFoldResult(mlir::Value) /data/home/slaurenz/src/iree/third_party/llvm-project/mlir/include/mlir/IR/OpDefinition.h:269:41
#18 0x00007f18bdc28f98 (anonymous namespace)::delinearizeLaneId(mlir::OpBuilder&, mlir::Location, llvm::ArrayRef<long>, llvm::ArrayRef<long>, long, mlir::Value, llvm::SmallVectorImpl<mlir::Value>&) /data/home/slaurenz/src/iree/third_party/llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp:816:71
#19 0x00007f18bdc2768f (anonymous namespace)::WarpOpTransferRead::matchAndRewrite(mlir::vector::WarpExecuteOnLane0Op, mlir::PatternRewriter&) const /data/home/slaurenz/src/iree/third_party/llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp:875:10

If running with a release compiler, there are errors on two dispatches.

@main$async_dispatch_86

hal.executable public @main$async_dispatch_86 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {
    hal.executable.export public @main$async_dispatch_86_generic_2x4096x640x1_f16xf32xf32xf16xf16xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main$async_dispatch_86_generic_2x4096x640x1_f16xf32xf32xf16xf16xf16() {
        %cst = arith.constant 9.99999974E-6 : f32
        %cst_0 = arith.constant 6.400000e+02 : f32
        %cst_1 = arith.constant 0.000000e+00 : f32
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096x640xf16>>
        %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096x640x1xf16>>
        %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640x1xf16>>
        %7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640x1xf16>>
        %8 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<2x4096x640x1xf16>>
        %9 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096x640xf16>> -> tensor<2x4096x640xf16>
        %10 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0], sizes = [2, 4096, 640, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096x640x1xf16>> -> tensor<2x4096x640x1xf16>
        %11 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [640, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<640x1xf16>> -> tensor<640x1xf16>
        %12 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [640, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<640x1xf16>> -> tensor<640x1xf16>
        %13 = tensor.empty() : tensor<2x4096x640x1xf16>
        %14 = tensor.empty() : tensor<2x4096xf32>
        %15 = tensor.empty() : tensor<2x4096x640xf32>
        %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9 : tensor<2x4096x640xf16>) outs(%15 : tensor<2x4096x640xf32>) {
        ^bb0(%in: f16, %out: f32):
          %22 = arith.extf %in : f16 to f32
          linalg.yield %22 : f32
        } -> tensor<2x4096x640xf32>
        %17 = linalg.fill ins(%cst_1 : f32) outs(%14 : tensor<2x4096xf32>) -> tensor<2x4096xf32>
        %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%16 : tensor<2x4096x640xf32>) outs(%17 : tensor<2x4096xf32>) {
        ^bb0(%in: f32, %out: f32):
          %22 = arith.addf %in, %out : f32
          linalg.yield %22 : f32
        } -> tensor<2x4096xf32>
        %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18 : tensor<2x4096xf32>) outs(%14 : tensor<2x4096xf32>) {
        ^bb0(%in: f32, %out: f32):
          %22 = arith.divf %in, %cst_0 : f32
          linalg.yield %22 : f32
        } -> tensor<2x4096xf32>
        %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%16, %19 : tensor<2x4096x640xf32>, tensor<2x4096xf32>) outs(%17 : tensor<2x4096xf32>) {
        ^bb0(%in: f32, %in_2: f32, %out: f32):
          %22 = arith.subf %in, %in_2 : f32
          %23 = arith.mulf %22, %22 : f32
          %24 = arith.addf %23, %out : f32
          linalg.yield %24 : f32
        } -> tensor<2x4096xf32>
        %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %19, %20, %11, %12 : tensor<2x4096x640x1xf16>, tensor<2x4096xf32>, tensor<2x4096xf32>, tensor<640x1xf16>, tensor<640x1xf16>) outs(%13 : tensor<2x4096x640x1xf16>) {
        ^bb0(%in: f16, %in_2: f32, %in_3: f32, %in_4: f16, %in_5: f16, %out: f16):
          %22 = arith.divf %in_3, %cst_0 : f32
          %23 = arith.addf %22, %cst : f32
          %24 = math.rsqrt %23 : f32
          %25 = arith.extf %in : f16 to f32
          %26 = arith.subf %25, %in_2 : f32
          %27 = arith.mulf %26, %24 : f32
          %28 = arith.extf %in_4 : f16 to f32
          %29 = arith.mulf %27, %28 : f32
          %30 = arith.extf %in_5 : f16 to f32
          %31 = arith.addf %29, %30 : f32
          %32 = arith.truncf %31 : f32 to f16
          linalg.yield %32 : f16
        } -> tensor<2x4096x640x1xf16>
        flow.dispatch.tensor.store %21, %8, offsets = [0, 0, 0, 0], sizes = [2, 4096, 640, 1], strides = [1, 1, 1, 1] : tensor<2x4096x640x1xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x4096x640x1xf16>>
        return
      }
    }
  }
}

@main$async_dispatch_88

hal.executable public @main$async_dispatch_88 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {
    hal.executable.export public @main$async_dispatch_88_batch_matmul_transpose_b_2x4096x640x640_i32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main$async_dispatch_88_batch_matmul_transpose_b_2x4096x640x640_i32() {
        %c0 = arith.constant 0 : index
        %c0_i32 = arith.constant 0 : i32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<640x640xi8>>
        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096x640xi8>>
        %6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<2x4096x640xi32>>
        %7 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [640, 640], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<640x640xi8>> -> tensor<640x640xi8>
        %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096x640xi8>> -> tensor<2x4096x640xi8>
        %9 = tensor.empty() : tensor<2x4096x2xi32>
        %10 = tensor.empty() : tensor<2x640x640xi32>
        %11 = tensor.empty() : tensor<2x4096x640xi32>
        %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7 : tensor<640x640xi8>) outs(%10 : tensor<2x640x640xi32>) {
        ^bb0(%in: i8, %out: i32):
          %16 = arith.extsi %in : i8 to i32
          linalg.yield %16 : i32
        } -> tensor<2x640x640xi32>
        %cast = tensor.cast %9 : tensor<2x4096x2xi32> to tensor<?x?x?xi32>
        %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x4096x640xi8>) outs(%11 : tensor<2x4096x640xi32>) {
        ^bb0(%in: i8, %out: i32):
          %16 = arith.extsi %in : i8 to i32
          linalg.yield %16 : i32
        } -> tensor<2x4096x640xi32>
        %cast_0 = tensor.cast %cast : tensor<?x?x?xi32> to tensor<2x4096x640xi32>
        %14 = linalg.fill ins(%c0_i32 : i32) outs(%cast_0 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
        %15 = linalg.batch_matmul_transpose_b ins(%13, %12 : tensor<2x4096x640xi32>, tensor<2x640x640xi32>) outs(%14 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
        flow.dispatch.tensor.store %15, %6, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x4096x640xi32>>
        return
      }
    }
  }
}

Probably related.

@stellaraccident
Copy link
Contributor Author

Dispatch 86 does trigger the assert in a debug build.
Dispatch 88 does not trigger an assert in the debug build (regular MLIR failure).

I think these are the only two issues left in the model (compilation blocking at least). Either can be compiled in isolation with iree-compile --compile-mode=hal-executable

@antiagainst antiagainst added the sdxl-int8 Issues replated to SDXL quantized model support label Jun 30, 2024
@antiagainst antiagainst moved this to In progress in Turbine: SDXL on CDNA Jun 30, 2024
@antiagainst antiagainst moved this from In progress to Todo in Turbine: SDXL on CDNA Jun 30, 2024
@antiagainst
Copy link
Contributor

The extra innermost unit dim is causing issues for warp reduction in dispatch 86. It can be fixed via the following:

diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index c35d70d99a..7bdf02468c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -70,12 +71,14 @@ struct OptimizeVectorTransferPass
     // to transfer reads.
     {
       RewritePatternSet patterns(&getContext());
-      mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::populateDropUnitDimWithShapeCastPatterns(patterns);
       vector::ExtractOp::getCanonicalizationPatterns(patterns, &getContext());
+      vector::ShapeCastOp::getCanonicalizationPatterns(patterns, &getContext());
+      vector::BroadcastOp::getCanonicalizationPatterns(patterns, &getContext());
       patterns.add<TransposeUnitDimToShapeCast>(&getContext());
-      mlir::vector::
-          populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-              patterns);
+      vector::populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
+          patterns);
       if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
         return signalPassFailure();
       }

@antiagainst
Copy link
Contributor

For dispatch 88, @bangtianliu can you look into the issue?

@bangtianliu
Copy link

sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
sdxl-int8 Issues replated to SDXL quantized model support
Projects
Status: Todo
Development

No branches or pull requests

3 participants