-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Store and set the correct CUDA device in device_buffer (#1370)
This changes `device_buffer` to store the active CUDA device ID on creation, and (possibly temporarily) set the active device to that ID before allocating or freeing memory. It also adds tests for containers built on `device_buffer` (`device_buffer`, `device_uvector` and `device_scalar`) that ensure correct operation when the device is changed before doing things that alloc/dealloc memory for those containers. This fixes #1342 . HOWEVER, there is an important question yet to answer: `rmm::device_vector` is just an alias for `thrust::device_vector`, which does not use `rmm::device_buffer` for storage. However users may be surprised after this PR because the multidevice semantics of RMM containers will be different from `thrust::device_vector` (and therefore `rmm::device_vector`). Update: opinion is that it's probably OK to diverge from `device_vector`, and some think we should remove `rmm::device_vector`. ~While we discuss this I have set the DO NOT MERGE label.~ Authors: - Mark Harris (https://github.com/harrism) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Jake Hemstad (https://github.com/jrhemstad) URL: #1370
- Loading branch information
Showing
6 changed files
with
283 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "device_check_resource_adaptor.hpp" | ||
#include "rmm/mr/device/per_device_resource.hpp" | ||
|
||
#include <rmm/cuda_stream.hpp> | ||
#include <rmm/device_buffer.hpp> | ||
#include <rmm/device_scalar.hpp> | ||
#include <rmm/device_uvector.hpp> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <type_traits> | ||
|
||
template <typename ContainerType> | ||
struct ContainerMultiDeviceTest : public ::testing::Test {}; | ||
|
||
using containers = | ||
::testing::Types<rmm::device_buffer, rmm::device_uvector<int>, rmm::device_scalar<int>>; | ||
|
||
TYPED_TEST_CASE(ContainerMultiDeviceTest, containers); | ||
|
||
TYPED_TEST(ContainerMultiDeviceTest, CreateDestroyDifferentActiveDevice) | ||
{ | ||
// Get the number of cuda devices | ||
int num_devices = rmm::get_num_cuda_devices(); | ||
|
||
// only run on multidevice systems | ||
if (num_devices >= 2) { | ||
rmm::cuda_set_device_raii dev{rmm::cuda_device_id{0}}; | ||
auto* orig_mr = rmm::mr::get_current_device_resource(); | ||
auto check_mr = device_check_resource_adaptor{orig_mr}; | ||
rmm::mr::set_current_device_resource(&check_mr); | ||
|
||
{ | ||
if constexpr (std::is_same_v<TypeParam, rmm::device_scalar<int>>) { | ||
auto buf = TypeParam(rmm::cuda_stream_view{}); | ||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(1)); // force dtor with different active device | ||
} else { | ||
auto buf = TypeParam(128, rmm::cuda_stream_view{}); | ||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(1)); // force dtor with different active device | ||
} | ||
} | ||
|
||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(0)); | ||
rmm::mr::set_current_device_resource(orig_mr); | ||
} | ||
} | ||
|
||
TYPED_TEST(ContainerMultiDeviceTest, CreateMoveDestroyDifferentActiveDevice) | ||
{ | ||
// Get the number of cuda devices | ||
int num_devices = rmm::get_num_cuda_devices(); | ||
|
||
// only run on multidevice systems | ||
if (num_devices >= 2) { | ||
rmm::cuda_set_device_raii dev{rmm::cuda_device_id{0}}; | ||
auto* orig_mr = rmm::mr::get_current_device_resource(); | ||
auto check_mr = device_check_resource_adaptor{orig_mr}; | ||
rmm::mr::set_current_device_resource(&check_mr); | ||
|
||
{ | ||
auto buf_1 = []() { | ||
if constexpr (std::is_same_v<TypeParam, rmm::device_scalar<int>>) { | ||
return TypeParam(rmm::cuda_stream_view{}); | ||
} else { | ||
return TypeParam(128, rmm::cuda_stream_view{}); | ||
} | ||
}(); | ||
|
||
{ | ||
if constexpr (std::is_same_v<TypeParam, rmm::device_scalar<int>>) { | ||
// device_vector does not have a constructor that takes a stream | ||
auto buf_0 = TypeParam(rmm::cuda_stream_view{}); | ||
buf_1 = std::move(buf_0); | ||
} else { | ||
auto buf_0 = TypeParam(128, rmm::cuda_stream_view{}); | ||
buf_1 = std::move(buf_0); | ||
} | ||
} | ||
|
||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(1)); // force dtor with different active device | ||
} | ||
|
||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(0)); | ||
rmm::mr::set_current_device_resource(orig_mr); | ||
} | ||
} | ||
|
||
TYPED_TEST(ContainerMultiDeviceTest, ResizeDifferentActiveDevice) | ||
{ | ||
// Get the number of cuda devices | ||
int num_devices = rmm::get_num_cuda_devices(); | ||
|
||
// only run on multidevice systems | ||
if (num_devices >= 2) { | ||
rmm::cuda_set_device_raii dev{rmm::cuda_device_id{0}}; | ||
auto* orig_mr = rmm::mr::get_current_device_resource(); | ||
auto check_mr = device_check_resource_adaptor{orig_mr}; | ||
rmm::mr::set_current_device_resource(&check_mr); | ||
|
||
if constexpr (not std::is_same_v<TypeParam, rmm::device_scalar<int>>) { | ||
auto buf = TypeParam(128, rmm::cuda_stream_view{}); | ||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(1)); // force resize with different active device | ||
buf.resize(1024, rmm::cuda_stream_view{}); | ||
} | ||
|
||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(0)); | ||
rmm::mr::set_current_device_resource(orig_mr); | ||
} | ||
} | ||
|
||
TYPED_TEST(ContainerMultiDeviceTest, ShrinkDifferentActiveDevice) | ||
{ | ||
// Get the number of cuda devices | ||
int num_devices = rmm::get_num_cuda_devices(); | ||
|
||
// only run on multidevice systems | ||
if (num_devices >= 2) { | ||
rmm::cuda_set_device_raii dev{rmm::cuda_device_id{0}}; | ||
auto* orig_mr = rmm::mr::get_current_device_resource(); | ||
auto check_mr = device_check_resource_adaptor{orig_mr}; | ||
rmm::mr::set_current_device_resource(&check_mr); | ||
|
||
if constexpr (not std::is_same_v<TypeParam, rmm::device_scalar<int>>) { | ||
auto buf = TypeParam(128, rmm::cuda_stream_view{}); | ||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(1)); // force resize with different active device | ||
buf.resize(64, rmm::cuda_stream_view{}); | ||
buf.shrink_to_fit(rmm::cuda_stream_view{}); | ||
} | ||
|
||
RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(0)); | ||
rmm::mr::set_current_device_resource(orig_mr); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.