Skip to content

Commit

Permalink
[Hopper TMA] CUDA codegen for async copy with barrier synchronization (
Browse files Browse the repository at this point in the history
…apache#15616)

[Codegen] CUDA async copy with barrier synchronization
  • Loading branch information
adstraw authored Aug 25, 2023
1 parent b5dae98 commit b9652a2
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 56 deletions.
32 changes: 32 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,38 @@ TVM_DLL const Op& ptx_cp_async();
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();

/*!
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
*
* ptx_cp_async_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();

/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
*
* ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();

/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
* ptx_arrive_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();

/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
* ptx_wait_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_wait_barrier();

/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
* For example, if each thread in a warp of size 32 has 4 elements from the result of
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,10 @@ def wrapped(*args, **kwargs):
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
Expand Down Expand Up @@ -2113,6 +2117,10 @@ def wrapped(*args, **kwargs):
"ptx_cp_async",
"ptx_wait_group",
"ptx_commit_group",
"ptx_cp_async_barrier",
"ptx_init_barrier_thread_count",
"ptx_arrive_barrier",
"ptx_wait_barrier",
"mma_store",
"mma_fill",
"vectorlow",
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@
tvm_fill_fragment,
)
from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
from .op import (
ptx_ldmatrix,
ptx_cp_async,
ptx_commit_group,
ptx_wait_group,
ptx_cp_async_barrier,
ptx_init_barrier_thread_count,
ptx_arrive_barrier,
ptx_wait_barrier,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
80 changes: 80 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,86 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)


def ptx_cp_async_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)


def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array
thread_count : int
Number of threads expected to arrive at the barrier
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
)


def ptx_arrive_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)


def ptx_wait_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)


def vectorlow(dtype, vec):
"""Get the low level half of the vector
Expand Down
39 changes: 39 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

if (need_cast_smem_ptr_to_int_) {
decl_stream << "__forceinline__ __device__ unsigned int\n";
decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n";
decl_stream << "{\n";
decl_stream << " unsigned int smem_int;\n";
decl_stream << " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
"cvt.u32.u64 %0, smem_int; }\"\n";
decl_stream << " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n";
decl_stream << " return smem_int;\n";
decl_stream << "}\n";
}

decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
Expand Down Expand Up @@ -873,6 +885,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
}
Expand Down Expand Up @@ -941,6 +954,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
need_cast_smem_ptr_to_int_ = true;
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else {
Expand All @@ -952,6 +966,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
std::string thread_count = this->PrintExpr(op->args[2]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class CodeGenCUDA final : public CodeGenC {
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");

Expand Down
93 changes: 75 additions & 18 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
Expand Down Expand Up @@ -638,12 +633,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_elem_offset, const std::string& bytes) {
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
Expand Down Expand Up @@ -674,12 +664,7 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"{ .reg .pred p;"
Expand Down Expand Up @@ -724,5 +709,77 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
return predicated_asm_code;
}

std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
__asm__ __volatile__(
"cp.async.mbarrier.arrive.shared.b64 [%0];"
:: "r" (barrier_addr_int)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
int thread_count = {thread_count};
__asm__ __volatile__(
"mbarrier.init.shared.b64 [%0], %1;"
:: "r"(barrier_addr_int), "r"(thread_count)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
replacer.register_rule("{thread_count}", thread_count);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintArriveBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
__asm__ __volatile__(
"{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }"
:: "r"(barrier_addr_int)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintWaitBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
constexpr int phase_bit = 0;
__asm__ __volatile__(
"{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }"
:: "r"(barrier_addr_int), "r"(phase_bit)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

} // namespace codegen
} // namespace tvm
26 changes: 26 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& bytes,
const std::string& predicate_value);

/*!
* \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintCpAsyncBarrierAsm(const std::string& barrier);

/*!
* \brief Print ptx barrier initialization of thread count using mbarrier.init
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
* \param thread_count: The number of threads expected to arrive at the barrier
*/
std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count);

/*!
* \brief Print ptx barrier arrival using mbarrier.arrive
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintArriveBarrierAsm(const std::string& barrier);

/*!
* \brief Print ptx barrier wait using mbarrier.try_wait
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintWaitBarrierAsm(const std::string& barrier);

} // namespace codegen
} // namespace tvm

Expand Down
Loading

0 comments on commit b9652a2

Please sign in to comment.