Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Nov 21, 2024
1 parent aaabb82 commit f61e92b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
37 changes: 23 additions & 14 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _env_get_int(name, default):


@cache
def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource:
def default_memory_resource(
device: int,
cuda_managed_memory: bool, # noqa: FBT001
) -> rmm.mr.DeviceMemoryResource:
"""
Return the default memory resource for cudf-polars.
Expand All @@ -58,31 +61,35 @@ def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource:
device
Disambiguating device id when selecting the device. Must be
the active device when this function is called.
cuda_managed_memory
Whether to use managed memory or not.
Returns
-------
rmm.mr.DeviceMemoryResource
The default memory resource that cudf-polars uses. Currently
a managed memory resource, if `POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY`
environment variable is set to `0`, then an async pool resource is returned.
a managed memory resource, if `cuda_managed_memory` is `True`.
else, an async pool resource is returned.
bool
A flag indicating whether to enable prefetching.
"""
try:
if (
_env_get_int("POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY", default=1) == 1
cuda_managed_memory
and pylibcudf.utils._is_concurrent_managed_access_supported()
):
free_memory, _ = rmm.mr.available_device_memory()
free_memory = int(round(float(free_memory) * 0.80 / 256) * 256)
return rmm.mr.PrefetchResourceAdaptor(
for key in _SUPPORTED_PREFETCHES:
pylibcudf.experimental.enable_prefetching(key)
mr = rmm.mr.PrefetchResourceAdaptor(
rmm.mr.PoolMemoryResource(
rmm.mr.ManagedMemoryResource(),
initial_pool_size=free_memory,
)
), True
)
else:
return rmm.mr.CudaAsyncMemoryResource(), False
mr = rmm.mr.CudaAsyncMemoryResource()
except RuntimeError as e: # pragma: no cover
msg, *_ = e.args
if (
Expand All @@ -96,6 +103,8 @@ def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource:
) from None
else:
raise
else:
return mr


@contextlib.contextmanager
Expand Down Expand Up @@ -124,13 +133,13 @@ def set_memory_resource(
previous = rmm.mr.get_current_device_resource()
if mr is None:
device: int = gpu.getDevice()
mr, enable_prefetching = default_memory_resource(device)
rmm.mr.set_current_device_resource(mr)
if enable_prefetching:
for key in _SUPPORTED_PREFETCHES:
pylibcudf.experimental.enable_prefetching(key)
else:
rmm.mr.set_current_device_resource(mr)
mr = default_memory_resource(
device=device,
cuda_managed_memory=bool(
_env_get_int("POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY", default=1) != 0
),
)
rmm.mr.set_current_device_resource(mr)
try:
yield mr
finally:
Expand Down
8 changes: 8 additions & 0 deletions python/cudf_polars/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import rmm

from cudf_polars.callback import default_memory_resource
from cudf_polars.dsl.ir import DataFrameScan
from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
Expand Down Expand Up @@ -67,6 +68,13 @@ def test_cudf_polars_enable_disable_managed_memory(monkeypatch, disable_managed_
"POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY", disable_managed_memory
)
result = q.collect(engine=pl.GPUEngine())
mr = default_memory_resource(0, bool(disable_managed_memory == "1"))
if disable_managed_memory == "1":
assert isinstance(mr, rmm.mr.PrefetchResourceAdaptor)
assert isinstance(mr.upstream_mr, rmm.mr.PoolMemoryResource)
else:
assert isinstance(mr, rmm.mr.CudaAsyncMemoryResource)
monkeycontext.delenv("POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY")
assert_frame_equal(q.collect(), result)


Expand Down

0 comments on commit f61e92b

Please sign in to comment.