Skip to content

Commit

Permalink
cuda p2p enable when available (tinygrad#4153)
Browse files Browse the repository at this point in the history
  • Loading branch information
nimlgen authored Apr 12, 2024
1 parent 380f27d commit 5a57b48
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tinygrad/runtime/graph/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], va
node_from = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
if getenv("CUDA_P2P"):
if getenv("CUDA_P2P", CUDADevice.peer_access):
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
Expand Down
14 changes: 12 additions & 2 deletions tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,25 @@ def transfer(self, dest, src, sz:int, src_dev, dest_dev):

class CUDADevice(Compiled):
devices: List[CUDADevice] = []
peer_access = False

def __init__(self, device:str):
device_id = int(device.split(":")[1]) if ":" in device else 0
if not CUDACPU:
check(cuda.cuInit(0))
check(cuda.cuDeviceGet(ctypes.byref(cu_device := cuda.CUdevice()), device_id))
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, cu_device)))
self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))

for dev in CUDADevice.devices:
check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
if val.value != 1: continue
check(cuda.cuCtxSetCurrent(dev.context))
check(cuda.cuCtxEnablePeerAccess(self.context, 0))
check(cuda.cuCtxSetCurrent(self.context))
check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
CUDADevice.peer_access = True

self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
CUDADevice.devices.append(self)
Expand Down

0 comments on commit 5a57b48

Please sign in to comment.