Skip to content

Commit

Permalink
cudagraph memcpy through host (tinygrad#4137)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Apr 10, 2024
1 parent 5e6d215 commit af5984d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
21 changes: 15 additions & 6 deletions tinygrad/runtime/graph/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
Expand All @@ -24,6 +24,7 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
self.w_dependency_map: Dict[Any, Any] = {}
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
self.cpu_buffers = []

for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledASTRunner):
Expand All @@ -41,17 +42,25 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
src_dev = cast(CUDADevice, Device[src.device])
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
cpu_buffer = Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate()
self.cpu_buffers.append(cpu_buffer)

new_node = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=new_node)
node_to = cuda.CUgraphNode()
node_from = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None

cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (new_node, cp_params, src_dev.context, True)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1,
ctypes.byref(cp_params), dest_dev.context))
if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)

self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, device:CUDADevice):
super().__init__()
def _alloc(self, size, options:BufferOptions):
check(cuda.cuCtxSetCurrent(self.device.context))
if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0)))
if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
else: return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
def _free(self, opaque, options:BufferOptions):
if options.host: return check(cuda.cuMemFreeHost(opaque))
Expand Down

0 comments on commit af5984d

Please sign in to comment.