forked from ColfaxResearch/cfx-article-src
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add version with TiledCopy vectorized load into registers and write o…
…ut via TMA store
- Loading branch information
Showing
3 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#pragma once | ||
|
||
#include <cute/atom/mma_traits_sm90_gmma.hpp> | ||
|
||
namespace cfx { | ||
|
||
using namespace cute; | ||
|
||
// Helper functions for retrieving optimal swizzled layouts | ||
template <typename PrecType, int DIM> constexpr auto getSmemLayoutK() { | ||
|
||
constexpr int headSizeBytes = sizeof(PrecType) * DIM; | ||
|
||
if constexpr (headSizeBytes == 16) { | ||
return GMMA::Layout_K_INTER_Atom<PrecType>{}; | ||
} else if constexpr (headSizeBytes == 32) { | ||
return GMMA::Layout_K_SW32_Atom<PrecType>{}; | ||
} else if constexpr (headSizeBytes == 64) { | ||
return GMMA::Layout_K_SW64_Atom<PrecType>{}; | ||
} else { | ||
return GMMA::Layout_K_SW128_Atom<PrecType>{}; | ||
} | ||
} | ||
|
||
template <typename PrecType, int DIM> constexpr auto getSmemLayoutMN() { | ||
|
||
constexpr int headSizeBytes = sizeof(PrecType) * DIM; | ||
|
||
if constexpr (headSizeBytes == 16) { | ||
return GMMA::Layout_MN_INTER_Atom<PrecType>{}; | ||
} else if constexpr (headSizeBytes == 32) { | ||
return GMMA::Layout_MN_SW32_Atom<PrecType>{}; | ||
} else if constexpr (headSizeBytes == 64) { | ||
return GMMA::Layout_MN_SW64_Atom<PrecType>{}; | ||
} else { | ||
return GMMA::Layout_MN_SW128_Atom<PrecType>{}; | ||
} | ||
} | ||
|
||
void set_smem_size(int smem_size, void const* kernel) | ||
{ | ||
// account for dynamic smem capacity if needed | ||
if (smem_size >= (48 << 10)) { | ||
cudaError_t result = cudaFuncSetAttribute( | ||
kernel, | ||
cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
smem_size); | ||
if (cudaSuccess != result) { | ||
result = cudaGetLastError(); // to clear the error bit | ||
std::cout << " Shared Memory Allocation Failed " << std:: endl << " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result) << std::endl; | ||
} | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
#pragma once | ||
|
||
#include <cassert> | ||
#include <cstdio> | ||
#include <cstdlib> | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
|
||
#include <cutlass/numeric_types.h> | ||
#include <cute/tensor.hpp> | ||
#include <cute/arch/cluster_sm90.hpp> | ||
#include <cutlass/cluster_launch.hpp> | ||
#include <cutlass/cutlass.h> | ||
#include <cutlass/arch/barrier.h> | ||
|
||
#include "cutlass/util/GPU_Clock.hpp" | ||
#include "cutlass/util/command_line.h" | ||
#include "cutlass/util/helper_cuda.hpp" | ||
#include "cutlass/util/print_error.hpp" | ||
|
||
#include "cutlass/detail/layout.hpp" | ||
|
||
#include "shared_storage.h" | ||
#include "smem_helper.hpp" | ||
|
||
using namespace cute; | ||
|
||
template <class TensorS, class SmemLayout, class TiledCopyS, class TiledCopyD, | ||
class GmemLayoutD, class TileShapeD, class ThreadLayoutM, class SmemLayoutM> | ||
__global__ static void __launch_bounds__(256) | ||
transposeKernelTMA(TensorS const S, SmemLayout const smemLayout, TiledCopyS const tiled_copy_S, | ||
CUTE_GRID_CONSTANT TiledCopyD const tmaStoreD, GmemLayoutD const gmemLayoutD, TileShapeD const tileShapeD, ThreadLayoutM const tM, SmemLayoutM const smemLayoutM) { | ||
using namespace cute; | ||
using Element = typename TensorS::value_type; | ||
|
||
int lane_predicate = cute::elect_one_sync(); | ||
int warp_idx = cutlass::canonical_warp_idx_sync(); | ||
bool leaderWarp = warp_idx == 0; | ||
|
||
// Use Shared Storage structure to allocate aligned SMEM addresses. | ||
extern __shared__ char shared_memory[]; | ||
using SharedStorage = SharedStorageTranspose<Element, SmemLayout>; | ||
SharedStorage &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory); | ||
Tensor sM = make_tensor(make_smem_ptr(shared_storage.smem.data()), smemLayoutM); | ||
|
||
Tensor gS = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (bM, bN) | ||
auto thr_copy_S = tiled_copy_S.get_thread_slice(threadIdx.x); | ||
|
||
Tensor tSgS = thr_copy_S.partition_S(gS); // (CopyOp, CopyM, CopyN) | ||
Tensor tSrS = make_fragment_like(tSgS); // (CopyOp, CopyM, CopyN) | ||
Tensor tMsM = local_partition(sM, tM, threadIdx.x); | ||
|
||
// Copy from GMEM to RMEM to SMEM | ||
copy(tiled_copy_S, tSgS, tSrS); | ||
copy(tSrS, tMsM); | ||
|
||
|
||
auto synchronize = [&]() { | ||
cutlass::arch::NamedBarrier::sync(size(ThreadLayoutM{}), 0); | ||
}; | ||
cutlass::arch::fence_view_async_shared(); | ||
synchronize(); | ||
|
||
// Issue the TMA store. | ||
Tensor mD = tmaStoreD.get_tma_tensor(shape(gmemLayoutD)); | ||
auto blkCoordD = make_coord(blockIdx.y, blockIdx.x); | ||
Tensor gD = local_tile(mD, tileShapeD, blkCoordD); | ||
Tensor sD = make_tensor(make_smem_ptr(shared_storage.smem.data()), smemLayout); // (bN, bM) | ||
|
||
auto cta_tmaD = tmaStoreD.get_slice(0); | ||
|
||
Tensor tDgDX = cta_tmaD.partition_D(gD); | ||
Tensor tDgD = group_modes<1, rank(tDgDX)>(tDgDX); // (TMA,REST) | ||
assert(size<1>(tDgD) == 1); | ||
|
||
Tensor tDsDX = cta_tmaD.partition_S(sD); | ||
Tensor tDsD = group_modes<1, rank(tDsDX)>(tDsDX); // (TMA,REST) | ||
static_assert(size<1>(tDsD) == 1); | ||
|
||
if (leaderWarp and lane_predicate) { | ||
copy(tmaStoreD, tDsD, tDgD); | ||
} | ||
// Wait for TMA store to complete. | ||
tma_store_wait<0>(); | ||
|
||
} | ||
|
||
int transpose_host_kernel_tma(int M, int N) { | ||
printf("Vectorized load into registers, write out via TMA Store\n"); | ||
printf("Profiler reports uncoalesced smem accesses\n"); | ||
|
||
using Element = float; | ||
using namespace cute; | ||
|
||
auto tensor_shape = make_shape(M, N); | ||
auto tensor_shape_trans = make_shape(N, M); | ||
|
||
//Allocate and initialize | ||
thrust::host_vector<Element> h_S(size(tensor_shape)); // (M, N) | ||
thrust::host_vector<Element> h_D(size(tensor_shape_trans)); // (N, M) | ||
|
||
for (size_t i = 0; i < h_S.size(); ++i) { | ||
h_S[i] = static_cast<Element>(i); | ||
h_D[i] = Element{}; | ||
} | ||
|
||
thrust::device_vector<Element> d_S = h_S; | ||
thrust::device_vector<Element> d_D = h_D; | ||
|
||
// | ||
// Make tensors | ||
// | ||
|
||
// Could also have ColMajor. | ||
auto gmemLayoutS = make_layout(tensor_shape, GenRowMajor{}); | ||
auto gmemLayoutD = make_layout(tensor_shape_trans, GenRowMajor{}); | ||
|
||
Tensor tensor_S = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), gmemLayoutS); | ||
Tensor tensor_D = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayoutD); | ||
|
||
// | ||
// Tile tensors | ||
// | ||
|
||
using bM = Int<32>; | ||
using bN = Int<32>; | ||
|
||
auto block_shape = make_shape(bM{}, bN{}); // (bM, bN) | ||
auto block_shape_trans = make_shape(bN{}, bM{}); // (bN, bM) | ||
|
||
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((bM, bN), m', n') | ||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape_trans); // ((bN, bM), n', m') | ||
|
||
auto threadLayoutS = make_layout(make_shape(Int<32>{}, Int<8>{}), GenRowMajor{}); | ||
auto vecLayoutS = make_layout(make_shape(Int<1>{}, Int<4>{})); | ||
using AccessTypeS = cutlass::AlignedArray<Element, size(vecLayoutS)>; | ||
using AtomS = Copy_Atom<UniversalCopy<AccessTypeS>, Element>; | ||
auto tiled_copy_S = make_tiled_copy(AtomS{}, threadLayoutS, vecLayoutS); | ||
|
||
auto tileShapeD = block_shape_trans; | ||
auto smemLayoutD = | ||
tile_to_shape(cfx::getSmemLayoutK<Element, bM{}>(), | ||
make_shape(shape<0>(tileShapeD), shape<1>(tileShapeD))); | ||
//TMA only supports certain swizzles | ||
//https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp | ||
auto tmaD = | ||
make_tma_copy(SM90_TMA_STORE{}, tensor_D, smemLayoutD, tileShapeD, Int<1>{}); | ||
|
||
auto tileShapeM = make_shape(Int<4>{}, Int<8>{}, Int<32>{}); | ||
auto smemLayoutM = composition(smemLayoutD, make_layout(tileShapeM)); | ||
auto threadLayoutM = make_layout(make_shape(Int<1>{},Int<8>{}, Int<32>{}), make_stride(Int<1>{}, Int<1>{}, Int<8>{})); | ||
|
||
size_t smem_size = int(sizeof(SharedStorageTranspose<Element, decltype(smemLayoutD)>)); | ||
|
||
// | ||
// Determine grid and block dimensions | ||
// | ||
|
||
dim3 gridDim(size<1>(tiled_tensor_S), size<2>(tiled_tensor_S)); // Grid shape corresponds to modes m' and n' | ||
dim3 blockDim(size(threadLayoutS)); | ||
|
||
transposeKernelTMA<<<gridDim, blockDim, smem_size>>>(tiled_tensor_S, smemLayoutD, tiled_copy_S, tmaD, | ||
gmemLayoutD, tileShapeD, threadLayoutM, smemLayoutM); | ||
|
||
int iterations = 10; | ||
|
||
for (int i = 0; i < iterations; i++) { | ||
auto t1 = std::chrono::high_resolution_clock::now(); | ||
transposeKernelTMA<<<gridDim, blockDim, smem_size>>>(tiled_tensor_S, smemLayoutD, tiled_copy_S, tmaD, | ||
gmemLayoutD, tileShapeD, threadLayoutM, smemLayoutM); | ||
cudaError result = cudaDeviceSynchronize(); | ||
auto t2 = std::chrono::high_resolution_clock::now(); | ||
if (result != cudaSuccess) { | ||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) | ||
<< std::endl; | ||
return -1; | ||
} | ||
std::chrono::duration<double, std::milli> tDiff = t2 - t1; | ||
double time_ms = tDiff.count(); | ||
std::cout << "Trial " << i << " Completed in " << time_ms << "ms (" | ||
<< 2e-6 * M * N * sizeof(Element) / time_ms << " GB/s)" | ||
<< std::endl; | ||
} | ||
|
||
cudaError result = cudaDeviceSynchronize(); | ||
if (result != cudaSuccess) { | ||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; | ||
return -1; | ||
} | ||
|
||
// | ||
// Verify | ||
// | ||
|
||
h_D = d_D; | ||
|
||
int good = 0, bad = 0; | ||
|
||
auto transposeFunction = make_layout(tensor_shape, GenRowMajor{}); | ||
|
||
for (size_t i = 0; i < h_D.size(); ++i) { | ||
if (h_D[i] == h_S[transposeFunction(i)]) | ||
good++; | ||
else | ||
bad++; | ||
} | ||
|
||
std::cout << "Success " << good << ", Fail " << bad << std::endl; | ||
|
||
return 0; | ||
} | ||
|
||
|
||
|