diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index 1631fa00412..b2aba4f978b 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -106,8 +106,25 @@ class BufferOwner(Serializable): been accessed outside of BufferOwner. In this case, we have no control over knowing if the data is being modified by a third party. - Use `_from_device_memory` and `_from_host_memory` to create + Use `from_device_memory` and `from_host_memory` to create a new instance from either device or host memory respectively. + + Parameters + ---------- + ptr + An integer representing a pointer to memory. + size + The size of the memory in nbytes + owner + Python object to which the lifetime of the memory allocation is tied. + This buffer will keep a reference to `owner`. + exposed + Pointer to the underlying memory + + Raises + ------ + ValueError + If size is negative """ _ptr: int @@ -117,14 +134,25 @@ class BufferOwner(Serializable): # The set of buffers that point to this owner. _slices: weakref.WeakSet[Buffer] - def __init__(self): - raise ValueError( - f"do not create a {self.__class__} directly, please " - "use the factory function `cudf.core.buffer.as_buffer`" - ) + def __init__( + self, + *, + ptr: int, + size: int, + owner: object, + exposed: bool, + ): + if size < 0: + raise ValueError("size cannot be negative") + + self._ptr = ptr + self._size = size + self._owner = owner + self._exposed = exposed + self._slices = weakref.WeakSet() @classmethod - def _from_device_memory(cls, data: Any, exposed: bool) -> Self: + def from_device_memory(cls, data: Any, exposed: bool) -> Self: """Create from an object providing a `__cuda_array_interface__`. No data is being copied. @@ -151,24 +179,15 @@ def _from_device_memory(cls, data: Any, exposed: bool) -> Self: If the resulting buffer has negative size """ - # Bypass `__init__` and initialize attributes manually - ret = cls.__new__(cls) - ret._owner = data - ret._exposed = exposed - ret._slices = weakref.WeakSet() if isinstance(data, rmm.DeviceBuffer): # Common case shortcut - ret._ptr = data.ptr - ret._size = data.size + ptr = data.ptr + size = data.size else: - ret._ptr, ret._size = get_ptr_and_size( - data.__cuda_array_interface__ - ) - if ret.size < 0: - raise ValueError("size cannot be negative") - return ret + ptr, size = get_ptr_and_size(data.__cuda_array_interface__) + return cls(ptr=ptr, size=size, owner=data, exposed=exposed) @classmethod - def _from_host_memory(cls, data: Any) -> Self: + def from_host_memory(cls, data: Any) -> Self: """Create an owner from a buffer or array like object Data must implement `__array_interface__`, the buffer protocol, and/or @@ -196,7 +215,7 @@ def _from_host_memory(cls, data: Any) -> Self: # Copy to device memory buf = rmm.DeviceBuffer(ptr=ptr, size=size) # Create from device memory - return cls._from_device_memory(buf, exposed=False) + return cls.from_device_memory(buf, exposed=False) @property def size(self) -> int: @@ -375,7 +394,7 @@ def copy(self, deep: bool = True) -> Self: ) # Otherwise, we create a new copy of the memory - owner = self._owner._from_device_memory( + owner = self._owner.from_device_memory( rmm.DeviceBuffer( ptr=self._owner.get_ptr(mode="read") + self._offset, size=self.size, @@ -439,9 +458,9 @@ def deserialize(cls, header: dict, frames: list) -> Self: owner_type: BufferOwner = pickle.loads(header["owner-type-serialized"]) if hasattr(frame, "__cuda_array_interface__"): - owner = owner_type._from_device_memory(frame, exposed=False) + owner = owner_type.from_device_memory(frame, exposed=False) else: - owner = owner_type._from_host_memory(frame) + owner = owner_type.from_host_memory(frame) return cls( owner=owner, offset=0, diff --git a/python/cudf/cudf/core/buffer/exposure_tracked_buffer.py b/python/cudf/cudf/core/buffer/exposure_tracked_buffer.py index 4c08016adbb..15f00fc670d 100644 --- a/python/cudf/cudf/core/buffer/exposure_tracked_buffer.py +++ b/python/cudf/cudf/core/buffer/exposure_tracked_buffer.py @@ -23,8 +23,6 @@ class ExposureTrackedBuffer(Buffer): The size of the slice (in bytes) """ - _owner: BufferOwner - def __init__( self, owner: BufferOwner, @@ -32,11 +30,7 @@ def __init__( size: Optional[int] = None, ) -> None: super().__init__(owner=owner, offset=offset, size=size) - self._owner._slices.add(self) - - @property - def exposed(self) -> bool: - return self._owner.exposed + self.owner._slices.add(self) def get_ptr(self, *, mode: Literal["read", "write"]) -> int: if mode == "write" and cudf.get_option("copy_on_write"): @@ -72,7 +66,7 @@ def copy(self, deep: bool = True) -> Self: copy-on-write option (see above). """ if cudf.get_option("copy_on_write"): - return super().copy(deep=deep or self.exposed) + return super().copy(deep=deep or self.owner.exposed) return super().copy(deep=deep) @property @@ -98,11 +92,11 @@ def make_single_owner_inplace(self) -> None: Buffer representing the same device memory as `data` """ - if len(self._owner._slices) > 1: - # If this is not the only slice pointing to `self._owner`, we - # point to a new deep copy of the owner. + if len(self.owner._slices) > 1: + # If this is not the only slice pointing to `self.owner`, we + # point to a new copy of our slice of `self.owner`. t = self.copy(deep=True) - self._owner = t._owner + self._owner = t.owner self._offset = t._offset self._size = t._size self._owner._slices.add(self) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 3e654e01401..cd81149bdb8 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -10,6 +10,7 @@ import warnings import weakref from collections import defaultdict +from contextlib import contextmanager from dataclasses import dataclass from functools import partial from typing import Dict, List, Optional, Tuple @@ -201,10 +202,6 @@ class SpillManager: This class implements tracking of all known spillable buffers, on-demand spilling of said buffers, and (optionally) maintains a memory usage limit. - When `spill_on_demand=True`, the manager registers an RMM out-of-memory - error handler, which will spill spillable buffers in order to free up - memory. - When `device_memory_limit=`, the manager will try keep the device memory usage below the specified limit by spilling of spillable buffers continuously, which will introduce a modest overhead. @@ -213,8 +210,6 @@ class SpillManager: Parameters ---------- - spill_on_demand : bool - Enable spill on demand. device_memory_limit: int, optional If not None, this is the device memory limit in bytes that triggers device to host spilling. The global manager sets this to the value @@ -230,30 +225,15 @@ class SpillManager: def __init__( self, *, - spill_on_demand: bool = False, device_memory_limit: Optional[int] = None, statistic_level: int = 0, ) -> None: self._lock = threading.Lock() self._buffers = weakref.WeakValueDictionary() self._id_counter = 0 - self._spill_on_demand = spill_on_demand self._device_memory_limit = device_memory_limit self.statistics = SpillStatistics(statistic_level) - if self._spill_on_demand: - # Set the RMM out-of-memory handle if not already set - mr = rmm.mr.get_current_device_resource() - if all( - not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) - for m in get_rmm_memory_resource_stack(mr) - ): - rmm.mr.set_current_device_resource( - rmm.mr.FailureCallbackResourceAdaptor( - mr, self._out_of_memory_handle - ) - ) - def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: """Try to handle an out-of-memory error by spilling @@ -408,8 +388,7 @@ def __repr__(self) -> str: dev_limit = format_bytes(self._device_memory_limit) return ( - f"" @@ -442,12 +421,82 @@ def get_global_manager() -> Optional[SpillManager]: """Get the global manager or None if spilling is disabled""" global _global_manager_uninitialized if _global_manager_uninitialized: - manager = None if get_option("spill"): manager = SpillManager( - spill_on_demand=get_option("spill_on_demand"), device_memory_limit=get_option("spill_device_limit"), statistic_level=get_option("spill_stats"), ) - set_global_manager(manager) + set_global_manager(manager) + if get_option("spill_on_demand"): + set_spill_on_demand_globally() + else: + set_global_manager(None) return _global_manager + + +def set_spill_on_demand_globally() -> None: + """Enable spill on demand in the current global spill manager. + + Warning: this modifies the current RMM memory resource. A memory resource + to handle out-of-memory errors is pushed onto the RMM memory resource stack. + + Raises + ------ + ValueError + If no global spill manager exists (spilling is disabled). + ValueError + If a failure callback resource is already in the resource stack. + """ + + manager = get_global_manager() + if manager is None: + raise ValueError( + "Cannot enable spill on demand with no global spill manager" + ) + mr = rmm.mr.get_current_device_resource() + if any( + isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) + for m in get_rmm_memory_resource_stack(mr) + ): + raise ValueError( + "Spill on demand (or another failure callback resource) " + "is already registered" + ) + rmm.mr.set_current_device_resource( + rmm.mr.FailureCallbackResourceAdaptor( + mr, manager._out_of_memory_handle + ) + ) + + +@contextmanager +def spill_on_demand_globally(): + """Context to enable spill on demand temporarily. + + Warning: this modifies the current RMM memory resource. A memory resource + to handle out-of-memory errors is pushed onto the RMM memory resource stack + when entering the context and popped again when exiting. + + Raises + ------ + ValueError + If no global spill manager exists (spilling is disabled). + ValueError + If a failure callback resource is already in the resource stack. + ValueError + If the RMM memory source stack was changed while in the context. + """ + set_spill_on_demand_globally() + # Save the new memory resource stack for later cleanup + mr_stack = get_rmm_memory_resource_stack( + rmm.mr.get_current_device_resource() + ) + try: + yield + finally: + mr = rmm.mr.get_current_device_resource() + if mr_stack != get_rmm_memory_resource_stack(mr): + raise ValueError( + "RMM memory source stack was changed while in the context" + ) + rmm.mr.set_current_device_resource(mr_stack[1]) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index a9569190e75..a1af3ba8c9d 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -20,6 +20,7 @@ cuda_array_interface_wrapper, host_memory_allocation, ) +from cudf.core.buffer.exposure_tracked_buffer import ExposureTrackedBuffer from cudf.utils.nvtx_annotation import _get_color_for_nvtx, annotate from cudf.utils.string import format_bytes @@ -93,8 +94,8 @@ class SpillableBufferOwner(BufferOwner): def _finalize_init(self, ptr_desc: Dict[str, Any]) -> None: """Finish initialization of the spillable buffer - This implements the common initialization that `_from_device_memory` - and `_from_host_memory` are missing. + This implements the common initialization that `from_device_memory` + and `from_host_memory` are missing. Parameters ---------- @@ -119,7 +120,7 @@ def _finalize_init(self, ptr_desc: Dict[str, Any]) -> None: self._manager.add(self) @classmethod - def _from_device_memory(cls, data: Any, exposed: bool) -> Self: + def from_device_memory(cls, data: Any, exposed: bool) -> Self: """Create a spillabe buffer from device memory. No data is being copied. @@ -136,12 +137,12 @@ def _from_device_memory(cls, data: Any, exposed: bool) -> Self: SpillableBufferOwner Buffer representing the same device memory as `data` """ - ret = super()._from_device_memory(data, exposed=exposed) + ret = super().from_device_memory(data, exposed=exposed) ret._finalize_init(ptr_desc={"type": "gpu"}) return ret @classmethod - def _from_host_memory(cls, data: Any) -> Self: + def from_host_memory(cls, data: Any) -> Self: """Create a spillabe buffer from host memory. Data must implement `__array_interface__`, the buffer protocol, and/or @@ -170,11 +171,7 @@ def _from_host_memory(cls, data: Any) -> Self: data = data.cast("B") # Make sure itemsize==1 # Create an already spilled buffer - ret = cls.__new__(cls) - ret._owner = None - ret._ptr = 0 - ret._size = data.nbytes - ret._exposed = False + ret = cls(ptr=0, size=data.nbytes, owner=None, exposed=False) ret._finalize_init(ptr_desc={"type": "cpu", "memoryview": data}) return ret @@ -372,21 +369,8 @@ def __str__(self) -> str: ) -class SpillableBuffer(Buffer): - """A slice of a spillable buffer - - This buffer applies the slicing and then delegates all - operations to its owning buffer. - - Parameters - ---------- - owner : SpillableBufferOwner - The owner of the view - offset : int - Memory offset into the owning buffer - size : int - Size of the view (in bytes) - """ +class SpillableBuffer(ExposureTrackedBuffer): + """A slice of a spillable buffer""" _owner: SpillableBufferOwner @@ -397,10 +381,6 @@ def spill(self, target: str = "cpu") -> None: def is_spilled(self) -> bool: return self._owner.is_spilled - @property - def exposed(self) -> bool: - return self._owner.exposed - @property def spillable(self) -> bool: return self._owner.spillable @@ -412,9 +392,6 @@ def memory_info(self) -> Tuple[int, int, str]: (ptr, _, device_type) = self._owner.memory_info() return (ptr + self._offset, self.nbytes, device_type) - def mark_exposed(self) -> None: - self._owner.mark_exposed() - def serialize(self) -> Tuple[dict, list]: """Serialize the Buffer @@ -449,7 +426,7 @@ def serialize(self) -> Tuple[dict, list]: ptr, size, _ = self.memory_info() frames = [ Buffer( - owner=BufferOwner._from_device_memory( + owner=BufferOwner.from_device_memory( cuda_array_interface_wrapper( ptr=ptr, size=size, @@ -461,6 +438,22 @@ def serialize(self) -> Tuple[dict, list]: ] return header, frames + def copy(self, deep: bool = True) -> Self: + from cudf.core.buffer.utils import acquire_spill_lock + + if not deep: + return super().copy(deep=False) + + if self.is_spilled: + # In this case, we make the new copy point to the same spilled + # data in host memory. We can do this since spilled data is never + # modified. + owner = self._owner.from_host_memory(self.memoryview()) + return self.__class__(owner=owner, offset=0, size=owner.size) + + with acquire_spill_lock(): + return super().copy(deep=deep) + @property def __cuda_array_interface__(self) -> dict: return { diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index c2ec7effd13..3346d05ed4a 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -133,13 +133,13 @@ def as_buffer( if not hasattr(data, "__cuda_array_interface__"): if exposed: raise ValueError("cannot created exposed host memory") - return buffer_class(owner=owner_class._from_host_memory(data)) + return buffer_class(owner=owner_class.from_host_memory(data)) # Check if `data` is owned by a known class owner = get_buffer_owner(data) if owner is None: # `data` is new device memory return buffer_class( - owner=owner_class._from_device_memory(data, exposed=exposed) + owner=owner_class.from_device_memory(data, exposed=exposed) ) # At this point, we know that `data` is owned by a known class, which diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index 7a0db49bd20..efa8eabd8b8 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. import os import textwrap @@ -152,11 +152,6 @@ def _validator(val): def _cow_validator(val): - if get_option("spill") and val: - raise ValueError( - "Copy-on-write is not supported when spilling is enabled. " - "Please set `spill` to `False`" - ) if val not in {False, True}: raise ValueError( f"{val} is not a valid option. Must be one of {{False, True}}." @@ -164,14 +159,6 @@ def _cow_validator(val): def _spill_validator(val): - try: - if get_option("copy_on_write") and val: - raise ValueError( - "Spilling is not supported when copy-on-write is enabled. " - "Please set `copy_on_write` to `False`" - ) - except KeyError: - pass if val not in {False, True}: raise ValueError( f"{val} is not a valid option. Must be one of {{False, True}}." diff --git a/python/cudf/cudf/tests/test_copying.py b/python/cudf/cudf/tests/test_copying.py index e737a73e86b..0bc9ffa8004 100644 --- a/python/cudf/cudf/tests/test_copying.py +++ b/python/cudf/cudf/tests/test_copying.py @@ -7,8 +7,11 @@ import cudf from cudf import Series +from cudf.core.buffer.spill_manager import get_global_manager from cudf.testing._utils import NUMERIC_TYPES, OTHER_TYPES, assert_eq +pytestmark = pytest.mark.spilling + @pytest.mark.parametrize("dtype", NUMERIC_TYPES + OTHER_TYPES) def test_repeat(dtype): @@ -302,6 +305,8 @@ def test_series_zero_copy_cow_on(): def test_series_zero_copy_cow_off(): + is_spill_enabled = get_global_manager() is not None + with cudf.option_context("copy_on_write", False): s = cudf.Series([1, 2, 3, 4, 5]) s1 = s.copy(deep=False) @@ -334,8 +339,12 @@ def test_series_zero_copy_cow_off(): assert_eq(s, cudf.Series([20, 10, 10, 4, 5])) assert_eq(s1, cudf.Series([20, 10, 10, 4, 5])) assert_eq(cp_array, cp.array([20, 10, 10, 4, 5])) - assert_eq(s2, cudf.Series([20, 10, 10, 4, 5])) - assert_eq(s3, cudf.Series([20, 10, 10, 4, 5])) + if not is_spill_enabled: + # Since spilling might make a copy of the data, we cannot + # expect the two series to be a zero-copy of the cupy array + # when spilling is enabled globally. + assert_eq(s2, cudf.Series([20, 10, 10, 4, 5])) + assert_eq(s3, cudf.Series([20, 10, 10, 4, 5])) s4 = cudf.Series([10, 20, 30, 40, 50]) s5 = cudf.Series(s4) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index f18cb32a091..913a958b4c2 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -32,6 +32,7 @@ get_global_manager, get_rmm_memory_resource_stack, set_global_manager, + spill_on_demand_globally, ) from cudf.core.buffer.spillable_buffer import ( SpillableBuffer, @@ -47,6 +48,22 @@ ) +@contextlib.contextmanager +def set_rmm_memory_pool(nbytes: int): + mr = rmm.mr.get_current_device_resource() + rmm.mr.set_current_device_resource( + rmm.mr.PoolMemoryResource( + mr, + initial_pool_size=nbytes, + maximum_pool_size=nbytes, + ) + ) + try: + yield + finally: + rmm.mr.set_current_device_resource(mr) + + def single_column_df(target="gpu") -> cudf.DataFrame: """Create a standard single column dataframe used for testing @@ -120,18 +137,18 @@ def test_spillable_buffer(manager: SpillManager): buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) assert isinstance(buf, SpillableBuffer) assert buf.spillable - buf.mark_exposed() - assert buf.exposed + buf.owner.mark_exposed() + assert buf.owner.exposed assert not buf.spillable buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) # Notice, accessing `__cuda_array_interface__` itself doesn't # expose the pointer, only accessing the "data" field exposes # the pointer. iface = buf.__cuda_array_interface__ - assert not buf.exposed + assert not buf.owner.exposed assert buf.spillable iface["data"][0] # Expose pointer - assert buf.exposed + assert buf.owner.exposed assert not buf.spillable @@ -141,7 +158,6 @@ def test_spillable_buffer(manager: SpillManager): "get_ptr", "memoryview", "is_spilled", - "exposed", "spillable", "spill_lock", "spill", @@ -210,7 +226,7 @@ def test_spilling_buffer(manager: SpillManager): buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) buf.spill(target="cpu") assert buf.is_spilled - buf.mark_exposed() # Expose pointer and trigger unspill + buf.owner.mark_exposed() # Expose pointer and trigger unspill assert not buf.is_spilled with pytest.raises(ValueError, match="unspillable buffer"): buf.spill(target="cpu") @@ -237,7 +253,7 @@ def _get_manager_in_env(monkeypatch, var_vals): def test_environment_variables_spill_off(monkeypatch): with _get_manager_in_env( monkeypatch, - [("CUDF_SPILL", "off"), ("CUDF_SPILL_ON_DEMAND", "off")], + [("CUDF_SPILL", "off")], ) as manager: assert manager is None @@ -245,10 +261,9 @@ def test_environment_variables_spill_off(monkeypatch): def test_environment_variables_spill_on(monkeypatch): with _get_manager_in_env( monkeypatch, - [("CUDF_SPILL", "on")], + [("CUDF_SPILL", "on"), ("CUDF_SPILL_ON_DEMAND", "off")], ) as manager: assert isinstance(manager, SpillManager) - assert manager._spill_on_demand is True assert manager._device_memory_limit is None assert manager.statistics.level == 0 @@ -256,7 +271,11 @@ def test_environment_variables_spill_on(monkeypatch): def test_environment_variables_device_limit(monkeypatch): with _get_manager_in_env( monkeypatch, - [("CUDF_SPILL", "on"), ("CUDF_SPILL_DEVICE_LIMIT", "1000")], + [ + ("CUDF_SPILL", "on"), + ("CUDF_SPILL_ON_DEMAND", "off"), + ("CUDF_SPILL_DEVICE_LIMIT", "1000"), + ], ) as manager: assert isinstance(manager, SpillManager) assert manager._device_memory_limit == 1000 @@ -269,6 +288,7 @@ def test_environment_variables_spill_stats(monkeypatch, level): monkeypatch, [ ("CUDF_SPILL", "on"), + ("CUDF_SPILL_ON_DEMAND", "off"), ("CUDF_SPILL_DEVICE_LIMIT", "1000"), ("CUDF_SPILL_STATS", f"{level}"), ], @@ -529,12 +549,8 @@ def test_serialize_cuda_dataframe(manager: SpillManager): assert_eq(df1, df2) -@pytest.mark.skip( - reason="This test is not safe because other tests may have enabled" - "spilling and already modified rmm's global state" -) def test_get_rmm_memory_resource_stack(): - mr1 = rmm.mr.get_current_device_resource() + mr1 = rmm.mr.CudaMemoryResource() assert all( not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) for m in get_rmm_memory_resource_stack(mr1) @@ -560,9 +576,9 @@ def test_df_transpose(manager: SpillManager): df1 = cudf.DataFrame({"a": [1, 2]}) df2 = df1.transpose() # For now, all buffers are marked as exposed - assert df1._data._data["a"].data.exposed - assert df2._data._data[0].data.exposed - assert df2._data._data[1].data.exposed + assert df1._data._data["a"].data.owner.exposed + assert df2._data._data[0].data.owner.exposed + assert df2._data._data[1].data.owner.exposed def test_as_buffer_of_spillable_buffer(manager: SpillManager): @@ -651,7 +667,7 @@ def test_statistics_expose(manager: SpillManager): ] # Expose the first buffer - buffers[0].mark_exposed() + buffers[0].owner.mark_exposed() assert len(manager.statistics.exposes) == 1 stat = list(manager.statistics.exposes.values())[0] assert stat.count == 1 @@ -660,7 +676,7 @@ def test_statistics_expose(manager: SpillManager): # Expose all 10 buffers for i in range(10): - buffers[i].mark_exposed() + buffers[i].owner.mark_exposed() # The rest of the ptr accesses should accumulate to a single stat # because they resolve to the same traceback. @@ -680,9 +696,91 @@ def test_statistics_expose(manager: SpillManager): # Expose the new buffers and check that they are counted as spilled for i in range(10): - buffers[i].mark_exposed() + buffers[i].owner.mark_exposed() assert len(manager.statistics.exposes) == 3 stat = list(manager.statistics.exposes.values())[2] assert stat.count == 10 assert stat.total_nbytes == buffers[0].nbytes * 10 assert stat.spilled_nbytes == buffers[0].nbytes * 10 + + +def test_spill_on_demand(manager: SpillManager): + with set_rmm_memory_pool(1024): + a = as_buffer(data=rmm.DeviceBuffer(size=1024)) + assert isinstance(a, SpillableBuffer) + assert not a.is_spilled + + with pytest.raises(MemoryError, match="Maximum pool size exceeded"): + as_buffer(data=rmm.DeviceBuffer(size=1024)) + + with spill_on_demand_globally(): + b = as_buffer(data=rmm.DeviceBuffer(size=1024)) + assert a.is_spilled + assert not b.is_spilled + + with pytest.raises(MemoryError, match="Maximum pool size exceeded"): + as_buffer(data=rmm.DeviceBuffer(size=1024)) + + +def test_spilling_and_copy_on_write(manager: SpillManager): + with cudf.option_context("copy_on_write", True): + a: SpillableBuffer = as_buffer(data=rmm.DeviceBuffer(size=10)) + + b = a.copy(deep=False) + assert a.owner == b.owner + a.spill(target="cpu") + assert a.is_spilled + assert b.is_spilled + + # Write access trigger copy of `a` into `b` but since `a` is spilled + # the copy is done in host memory and `a` remains spilled. + with acquire_spill_lock(): + b.get_ptr(mode="write") + assert a.is_spilled + assert not b.is_spilled + + # Deep copy of the spilled buffer `a` + b = a.copy(deep=True) + assert a.owner != b.owner + assert a.is_spilled + assert b.is_spilled + a.spill(target="gpu") + assert not a.is_spilled + assert b.is_spilled + + # Deep copy of the unspilled buffer `a` + b = a.copy(deep=True) + assert a.spillable + assert not a.is_spilled + assert not b.is_spilled + + b = a.copy(deep=False) + assert a.owner == b.owner + # Write access trigger copy of `a` into `b` in device memory + with acquire_spill_lock(): + b.get_ptr(mode="write") + assert a.owner != b.owner + assert not a.is_spilled + assert not b.is_spilled + # And `a` and `b` is now seperated with there one spilling status + a.spill(target="cpu") + assert a.is_spilled + assert not b.is_spilled + b.spill(target="cpu") + assert a.is_spilled + assert b.is_spilled + + # Read access with a spill lock unspill `a` and allows copy-on-write + with acquire_spill_lock(): + a.get_ptr(mode="read") + b = a.copy(deep=False) + assert a.owner == b.owner + assert not a.is_spilled + + # Read access without a spill lock exposes `a` and forces a deep copy + a.get_ptr(mode="read") + b = a.copy(deep=False) + assert a.owner != b.owner + assert not a.is_spilled + assert a.owner.exposed + assert not b.owner.exposed