Skip to content

Commit

Permalink
[CUDA] Sets the CUDA context in more methods (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmed256 authored Sep 10, 2020
1 parent 56eb9ea commit c8a5876
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
2 changes: 2 additions & 0 deletions include/occa/modes/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace occa {

void* getNullPtr();

void setCudaContext();

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::properties &props);

Expand Down
30 changes: 20 additions & 10 deletions src/modes/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,17 @@ namespace occa {
return (void*) &(nullPtr->cuPtr);
}

void device::setCudaContext() {
OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
}

//---[ Stream ]---------------------
modeStream_t* device::createStream(const occa::properties &props) {
CUstream cuStream = NULL;

OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
setCudaContext();

OCCA_CUDA_ERROR("Device: createStream",
cuStreamCreate(&cuStream, CU_STREAM_DEFAULT));

Expand All @@ -141,8 +146,8 @@ namespace occa {
occa::streamTag device::tagStream() {
CUevent cuEvent = NULL;

OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
setCudaContext();

OCCA_CUDA_ERROR("Device: Tagging Stream (Creating Tag)",
cuEventCreate(&cuEvent,
CU_EVENT_DEFAULT));
Expand Down Expand Up @@ -219,6 +224,8 @@ namespace occa {
CUfunction cuFunction;
CUresult error;

setCudaContext();

error = cuModuleLoad(&cuModule, binaryFilename.c_str());
if (error) {
lock.release();
Expand Down Expand Up @@ -353,6 +360,8 @@ namespace occa {
CUmodule cuModule;
CUresult error;

setCudaContext();

error = cuModuleLoad(&cuModule, binaryFilename.c_str());
if (error) {
lock.release();
Expand Down Expand Up @@ -410,6 +419,8 @@ namespace occa {
CUmodule cuModule = NULL;
CUfunction cuFunction = NULL;

setCudaContext();

OCCA_CUDA_ERROR("Kernel [" + kernelName + "]: Loading Module",
cuModuleLoad(&cuModule, filename.c_str()));

Expand Down Expand Up @@ -438,8 +449,7 @@ namespace occa {

cuda::memory &mem = *(new cuda::memory(this, bytes, props));

OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
setCudaContext();

OCCA_CUDA_ERROR("Device: malloc",
cuMemAlloc(&(mem.cuPtr), bytes));
Expand All @@ -456,8 +466,8 @@ namespace occa {

cuda::memory &mem = *(new cuda::memory(this, bytes, props));

OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
setCudaContext();

OCCA_CUDA_ERROR("Device: malloc host",
cuMemAllocHost((void**) &(mem.mappedPtr), bytes));
OCCA_CUDA_ERROR("Device: get device pointer from host",
Expand All @@ -481,8 +491,8 @@ namespace occa {
const unsigned int flags = (props.get("attached_host", false) ?
CU_MEM_ATTACH_HOST : CU_MEM_ATTACH_GLOBAL);

OCCA_CUDA_ERROR("Device: Setting Context",
cuCtxSetCurrent(cuContext));
setCudaContext();

OCCA_CUDA_ERROR("Device: Unified alloc",
cuMemAllocManaged(&(mem.cuPtr),
bytes,
Expand Down
6 changes: 5 additions & 1 deletion src/modes/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace occa {
}

void kernel::deviceRun() const {
device *devicePtr = (device*) modeDevice;

const int args = (int) arguments.size();
if (!args) {
vArgs.resize(1);
Expand All @@ -74,10 +76,12 @@ namespace occa {
vArgs[i] = arguments[i].ptr();
// Set a proper NULL pointer
if (!vArgs[i]) {
vArgs[i] = ((device*) modeDevice)->getNullPtr();
vArgs[i] = devicePtr->getNullPtr();
}
}

devicePtr->setCudaContext();

OCCA_CUDA_ERROR("Launching Kernel",
cuLaunchKernel(cuFunction,
outerDims.x, outerDims.y, outerDims.z,
Expand Down

0 comments on commit c8a5876

Please sign in to comment.