Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace GEMM backend: cublas.gemm -> cublaslt.matmul #1736

Merged
merged 52 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
2cc477b
Replace GEMM backend: cublas.gemm -> cublaslt.matmul
achirkin Aug 14, 2023
dc7a9a4
Replace broken (due to missing direct includes) direct uses of cublas…
achirkin Aug 14, 2023
34a9479
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 15, 2023
71c03c0
Fix docs
achirkin Aug 15, 2023
a2fb088
Replace cublasgemm where it makes sense
achirkin Aug 16, 2023
699de0c
Fix a typo
achirkin Aug 16, 2023
f994f19
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 16, 2023
f4d634a
Put the cache into the resource handle as a user-define resource
achirkin Aug 21, 2023
2d1bf5c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 22, 2023
e57eebf
Move matmul into a separate file
achirkin Aug 22, 2023
d44bf20
Complete the docs
achirkin Aug 22, 2023
facf81d
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 23, 2023
157d8ae
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 24, 2023
be68b61
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 24, 2023
f5ac41a
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 28, 2023
2d4dcb2
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 29, 2023
6f58669
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
a0e93fd
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
4c0d742
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
01c3634
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
abb3f00
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
e24b1c0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 31, 2023
de29580
move matmul.hpp to cublaslt_wrappers.hpp
achirkin Aug 31, 2023
3835ed0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 31, 2023
de60202
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 1, 2023
fe84fae
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 5, 2023
f47626a
Merge branch 'branch-23.10' into fea-cublaslt-matmul
cjnolet Sep 6, 2023
d7efc0c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 7, 2023
dd7ee22
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 8, 2023
01e62b0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 9, 2023
8fdf6cc
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 13, 2023
324f5c6
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 19, 2023
ba6883f
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 19, 2023
a56ea2c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Nov 20, 2023
cd4663a
Merge branch 'branch-23.12' into fea-cublaslt-matmul
achirkin Nov 20, 2023
c2f1daa
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Dec 14, 2023
c976de0
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Dec 15, 2023
a5de437
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 8, 2024
7849786
Update copyright year for changed files
achirkin Jan 8, 2024
9bec3cf
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 12, 2024
ceb8d10
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 15, 2024
1f39534
Deprecate linalg/gemm.cuh
achirkin Jan 15, 2024
b2e3b8b
Update copyright years
achirkin Jan 15, 2024
05c64fc
Rename user_resource -> custom_resource
achirkin Jan 15, 2024
fdbe003
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 17, 2024
9e08c0f
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 17, 2024
97f1d49
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 18, 2024
6164e4f
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 19, 2024
f6ded84
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 19, 2024
88ecbb0
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 22, 2024
ca11f9f
Use plain the vector instead of the unordered_map for the cache and c…
achirkin Jan 23, 2024
47303b7
Merge branch 'branch-24.02' into fea-cublaslt-matmul
achirkin Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/include/raft/core/resource/cublas_handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,8 +60,8 @@ class cublas_resource_factory : public resource_factory {
*/

/**
* Load a cublasres_t from raft res if it exists, otherwise
* add it and return it.
* Load a `cublasHandle_t` from raft res if it exists, otherwise add it and return it.
*
* @param[in] res the raft resources object
* @return cublas handle
*/
Expand Down
68 changes: 68 additions & 0 deletions cpp/include/raft/core/resource/cublaslt_handle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cublasLt.h>
#include <raft/core/cublas_macros.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <memory>

namespace raft::resource {

class cublaslt_resource : public resource {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
public:
cublaslt_resource() { RAFT_CUBLAS_TRY(cublasLtCreate(&handle_)); }
~cublaslt_resource() noexcept override { RAFT_CUBLAS_TRY_NO_THROW(cublasLtDestroy(handle_)); }
auto get_resource() -> void* override { return &handle_; }

private:
cublasLtHandle_t handle_;
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class cublaslt_resource_factory : public resource_factory {
public:
auto get_resource_type() -> resource_type override { return resource_type::CUBLASLT_HANDLE; }
auto make_resource() -> resource* override { return new cublaslt_resource(); }
};

/**
* @defgroup resource_cublaslt cuBLASLt handle resource functions
* @{
*/

/**
* Load a `cublasLtHandle_t` from raft res if it exists, otherwise add it and return it.
*
* @param[in] res the raft resources object
* @return cublasLt handle
*/
inline auto get_cublaslt_handle(resources const& res) -> cublasLtHandle_t
{
if (!res.has_resource_factory(resource_type::CUBLASLT_HANDLE)) {
res.add_resource_factory(std::make_shared<cublaslt_resource_factory>());
}
auto ret = *res.get_resource<cublasLtHandle_t>(resource_type::CUBLASLT_HANDLE);
return ret;
};

/**
* @}
*/

} // namespace raft::resource
93 changes: 93 additions & 0 deletions cpp/include/raft/core/resource/custom_resource.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <algorithm>
#include <memory>
#include <typeindex>

namespace raft::resource {

class custom_resource : public resource {
public:
custom_resource() = default;
~custom_resource() noexcept override = default;
auto get_resource() -> void* override { return this; }

template <typename ResourceT>
auto load() -> ResourceT*
{
std::lock_guard<std::mutex> _(lock_);
auto key = std::type_index{typeid(ResourceT)};
auto pos = std::lower_bound(store_.begin(), store_.end(), kv{key, {nullptr}});
if ((pos != store_.end()) && std::get<0>(*pos) == key) {
return reinterpret_cast<ResourceT*>(std::get<1>(*pos).get());
}
auto store_ptr = new ResourceT{};
store_.insert(pos, kv{key, std::shared_ptr<void>(store_ptr, [](void* ptr) {
delete reinterpret_cast<ResourceT*>(ptr);
})});
return store_ptr;
}

private:
using kv = std::tuple<std::type_index, std::shared_ptr<void>>;
std::mutex lock_{};
std::vector<kv> store_{};
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class custom_resource_factory : public resource_factory {
public:
auto get_resource_type() -> resource_type override { return resource_type::CUSTOM; }
auto make_resource() -> resource* override { return new custom_resource(); }
};

/**
* @defgroup resource_custom custom resource functions
* @{
*/

/**
* Get the custom default-constructible resource if it exists, create it otherwise.
*
* Note: in contrast to the other, hard-coded resources, there's no information about the custom
* resources at compile time. Hence, custom resources are kept in a hashmap and looked-up at
* runtime. This leads to slightly slower access times.
*
* @tparam ResourceT the type of the resource; it must be complete and default-constructible.
*
* @param[in] res the raft resources object
* @return a pointer to the custom resource.
*/
template <typename ResourceT>
auto get_custom_resource(resources const& res) -> ResourceT*
{
static_assert(std::is_default_constructible_v<ResourceT>);
if (!res.has_resource_factory(resource_type::CUSTOM)) {
res.add_resource_factory(std::make_shared<custom_resource_factory>());
}
return res.get_resource<custom_resource>(resource_type::CUSTOM)->load<ResourceT>();
};

/**
* @}
*/

} // namespace raft::resource
4 changes: 3 additions & 1 deletion cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,8 @@ enum resource_type {
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource

LAST_KEY // reserved for the last key
};
Expand Down
Loading