diff --git a/cmake/imex.cmake b/cmake/imex.cmake index 1c9339282..e98fd8784 100644 --- a/cmake/imex.cmake +++ b/cmake/imex.cmake @@ -8,7 +8,7 @@ if (NOT DEFINED IMEX_INCLUDES) # TODO: Change to main https://github.com/intel/mlir-extensions when all the # required functionality is merged. - gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions + gc_fetch_content(imex d5bbd635dee500b8cff138686833bacfac5ade78 https://github.com/Menooker/mlir-extensions SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0 ) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 6ccf86fc6..63a0d7c34 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -47,7 +47,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { "DPAS register block sizes MxNxK">, ]; } -#endif +#endif // GC_USE_IMEX def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", "func::FuncOp"> { diff --git a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt index 465e5bbd2..d6020e8bc 100644 --- a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -27,7 +27,7 @@ else() endif() set(GC_PASSES GcInterface GcPasses) -if(GC_UNABLE_GPU) +if(GC_ENABLE_IMEX) list(APPEND GC_PASSES GcGpuPasses) endif() diff --git a/lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp b/lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp index b7675b9f7..84150563c 100644 --- a/lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp +++ b/lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp @@ -129,11 +129,35 @@ template size_t countUntil(T *ptr, T &&elem) { } // namespace static cl_device_id getDevice(cl_device_type *devtype) { - cl_platform_id platform; // OpenCL platform - cl_device_id device; // device ID - CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL)); - CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL)); - return device; + cl_uint numPlatforms; + CL_SAFE_CALL(clGetPlatformIDs(0, nullptr, &numPlatforms)) // get num platforms + + std::vector platforms(numPlatforms); + CL_SAFE_CALL(clGetPlatformIDs(numPlatforms, platforms.data(), + nullptr)); // get available platforms + + for (cl_uint i = 0; i < numPlatforms; ++i) { + // Get GPU device IDs for each platform + cl_uint numDevices; + cl_int status = + clGetDeviceIDs(platforms[i], *devtype, 0, /*devices.data()=*/nullptr, + &numDevices); // get num devices with 'devtype' + if (status != CL_SUCCESS) { + if (status == CL_DEVICE_NOT_FOUND) { + continue; // No GPU devices found on this platform + } + fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__, + "Error getting device IDs"); + abort(); + } + + std::vector devices(numDevices); + clGetDeviceIDs(platforms[i], *devtype, numDevices, devices.data(), nullptr); + return devices[0]; + } + + fprintf(stderr, "No suitable devices found."); + abort(); } struct GPUCLQUEUE { diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index bc4326abc..eacb5933b 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -597,12 +597,22 @@ static SmallVector updateTilesOffsets(PatternRewriter &rewriter, Location loc, ValueRange tiles, ArrayRef offsets) { SmallVector updatedTiles; + // convert static offsets to dynamic because of this IMEX bug: + // https://github.com/intel/mlir-extensions/issues/815 + std::vector dynOffsets; + for (auto &x : offsets) { + Value offset = rewriter.create(loc, x); + dynOffsets.push_back(offset); + } + ValueRange newOffsets{dynOffsets}; for (auto tile : tiles) { - auto updatedTile = - rewriter - .create(loc, tile.getType(), tile, - /*offsets=*/ValueRange{}, offsets) - .getResult(); + auto updatedTile = rewriter + .create( + loc, tile.getType(), tile, + /*offsets=*/newOffsets, + SmallVector{ShapedType::kDynamic, + ShapedType::kDynamic}) + .getResult(); updatedTiles.push_back(updatedTile); } @@ -648,11 +658,17 @@ static SmallVector createDescriptorTiles(PatternRewriter &rewriter, SmallVector tiles; for (int i = 0; i < loadShape[0]; i += descTile[0]) { + // convert static offsets to dynamic because of this IMEX bug: + // https://github.com/intel/mlir-extensions/issues/815 + Value newRowOffs = rewriter.create(loc, i); for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { + Value newColOffs = rewriter.create(loc, j); auto tile = rewriter .create( loc, descType, rootTile, - /*offsets=*/ValueRange{}, SmallVector{i, j}) + /*offsets=*/ValueRange{newRowOffs, newColOffs}, + SmallVector{ShapedType::kDynamic, + ShapedType::kDynamic}) .getResult(); tiles.push_back(tile); } @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, VectorType vecLoadType = VectorType::get(tileType.getShape(), tileType.getElementType()); - UnitAttr vnniAxisAttr = nullptr; + mlir::UnitAttr packedAttr = nullptr; if (vnniConf) { - vnniAxisAttr = UnitAttr::get(rewriter.getContext()); vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(), *vnniConf); + packedAttr = mlir::UnitAttr::get(rewriter.getContext()); } - + IntegerAttr transpose_bit = nullptr; SmallVector loadVec; for (auto tile : loadTiles) { + auto loadOp = rewriter.create( - loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr, + loc, vecLoadType, tile, packedAttr, transpose, transpose_bit, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); loadVec.push_back(loadOp); @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Load A sub-tiles. SmallVector loadVecA = - loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA); + loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); auto tileTypeA = cast(tilesA[0].getType()); // Load B sub-tiles. diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir index 2ba6cc0af..a62ae1ce5 100644 --- a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir @@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem // Create output initial value load tiles. // CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] -// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0] +// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] // CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]] // Load initial accumulator values. @@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem // Create input load tiles. // CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] -// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0] +// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] // CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] -// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0] +// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] // CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]] // Create DPAS computation loop over tiled reduction dimension. @@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem // Extract DPAS-sized chunks from larger loaded tile A. // Tile B is already in the correct shape. -// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16> +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> // CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> // CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> // CHECK-COUNT-3: vector.extract_strided_slice diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_32x32.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_32x32.mlir new file mode 100644 index 000000000..e91a1e7ca --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_32x32.mlir @@ -0,0 +1,97 @@ +// RUN: gc-opt %s --pass-pipeline='builtin.module(func.func(iterative-tiling-and-fusion{use-cost-model=0 default-tile-size=matmul:{16,16}}),eliminate-empty-tensors,empty-tensor-to-alloc-tensor,one-shot-bufferize{bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map},drop-equivalent-buffer-results,func.func(finalizing-bufferize),canonicalize,cse,drop-equivalent-buffer-results,expand-realloc,canonicalize,ownership-based-buffer-deallocation,canonicalize,buffer-deallocation-simplification,bufferization-lower-deallocations,cse,canonicalize,convert-bufferization-to-memref,func.func(scf-forall-to-parallel),func.func(linalg-to-xegpu{stages=1 dpas-tile=8,16,16 k-tile=16}),xegpu-fold-alias-ops,func.func(convert-linalg-to-parallel-loops),func.func(gpu-map-parallel-loops),func.func(convert-parallel-loops-to-gpu),func.func(insert-gpu-allocs),gpu-kernel-outlining,canonicalize,set-spirv-capabilities{client-api=opencl},gpu.module(set-spirv-abi-attrs{client-api=opencl}),lower-affine,imex-vector-linearize,gpu.module(convert-xegpu-to-vc),reconcile-unrealized-casts,bf16-to-gpu,gpu.module(convert-func-to-spirv),gpu.module(convert-vector-to-spirv),imex-convert-gpu-to-spirv,spirv.module(spirv-lower-abi-attrs,spirv-update-vce),func.func(llvm-request-c-wrappers),serialize-spirv,convert-vector-to-scf,convert-gpu-to-gpux,convert-scf-to-cf,convert-cf-to-llvm,convert-vector-to-llvm,convert-index-to-llvm,convert-arith-to-llvm,convert-func-to-llvm,convert-math-to-llvm,convert-gpux-to-llvm,convert-index-to-llvm,expand-strided-metadata,lower-affine,finalize-memref-to-llvm,reconcile-unrealized-casts)' \ +// RUN: | gc-cpu-runner -e main --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s +module{ + +func.func @linalg_matmul(%arg0: tensor<32x32xf16>, + %arg1: tensor<32x32xf16>, + %arg2: tensor<32x32xf16>) -> tensor<32x32xf16> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf16>, tensor<32x32xf16>) + outs(%arg2 : tensor<32x32xf16>) -> tensor<32x32xf16> + return %0 : tensor<32x32xf16> +} + +func.func @generate_t(%min : f16, %max : f16) -> tensor<32x32xf16> { + %c32 = arith.constant 32.0 : f16 + %c1023 = arith.constant 1023.0 : f16 + %tmp = arith.subf %max, %min : f16 + %step = arith.divf %tmp, %c1023 : f16 + + // Generate the values + // for i in range(n): + // for j in range(n): + // index = i * n + j + // value = min_value + index * step + %0 = tensor.generate { + ^bb0(%i : index, %j : index): + %cst32 = arith.constant 32.0 : f16 + %int0 = arith.index_cast %i : index to i16 + %int1 = arith.index_cast %j : index to i16 + %fp1 = arith.uitofp %int0 : i16 to f16 + %fp2 = arith.uitofp %int1 : i16 to f16 + + %tmp1 = arith.mulf %fp1, %cst32 : f16 + %res = arith.addf %tmp1, %fp2 : f16 + + %tmp2 = arith.mulf %res, %step : f16 + %val = arith.addf %min, %tmp2 : f16 + tensor.yield %val : f16 + } : tensor<32x32xf16> + return %0 : tensor<32x32xf16> +} + +func.func @main() { + %a0 = arith.constant 0.0 : f16 + %b0 = arith.constant 256.0 : f16 + %0 = call @generate_t(%a0, %b0) : (f16, f16) -> tensor<32x32xf16> + + %a1 = arith.constant 0.0 : f16 + %b1 = arith.constant 10.0 : f16 + %1 = call @generate_t(%a1, %b1) : (f16, f16) -> tensor<32x32xf16> + + %2 = arith.constant dense<0.0> : tensor<32x32xf16> + %gpu_res = call @linalg_matmul(%0, %1, %2) : (tensor<32x32xf16>, tensor<32x32xf16>, tensor<32x32xf16>) -> tensor<32x32xf16> + + %cast = tensor.cast %gpu_res : tensor<32x32xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + return +} + +func.func private @printMemrefF16(%ptr : tensor<*xf16>) +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 2 offset = 0 sizes = [32, 32] strides = [32, 1] data = +// Computed using numpy: +// CHECK-NEXT: [815, 816.5, 817.5, 819, 820, 821.5, 822.5, 824, 825, 826, 827, 828.5, 830, 831, 832, 833.5, 834.5, 836, 837, 838.5, 839.5, 840.5, 841.5, 843.5, 844.5, 845.5, 846.5, 848, 849, 850.5, 851.5, 853], +// CHECK-NEXT: [2058, 2062, 2064, 2068, 2072, 2076, 2080, 2084, 2088, 2090, 2094, 2098, 2102, 2106, 2110, 2114, 2116, 2120, 2124, 2128, 2132, 2136, 2138, 2144, 2146, 2150, 2154, 2158, 2162, 2166, 2168, 2172], +// CHECK-NEXT: [3298, 3304, 3310, 3318, 3324, 3330, 3336, 3342, 3348, 3354, 3360, 3368, 3374, 3380, 3386, 3392, 3398, 3404, 3410, 3418, 3424, 3430, 3434, 3442, 3448, 3454, 3460, 3468, 3472, 3478, 3484, 3492], +// CHECK-NEXT: [4544, 4552, 4560, 4568, 4576, 4584, 4592, 4604, 4612, 4620, 4628, 4640, 4648, 4656, 4664, 4672, 4680, 4692, 4700, 4708, 4716, 4724, 4732, 4744, 4752, 4760, 4768, 4780, 4788, 4796, 4804, 4812], +// CHECK-NEXT: [5784, 5796, 5804, 5816, 5828, 5840, 5848, 5864, 5872, 5884, 5896, 5908, 5916, 5928, 5940, 5952, 5964, 5972, 5984, 5996, 6008, 6020, 6028, 6044, 6052, 6064, 6072, 6088, 6096, 6108, 6120, 6132], +// CHECK-NEXT: [7024, 7036, 7052, 7068, 7080, 7092, 7104, 7120, 7136, 7148, 7160, 7176, 7188, 7204, 7216, 7232, 7244, 7256, 7272, 7284, 7300, 7312, 7324, 7340, 7352, 7368, 7380, 7396, 7408, 7420, 7436, 7452], +// CHECK-NEXT: [8272, 8288, 8304, 8320, 8336, 8352, 8368, 8384, 8400, 8416, 8432, 8448, 8464, 8480, 8496, 8512, 8528, 8544, 8560, 8576, 8592, 8608, 8624, 8648, 8656, 8672, 8688, 8712, 8728, 8744, 8752, 8776], +// CHECK-NEXT: [9512, 9528, 9544, 9568, 9584, 9608, 9624, 9640, 9664, 9680, 9696, 9720, 9736, 9752, 9768, 9792, 9808, 9832, 9848, 9872, 9888, 9904, 9920, 9944, 9960, 9976, 10000, 10016, 10032, 10056, 10072, 10096], +// CHECK-NEXT: [10752, 10776, 10792, 10816, 10840, 10856, 10880, 10904, 10920, 10944, 10960, 10984, 11008, 11024, 11048, 11072, 11088, 11112, 11136, 11160, 11176, 11200, 11216, 11240, 11264, 11280, 11304, 11328, 11344, 11368, 11384, 11408], +// CHECK-NEXT: [11992, 12016, 12040, 12064, 12088, 12112, 12136, 12160, 12184, 12208, 12232, 12256, 12280, 12304, 12320, 12352, 12376, 12400, 12416, 12448, 12464, 12488, 12512, 12544, 12560, 12584, 12608, 12632, 12656, 12680, 12704, 12728], +// CHECK-NEXT: [13232, 13256, 13288, 13312, 13336, 13368, 13392, 13416, 13440, 13472, 13496, 13528, 13552, 13576, 13600, 13632, 13656, 13680, 13704, 13736, 13760, 13784, 13808, 13840, 13864, 13888, 13912, 13944, 13968, 13992, 14016, 14048], +// CHECK-NEXT: [14472, 14504, 14528, 14560, 14592, 14616, 14648, 14680, 14704, 14736, 14760, 14792, 14816, 14848, 14872, 14904, 14936, 14960, 14992, 15024, 15048, 15080, 15104, 15136, 15168, 15192, 15216, 15256, 15280, 15304, 15336, 15368], +// CHECK-NEXT: [15728, 15760, 15784, 15824, 15848, 15880, 15912, 15944, 15976, 16008, 16032, 16072, 16104, 16128, 16160, 16200, 16224, 16256, 16288, 16320, 16352, 16384, 16416, 16448, 16480, 16512, 16528, 16576, 16608, 16624, 16656, 16704], +// CHECK-NEXT: [16960, 16992, 17024, 17072, 17104, 17136, 17168, 17200, 17232, 17264, 17296, 17344, 17376, 17408, 17440, 17472, 17504, 17536, 17568, 17616, 17648, 17680, 17712, 17744, 17776, 17808, 17840, 17888, 17904, 17952, 17984, 18016], +// CHECK-NEXT: [18208, 18240, 18272, 18320, 18352, 18384, 18416, 18464, 18496, 18528, 18560, 18608, 18640, 18672, 18720, 18752, 18784, 18816, 18864, 18896, 18928, 18976, 19008, 19040, 19072, 19120, 19152, 19184, 19216, 19264, 19296, 19328], +// CHECK-NEXT: [19456, 19488, 19520, 19568, 19600, 19648, 19680, 19728, 19760, 19792, 19840, 19872, 19920, 19952, 19984, 20032, 20064, 20112, 20144, 20192, 20224, 20256, 20304, 20336, 20384, 20416, 20448, 20496, 20528, 20576, 20608, 20656], +// CHECK-NEXT: [20688, 20736, 20768, 20816, 20848, 20896, 20928, 20976, 21024, 21056, 21104, 21152, 21184, 21232, 21264, 21312, 21344, 21392, 21424, 21472, 21520, 21552, 21600, 21648, 21680, 21728, 21760, 21808, 21840, 21888, 21920, 21968], +// CHECK-NEXT: [21936, 21968, 22016, 22064, 22112, 22144, 22192, 22240, 22288, 22320, 22368, 22416, 22448, 22496, 22544, 22592, 22624, 22672, 22720, 22768, 22800, 22848, 22896, 22944, 22976, 23024, 23072, 23120, 23152, 23200, 23232, 23296], +// CHECK-NEXT: [23168, 23216, 23264, 23312, 23360, 23408, 23440, 23504, 23536, 23584, 23632, 23680, 23728, 23776, 23808, 23872, 23904, 23952, 24000, 24048, 24096, 24144, 24192, 24240, 24288, 24320, 24368, 24416, 24464, 24512, 24560, 24608], +// CHECK-NEXT: [24416, 24464, 24512, 24560, 24608, 24656, 24704, 24752, 24800, 24848, 24896, 24944, 24992, 25040, 25088, 25152, 25200, 25248, 25280, 25344, 25392, 25440, 25488, 25536, 25584, 25632, 25680, 25728, 25776, 25824, 25872, 25920], +// CHECK-NEXT: [25648, 25712, 25760, 25808, 25856, 25904, 25952, 26016, 26064, 26112, 26160, 26224, 26272, 26320, 26368, 26432, 26480, 26528, 26576, 26624, 26672, 26736, 26784, 26832, 26880, 26928, 26976, 27040, 27088, 27136, 27184, 27248], +// CHECK-NEXT: [26896, 26944, 26992, 27056, 27104, 27168, 27216, 27280, 27328, 27376, 27424, 27488, 27536, 27600, 27648, 27712, 27760, 27808, 27856, 27920, 27968, 28016, 28080, 28128, 28192, 28240, 28288, 28352, 28400, 28448, 28496, 28560], +// CHECK-NEXT: [28128, 28192, 28240, 28304, 28368, 28416, 28464, 28528, 28592, 28640, 28688, 28752, 28816, 28864, 28912, 28976, 29040, 29088, 29152, 29216, 29264, 29312, 29376, 29440, 29488, 29536, 29600, 29664, 29712, 29760, 29824, 29888], +// CHECK-NEXT: [29376, 29440, 29488, 29552, 29616, 29664, 29728, 29792, 29840, 29904, 29952, 30032, 30080, 30144, 30192, 30256, 30320, 30368, 30432, 30496, 30544, 30608, 30672, 30736, 30784, 30848, 30896, 30960, 31024, 31072, 31136, 31200], +// CHECK-NEXT: [30640, 30704, 30752, 30832, 30880, 30944, 31008, 31072, 31120, 31184, 31248, 31312, 31376, 31440, 31488, 31568, 31616, 31680, 31728, 31808, 31856, 31920, 31984, 32048, 32112, 32176, 32224, 32288, 32352, 32416, 32464, 32544], +// CHECK-NEXT: [31872, 31936, 32000, 32080, 32128, 32192, 32256, 32336, 32384, 32448, 32512, 32576, 32640, 32704, 32768, 32832, 32896, 32960, 33024, 33088, 33152, 33216, 33280, 33344, 33408, 33472, 33536, 33600, 33664, 33728, 33792, 33856], +// CHECK-NEXT: [33120, 33184, 33248, 33312, 33376, 33440, 33504, 33600, 33664, 33728, 33792, 33856, 33920, 33984, 34048, 34112, 34176, 34240, 34304, 34368, 34432, 34496, 34560, 34656, 34720, 34784, 34848, 34912, 34976, 35040, 35104, 35168], +// CHECK-NEXT: [34368, 34432, 34496, 34560, 34624, 34688, 34752, 34848, 34912, 34976, 35040, 35136, 35200, 35264, 35328, 35392, 35456, 35520, 35584, 35680, 35744, 35808, 35872, 35936, 36000, 36064, 36128, 36224, 36288, 36352, 36416, 36512], +// CHECK-NEXT: [35616, 35680, 35744, 35808, 35872, 35968, 36032, 36096, 36160, 36256, 36320, 36384, 36448, 36512, 36608, 36672, 36736, 36800, 36864, 36960, 37024, 37088, 37152, 37248, 37312, 37376, 37440, 37536, 37600, 37664, 37728, 37824], +// CHECK-NEXT: [36832, 36928, 36992, 37056, 37152, 37216, 37280, 37376, 37440, 37504, 37568, 37664, 37728, 37792, 37856, 37952, 38016, 38080, 38176, 38240, 38304, 38400, 38464, 38560, 38624, 38688, 38752, 38848, 38912, 38976, 39040, 39136], +// CHECK-NEXT: [38080, 38144, 38240, 38304, 38400, 38464, 38528, 38624, 38688, 38784, 38848, 38912, 39008, 39072, 39136, 39232, 39296, 39392, 39456, 39552, 39616, 39680, 39744, 39840, 39904, 40000, 40064, 40160, 40224, 40288, 40352, 40448], +// CHECK-NEXT: [39328, 39392, 39488, 39552, 39648, 39712, 39776, 39872, 39968, 40032, 40096, 40192, 40256, 40352, 40416, 40512, 40576, 40672, 40736, 40832, 40896, 40992, 41056, 41152, 41216, 41280, 41376, 41472, 41536, 41600, 41696, 41760]