Skip to content

Commit

Permalink
Just-in-time deserialization (#353)
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Andreas Entschev <[email protected]>

* Initial implementing of ObjectProxy

* Added basic tests of ObjectProxy

* Implemented some more proxy attributes

* Added spilling of proxy object optional

* Re-added dask_serialize for DeviceSerialized

* Added support of __array__

* Added __sizeof__

* Added some spill_proxy tests in test_device_host_file.py

* Checking len() instead of .size()

* Added dispatch support of hash_object_dispatch and group_split_dispatch

* Added "*args, **kwargs" to dispatch of ObjectProxy

* Added dispatch of make_scalar

* Added dispatch of concat_dispatch

* meta.yaml: added pandas dependency

* meta.yaml: depend on dask (not only dask-core)

* Added jit-unspill worker option

* meta.yaml: removed pandas

* Using explicit args for the dispatch functions

* ObjectProxy._obj_pxy_serialize(): takes serializers

* serializers replaces is_serialized

* Supporting cuda serializers

* Added a lot of operators

* fixed typos

* Support and test of a proxy object of a proxy object

* test_spilling_local_cuda_cluster(): added some extra checks

* Added _obj_pxy_is_cuda_object()

* asproxy(): added subclass argument

* fixed type in test_spilling_local_cuda_cluster check

* Added test of communicating proxy objects

* Making ObjectProxy threadsafe

* renamed ObjectProxy => ProxyObject

* Never re-serialize proxy objects

* Test: setting device_memory_limit="1B" to force serialization

* test: added an explicit client shutdown

* Added some str/repr tests

* added some more checks in test_proxy_object_of_numpy

* ProxyObject: added docs

* added unproxy()

* Added ValueError when serializers isn't specified

* Style and spelling fixes

* ProxyObject.__sizeof__(): use dask.sizeof()

* Serializer: convert to tuples before comparing
  • Loading branch information
madsbk authored Nov 23, 2020
1 parent 70950a4 commit 1429b67
Show file tree
Hide file tree
Showing 8 changed files with 912 additions and 19 deletions.
2 changes: 1 addition & 1 deletion conda/recipes/dask-cuda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ requirements:
- setuptools
run:
- python x.x
- dask-core >=2.4.0
- dask >=2.4.0
- distributed >=2.18.0
- pynvml >=8.0.3
- numpy >=1.16.0
Expand Down
7 changes: 7 additions & 0 deletions dask_cuda/cli/dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@
"InfiniBand only and will still cause unpredictable errors if not _ALL_ "
"interfaces are connected and properly configured.",
)
@click.option(
"--enable-jit-unspill/--disable-jit-unspill",
default=None, # If not specified, use Dask config
help="Enable just-in-time unspilling",
)
def main(
scheduler,
host,
Expand Down Expand Up @@ -218,6 +223,7 @@ def main(
enable_nvlink,
enable_rdmacm,
net_devices,
enable_jit_unspill,
**kwargs,
):
if tls_ca_file and tls_cert and tls_key:
Expand Down Expand Up @@ -252,6 +258,7 @@ def main(
enable_nvlink,
enable_rdmacm,
net_devices,
enable_jit_unspill,
**kwargs,
)

Expand Down
7 changes: 7 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
enable_nvlink=False,
enable_rdmacm=False,
net_devices=None,
jit_unspill=None,
**kwargs,
):
# Required by RAPIDS libraries (e.g., cuDF) to ensure no context
Expand Down Expand Up @@ -177,6 +178,11 @@ def del_pid_file():
cuda_device_index=0,
)

if jit_unspill is None:
self.jit_unspill = dask.config.get("jit-unspill", default=False)
else:
self.jit_unspill = jit_unspill

self.nannies = [
t(
scheduler,
Expand Down Expand Up @@ -216,6 +222,7 @@ def del_pid_file():
),
"memory_limit": memory_limit,
"local_directory": local_directory,
"jit_unspill": self.jit_unspill,
},
),
**kwargs,
Expand Down
41 changes: 37 additions & 4 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
from distributed.utils import nbytes
from distributed.worker import weight

from . import proxy_object
from .is_device_object import is_device_object
from .utils import nvtx_annotate


class DeviceSerialized:
""" Store device object on the host
This stores a device-side object as
1. A msgpack encodable header
2. A list of `bytes`-like objects (like NumPy arrays)
that are in host memory
Expand Down Expand Up @@ -66,6 +65,27 @@ def host_to_device(s: DeviceSerialized) -> object:
return deserialize(s.header, s.frames)


@nvtx_annotate("SPILL_D2H", color="red", domain="dask_cuda")
def pxy_obj_device_to_host(obj: object) -> proxy_object.ProxyObject:
try:
# Never re-serialize proxy objects.
if obj._obj_pxy["serializers"] is None:
return obj
except (KeyError, AttributeError):
pass

# Notice, both the "dask" and the "pickle" serializer will
# spill `obj` to main memory.
return proxy_object.asproxy(obj, serializers=["dask", "pickle"])


@nvtx_annotate("SPILL_H2D", color="green", domain="dask_cuda")
def pxy_obj_host_to_device(s: proxy_object.ProxyObject) -> object:
# Notice, we do _not_ deserialize at this point. The proxy
# object automatically deserialize just-in-time.
return s


class DeviceHostFile(ZictBase):
""" Manages serialization/deserialization of objects.
Expand All @@ -86,10 +106,16 @@ class DeviceHostFile(ZictBase):
implies no spilling to disk.
local_directory: path
Path where to store serialized objects on disk
jit_unspill: bool
If True, enable just-in-time unspilling (see proxy_object.ProxyObject).
"""

def __init__(
self, device_memory_limit=None, memory_limit=None, local_directory=None,
self,
device_memory_limit=None,
memory_limit=None,
local_directory=None,
jit_unspill=False,
):
if local_directory is None:
local_directory = dask.config.get("temporary-directory") or os.getcwd()
Expand All @@ -115,7 +141,14 @@ def __init__(

self.device_keys = set()
self.device_func = dict()
self.device_host_func = Func(device_to_host, host_to_device, self.host_buffer)
if jit_unspill:
self.device_host_func = Func(
pxy_obj_device_to_host, pxy_obj_host_to_device, self.host_buffer
)
else:
self.device_host_func = Func(
device_to_host, host_to_device, self.host_buffer
)
self.device_buffer = Buffer(
self.device_func, self.device_host_func, device_memory_limit, weight=weight
)
Expand Down
9 changes: 9 additions & 0 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class LocalCUDACluster(LocalCluster):
but in that case with default (non-managed) memory type.
WARNING: managed memory is currently incompatible with NVLink, trying
to enable both will result in an exception.
jit_unspill: bool
If True, enable just-in-time unspilling (see proxy_object.ProxyObject).
Examples
--------
Expand Down Expand Up @@ -133,6 +135,7 @@ def __init__(
ucx_net_devices=None,
rmm_pool_size=None,
rmm_managed_memory=False,
jit_unspill=None,
**kwargs,
):
# Required by RAPIDS libraries (e.g., cuDF) to ensure no context
Expand Down Expand Up @@ -182,6 +185,11 @@ def __init__(
"Processes are necessary in order to use multiple GPUs with Dask"
)

if jit_unspill is None:
self.jit_unspill = dask.config.get("jit-unspill", default=False)
else:
self.jit_unspill = jit_unspill

if data is None:
data = (
DeviceHostFile,
Expand All @@ -191,6 +199,7 @@ def __init__(
"local_directory": local_directory
or dask.config.get("temporary-directory")
or os.getcwd(),
"jit_unspill": self.jit_unspill,
},
)

Expand Down
Loading

0 comments on commit 1429b67

Please sign in to comment.