diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 4699e33cc865..29f9b53598ee 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import init_c_var, GraphException -from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device +from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \ @@ -24,6 +24,7 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) self.w_dependency_map: Dict[Any, Any] = {} self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list) + self.cpu_buffers = [] for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledASTRunner): @@ -41,17 +42,25 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var self.updatable_nodes[j] = (new_node, kern_params, c_args, False) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] - src_dev = cast(CUDADevice, Device[src.device]) + src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device]) + cpu_buffer = Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate() + self.cpu_buffers.append(cpu_buffer) - new_node = cuda.CUgraphNode() - deps = self.access_resources(read=[src], write=[dest], new_dependency=new_node) + node_to = cuda.CUgraphNode() + node_from = cuda.CUgraphNode() + deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1, dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) - if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (new_node, cp_params, src_dev.context, True) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1, + ctypes.byref(cp_params), dest_dev.context)) + if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True) self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index f1016582efdd..34c57ab57444 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -125,7 +125,7 @@ def __init__(self, device:CUDADevice): super().__init__() def _alloc(self, size, options:BufferOptions): check(cuda.cuCtxSetCurrent(self.device.context)) - if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0))) + if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01))) else: return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size))) def _free(self, opaque, options:BufferOptions): if options.host: return check(cuda.cuMemFreeHost(opaque))