Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Add a new memory type: Hierarchy #227

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum wholememory_memory_type_t {
WHOLEMEMORY_MT_CONTINUOUS, /*!< Memory from all ranks are mapped in continuous address space */
WHOLEMEMORY_MT_CHUNKED, /*!< Memory from all ranks are mapped in chunked address space */
WHOLEMEMORY_MT_DISTRIBUTED, /*!< Memory from other ranks are not mapped. */
WHOLEMEMORY_MT_HIERARCHY, /*!< Memory from other ranks are mapped in hierarchy address space */
};

/**
Expand Down Expand Up @@ -206,6 +207,23 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor
*/
wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm);

/**
* Get the local rank size of current process in the WholeMemory Communicator
* @param local_size : returned local rank size
* @param comm : WholeMemory Communicator
* @return : wholememory_error_code_t
*/

wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size,
wholememory_comm_t comm);

/**
* Get the clique info of WholeMemory Communicator
* @param clique_info : returned clique info
* @param comm : WholeMemory Communicator
* @return : wholememory_error_code_t
*/

wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm);

Expand Down Expand Up @@ -265,6 +283,25 @@ wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handl
wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle);

/**
* Get underlying Wholememory Local Communicator for "Hierarchy" memory type from WholeMemory Handle
* @param comm : returned Local WholeMemory Communicator
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_get_local_communicator(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle);

/**
* Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle
* One comminicator includes all rank with a same local id from different nodes
* @param comm : returned Cross WholeMemory Communicator
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_get_cross_communicator(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle);

/**
* Get WholeMemory Type
* @param wholememory_handle : WholeMemory Handle
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/wholememory/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,13 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t communicator_get_local_size(int* local_size,
wholememory_comm_t comm) noexcept
{
*local_size = comm->intra_node_rank_num;
return WHOLEMEMORY_SUCCESS;
}

// wholememory_error_code_t communicator_get_clique_rank(int* clique_rank,
// wholememory_comm_t comm) noexcept
// {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/wholememory/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ wholememory_error_code_t communicator_get_rank(int* rank, wholememory_comm_t com

wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t comm) noexcept;

wholememory_error_code_t communicator_get_local_size(int* local_size,
wholememory_comm_t comm) noexcept;

wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm) noexcept;

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/wholememory/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,9 @@ wholememory_error_code_t wholememory_create_embedding(
int embedding_world_size = 1;
WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm));
if (cache_policy != nullptr) {
if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
WHOLEMEMORY_ERROR("Cache is not supported now in hierarchy memory type.");
}
if (cache_policy->cache_comm == comm) {
if (cache_policy->cache_memory_location != WHOLEMEMORY_ML_DEVICE) {
WHOLEMEMORY_ERROR(
Expand Down
94 changes: 91 additions & 3 deletions cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class wholememory_impl {
return gref;
}
virtual bool contains_pointer(const void* ptr) const = 0;
void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const
virtual void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const
{
if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_;
if (local_size != nullptr) *local_size = get_local_size();
Expand All @@ -128,7 +128,7 @@ class wholememory_impl {
*rank_memory_offset = 0;
return false;
}
[[nodiscard]] size_t get_partition_stride() const
[[nodiscard]] virtual size_t get_partition_stride() const
{
return rank_partition_strategy_.partition_mem_stride;
}
Expand Down Expand Up @@ -326,7 +326,7 @@ class distributed_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY);
}
void create_memory() override
{
Expand Down Expand Up @@ -647,6 +647,12 @@ class continuous_device_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
// printf(
// "while in continuous device wholememory creation, the memory_type (%d) and memory_location
// "
// "(%d).\n",
// (int)memory_type,
// (int)memory_location);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS);
}
void create_memory() override
Expand Down Expand Up @@ -1747,6 +1753,43 @@ struct wholememory_create_param {
size_t min_granularity;
};

class hierarchy_wholememory_impl : public distributed_wholememory_impl {
public:
hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle,
size_t total_size,
wholememory_comm_t global_comm,
wholememory_comm_t local_comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
size_t data_granularity,
size_t* rank_entry_partition)
: distributed_wholememory_impl(wholememory_handle,
total_size,
global_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY);
local_comm_ = local_comm;
int world_rank = -1, world_size = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, global_comm);
wholememory_communicator_get_size(&world_size, global_comm);
wholememory_communicator_get_size(&local_size, local_comm);
WHOLEMEMORY_CHECK(world_size % local_size == 0);
wholememory_split_communicator(
&cross_comm_, global_comm, world_rank % local_size, world_rank / local_size);
}

[[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; }
[[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; }

protected:
wholememory_comm_t local_comm_;
wholememory_comm_t cross_comm_;
};

wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr,
size_t total_size,
wholememory_comm_t comm,
Expand Down Expand Up @@ -1853,6 +1896,21 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha
data_granularity,
rank_entry_partition);
}
} else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
wholememory_comm_t local_comm;
int world_rank = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, comm);
wholememory_communicator_get_local_size(&local_size, comm);
wholememory_split_communicator(
&local_comm, comm, world_rank / local_size, world_rank % local_size);
whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle,
total_size,
comm,
local_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition);
} else {
WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).",
(int)memory_type,
Expand Down Expand Up @@ -1928,6 +1986,36 @@ wholememory_error_code_t get_communicator_from_handle(
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_local_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept
{
if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) {
return WHOLEMEMORY_INVALID_INPUT;
}
if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) {
return WHOLEMEMORY_NOT_SUPPORTED;
}
hierarchy_wholememory_impl* hierarchy_impl =
dynamic_cast<hierarchy_wholememory_impl*>(wholememory_handle->impl);
*comm = hierarchy_impl->get_local_comm();
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_cross_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept
{
if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) {
return WHOLEMEMORY_INVALID_INPUT;
}
if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) {
return WHOLEMEMORY_NOT_SUPPORTED;
}
hierarchy_wholememory_impl* hierarchy_impl =
dynamic_cast<hierarchy_wholememory_impl*>(wholememory_handle->impl);
*comm = hierarchy_impl->get_cross_comm();
return WHOLEMEMORY_SUCCESS;
}

wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept
{
return wholememory_handle->impl->get_type();
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/wholememory/memory_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ wholememory_error_code_t destroy_wholememory(wholememory_handle_t wholememory_ha
wholememory_error_code_t get_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_local_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_cross_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept;

wholememory_memory_location_t get_memory_location(wholememory_handle_t wholememory_handle) noexcept;
Expand All @@ -65,6 +71,12 @@ wholememory_error_code_t get_local_memory_from_handle(
size_t* local_offset,
wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_local_node_memory_from_handle(
void** local_ptr,
size_t* local_size,
size_t* local_offset,
wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_rank_memory_from_handle(
void** rank_memory_ptr,
size_t* rank_memory_size,
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememor
{
return wholememory::communicator_get_size(size, comm);
}

wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size,
wholememory_comm_t comm)
{
return wholememory::communicator_get_local_size(local_size, comm);
}

bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm)
{
#ifdef WITH_NVSHMEM_SUPPORT
Expand Down Expand Up @@ -130,6 +137,18 @@ wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm,
return wholememory::get_communicator_from_handle(comm, wholememory_handle);
}

wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle)
{
return wholememory::get_local_communicator_from_handle(comm, wholememory_handle);
}

wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle)
{
return wholememory::get_cross_communicator_from_handle(comm, wholememory_handle);
}

wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle)
{
return wholememory::get_memory_type(wholememory_handle);
Expand Down
Loading
Loading