Skip to content

Commit

Permalink
JitItem -> ExecItem (tinygrad#4146)
Browse files Browse the repository at this point in the history
* JitItem -> ExecItem

* execitem in realize

* cleaner

* JITRunner -> Runner
  • Loading branch information
geohot authored Apr 11, 2024
1 parent e79a11b commit b7e281c
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 58 deletions.
6 changes: 3 additions & 3 deletions extra/backends/ops_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tinygrad.runtime.autogen.hip as hip
from tinygrad.helpers import DEBUG, getenv, init_c_var
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, JITRunner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip

Expand Down Expand Up @@ -128,7 +128,7 @@ def transfer(self, dest:T, src:T, sz:int, **kwargs):
hip_set_device(self.device.device)
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))

class HIPSyncEvent(JITRunner):
class HIPSyncEvent(Runner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
super().__init__()
Expand All @@ -138,7 +138,7 @@ def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)

class HIPWaitEvent(JITRunner):
class HIPWaitEvent(Runner):
def __init__(self, device):
self.device, self.dname = cast(HIPDevice, Device[device]), device
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions test/external/external_test_hsa_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.dtype import dtypes
from tinygrad.runtime.driver.hsa import AQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
from tinygrad.engine.jit import JitItem
from tinygrad.engine.realize import ExecItem

def get_hsa_inc_prog(dev, inc=1):
prg = f"""
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_hsa_copies_sync(self):
test_buf1.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(1*4)))

jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])]
jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
graph = HSAGraph(jit_cache, [], {})

for i in range(10000):
Expand Down
8 changes: 4 additions & 4 deletions test/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import JITRunner
from tinygrad.device import Runner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Context, CI, OSX, getenv
Expand All @@ -13,12 +13,12 @@ def derandomize_model(model):

def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
# until we have a better way of typing the prg in JitItem
if issubclass(type(fxn.jit_cache[0].prg), JITRunner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
# until we have a better way of typing the prg in ExecItem
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
assert len(fxn.jit_cache) == expected_len
else:
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
# until we have a better way of typing the prg in ExecItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

Expand Down
8 changes: 4 additions & 4 deletions tinygrad/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def DEFAULT(self) -> str:

# **************** base Runner + helpers ****************

class JITRunner:
class Runner:
def __init__(self):
self.op_estimate:sint = 0
self.mem_estimate:sint = 0
Expand Down Expand Up @@ -67,7 +67,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option

# **************** Buffer / Allocator ****************

class BufferCopy(JITRunner):
class BufferCopy(Runner):
def copy(self, dest, src):
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
Expand Down Expand Up @@ -158,7 +158,7 @@ def compile_cached(self, src:str) -> bytes:
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib

class CompiledASTRunner(JITRunner):
class CompiledASTRunner(Runner):
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1):
super().__init__()
Expand Down Expand Up @@ -201,7 +201,7 @@ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=Fals
self.first_run = False
return et

class MultiDeviceJITGraph(JITRunner):
class MultiDeviceJITGraph(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")

Expand Down
33 changes: 14 additions & 19 deletions tinygrad/engine/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,43 @@
from tinygrad.nn.state import get_parameters
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.device import Compiled, Runner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem
from weakref import ref, WeakKeyDictionary
from dataclasses import dataclass

@dataclass(frozen=True)
class JitItem:
prg: JITRunner # or a graph executor like MetalGraph
rawbufs: List[Optional[Buffer]]

def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[sint, int]:
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0), \
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0)
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]

def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]:
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
max_batch_size = getenv("JIT_BATCH_SIZE", 32)
graphed_jit_cache: List[JitItem] = []
current_batch: List[JitItem] = []
graphed_jit_cache: List[ExecItem] = []
current_batch: List[ExecItem] = []
current_device: Optional[Compiled] = None

def flush_batch():
nonlocal current_batch, current_device, max_batch_size
try:
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
max_batch_size *= 2
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
Expand Down Expand Up @@ -82,7 +77,7 @@ def __init__(self, fxn:Callable[..., ReturnType]):
self.reset()

def reset(self):
self.jit_cache: List[JitItem] = []
self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
Expand Down Expand Up @@ -162,7 +157,7 @@ def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:

class _CacheCollector:
def __init__(self):
self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None

def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
Expand All @@ -179,9 +174,9 @@ def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]):

self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))

def finish(self) -> List[JitItem]:
def finish(self) -> List[ExecItem]:
if self.cache is None: return []
buffer_cache: Dict[PlaceHolder, Buffer] = {}
saved_cache, self.cache = self.cache, None
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector()
33 changes: 17 additions & 16 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from typing import List, Dict, Optional
from typing import List, Dict, Optional, cast, Generator
from dataclasses import dataclass
from tinygrad.helpers import colored
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
from tinygrad.device import JITRunner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.buffer import Buffer
from tinygrad.shape.symbolic import Variable

class CustomOp(JITRunner):
@dataclass(frozen=True)
class ExecItem:
prg: Runner
rawbufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None):
self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {})

class CustomOp(Runner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)

class EmptyOp(JITRunner):
class EmptyOp(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
update_stats(colored(f"empty {rawbufs[0].size:10d} {rawbufs[0].dtype}", "yellow"), 0, 0, {}, jit, 1, device=rawbufs[0].device)

def lower_schedule_item(si:ScheduleItem) -> JITRunner:
def lower_schedule_item(si:ScheduleItem) -> Runner:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
Expand All @@ -27,15 +35,8 @@ def lower_schedule_item(si:ScheduleItem) -> JITRunner:
if ast.op is LoadOps.EMPTY: return EmptyOp()
raise RuntimeError(f"don't know how to lower {ast}")

def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]] = None):
while len(schedule):
si = schedule.pop(0)

# get the program
prg = lower_schedule_item(si)

# allocate output buffers
for out in si.outputs: out.ensure_allocated()
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))

# run the function (put it in JIT)
prg.exec(list(si.outputs+si.inputs), var_vals if var_vals is not None else {})
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule): ei.run(var_vals)
6 changes: 3 additions & 3 deletions tinygrad/runtime/graph/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals

class CUDAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
# Check all jit items are compatible.
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException

Expand Down
6 changes: 3 additions & 3 deletions tinygrad/runtime/graph/hsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL

Expand All @@ -26,7 +26,7 @@ def _submit_packet(self):
self.available_packet_slots -= 1

class HSAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/runtime/graph/metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice, wait_check

class MetalGraph:
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException

self.jit_cache = jit_cache
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/runtime/ops_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os, mmap, _posixshmem, io, functools
from typing import Dict, List, Any, Optional
from tinygrad.helpers import prod, OSX
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
from tinygrad.device import Compiled, Allocator, Runner, Buffer
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
from tinygrad.shape.view import strides_for_shape

Expand Down Expand Up @@ -32,7 +32,7 @@ def copyout(self, dest:memoryview, src:DiskBuffer):
else:
dest[:] = src._buf()

class DiskRunner(JITRunner):
class DiskRunner(Runner):
def __init__(self, ast:LazyOp):
# two ASTs are allowed here.
assert ast.op is BufferOps.STORE, "output of AST must be store"
Expand Down

0 comments on commit b7e281c

Please sign in to comment.