diff --git a/common/kokkos-sampler/kp_sampler_skip.cpp b/common/kokkos-sampler/kp_sampler_skip.cpp index 773753f8b..ff2f76ed4 100644 --- a/common/kokkos-sampler/kp_sampler_skip.cpp +++ b/common/kokkos-sampler/kp_sampler_skip.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "../../profiling/all/kp_core.hpp" #include "kp_config.hpp" @@ -13,6 +14,9 @@ static uint64_t kernelSampleSkip = 101; static int tool_verbosity = 0; static int tool_globFence = 0; +// a hash table mapping kID to nestedkID +static std::unordered_map infokIDSample; + typedef void (*initFunction)(const int, const uint64_t, const uint32_t, void*); typedef void (*finalizeFunction)(); typedef void (*beginFunction)(const char*, const uint32_t, uint64_t*); @@ -153,15 +157,15 @@ void kokkosp_finalize_library() { void kokkosp_begin_parallel_for(const char* name, const uint32_t devID, uint64_t* kID) { *kID = uniqID++; - if (((*kID) % kernelSampleSkip) == 0) { if (tool_verbosity > 0) { printf("KokkosP: sample %llu calling child-begin function...\n", (unsigned long long)(*kID)); } - if (NULL != beginForCallee) { - (*beginForCallee)(name, devID, kID); + uint64_t nestedkID = 0; + (*beginForCallee)(name, devID, &nestedkID); + infokIDSample.insert({*kID, nestedkID}); } } } @@ -172,9 +176,9 @@ void kokkosp_end_parallel_for(const uint64_t kID) { printf("KokkosP: sample %llu calling child-end function...\n", (unsigned long long)(kID)); } - if (NULL != endForCallee) { - (*endForCallee)(kID); + uint64_t retrievedNestedkID = infokIDSample.at(kID); + (*endForCallee)(retrievedNestedkID); } } } @@ -182,15 +186,15 @@ void kokkosp_end_parallel_for(const uint64_t kID) { void kokkosp_begin_parallel_scan(const char* name, const uint32_t devID, uint64_t* kID) { *kID = uniqID++; - if (((*kID) % kernelSampleSkip) == 0) { if (tool_verbosity > 0) { printf("KokkosP: sample %llu calling child-begin function...\n", (unsigned long long)(*kID)); } - if (NULL != beginScanCallee) { - (*beginScanCallee)(name, devID, kID); + uint64_t nestedkID = 0; + (*beginScanCallee)(name, devID, &nestedkID); + infokIDSample.insert({*kID, nestedkID}); } } } @@ -203,7 +207,8 @@ void kokkosp_end_parallel_scan(const uint64_t kID) { } if (NULL != endScanCallee) { - (*endScanCallee)(kID); + uint64_t retrievedNestedkID = infokIDSample.at(kID); + (*endScanCallee)(retrievedNestedkID); } } } @@ -211,7 +216,6 @@ void kokkosp_end_parallel_scan(const uint64_t kID) { void kokkosp_begin_parallel_reduce(const char* name, const uint32_t devID, uint64_t* kID) { *kID = uniqID++; - if (((*kID) % kernelSampleSkip) == 0) { if (tool_verbosity > 0) { printf("KokkosP: sample %llu calling child-begin function...\n", @@ -219,7 +223,9 @@ void kokkosp_begin_parallel_reduce(const char* name, const uint32_t devID, } if (NULL != beginReduceCallee) { - (*beginReduceCallee)(name, devID, kID); + uint64_t nestedkID = 0; + (*beginReduceCallee)(name, devID, &nestedkID); + infokIDSample.insert({*kID, nestedkID}); } } } @@ -232,7 +238,8 @@ void kokkosp_end_parallel_reduce(const uint64_t kID) { } if (NULL != endReduceCallee) { - (*endReduceCallee)(kID); + uint64_t retrievedNestedkID = infokIDSample.at(kID); + (*endReduceCallee)(retrievedNestedkID); } } }