Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update wholegraph #65

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ function(ConfigureBench)
rmm::rmm
pthread
)

if(BUILD_WITH_NVSHMEM)
target_compile_definitions(${BENCH_NAME} PRIVATE WITH_NVSHMEM_SUPPORT)
endif()
set_target_properties(
${BENCH_NAME}
PROPERTIES # set target compile options
Expand Down
17 changes: 17 additions & 0 deletions cpp/bench/common/wholegraph_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ void host_random_init_integer_indices(void* indices,
}
}

void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count)
{
std::default_random_engine random_engine(0);
std::uniform_int_distribution<size_t> uniform(90, 100);
size_t acc_size = 0;
size_t random_sum = 0;
for (int i = 0; i < partition_count; i++) {
partition_sizes[i] = (size_t)uniform(random_engine);
random_sum += partition_sizes[i];
}
for (int i = 0; i < partition_count; i++) {
partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size);
acc_size += partition_sizes[i];
}
partition_sizes[0] += total_size - acc_size;
}

void MultiProcessMeasurePerformance(std::function<void()> run_fn,
wholememory_comm_t& wm_comm,
const PerformanceMeter& meter,
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/common/wholegraph_benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void host_random_init_integer_indices(void* indices,
wholememory_array_description_t indices_desc,
int64_t max_indices);

void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count);

struct Metric {
Metric(const std::string& metrics_name,
const std::string& metrics_unit,
Expand Down
73 changes: 63 additions & 10 deletions cpp/bench/wholememory_ops/gather_scatter_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ typedef struct GatherScatterBenchParam {

int64_t get_embedding_dim() const { return embedding_dim; }
wholememory_dtype_t get_embedding_type() const { return embedding_type; }
int get_partition_method() const { return partition_method; }
std::string get_distributed_backend() const { return distributed_backend; }

GatherScatterBenchParam& set_memory_type(wholememory_memory_type_t new_memory_type)
{
Expand Down Expand Up @@ -153,6 +155,18 @@ typedef struct GatherScatterBenchParam {
return *this;
}

GatherScatterBenchParam& set_partition_method(int new_partition_method)
{
partition_method = new_partition_method;
return *this;
}

GatherScatterBenchParam& set_distributed_backend(std::string new_distributed_backend)
{
distributed_backend = new_distributed_backend;
return *this;
}

private:
int64_t get_embedding_entry_count() const
{
Expand Down Expand Up @@ -196,6 +210,8 @@ typedef struct GatherScatterBenchParam {
int64_t embedding_dim = 32;
int loop_count = 20;
std::string test_type = "gather"; // gather or scatter
int partition_method = 0;
std::string distributed_backend = "nccl"; // nccl or nvshmem

std::string server_addr = "localhost";
int server_port = 24987;
Expand Down Expand Up @@ -256,7 +272,15 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)

wholememory_comm_t wm_comm =
create_communicator_by_socket(side_band_communicator, world_rank, world_size);

std::string distributed_backend = params.get_distributed_backend();
#ifdef WITH_NVSHMEM_SUPPORT
if (distributed_backend.compare("nvshmem") == 0)
WHOLEMEMORY_CHECK_NOTHROW(wholememory_communicator_set_distributed_backend(
wm_comm, WHOLEMEMORY_DB_NVSHMEM) == WHOLEMEMORY_SUCCESS);
#else
distributed_backend = "nccl";
params.set_distributed_backend("nccl");
#endif
ShutDownSidebandCommunicator(side_band_communicator);

auto embedding_desc = params.get_embedding_desc();
Expand All @@ -268,12 +292,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
wholememory_tensor_t embedding_tensor;
wholememory_tensor_description_t embedding_tensor_desc;
wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_desc, &embedding_desc);
std::vector<size_t> rank_partition(world_size);
wholegraph::bench::host_random_partition(
rank_partition.data(), embedding_tensor_desc.sizes[0], world_size);
WHOLEMEMORY_CHECK_NOTHROW(wholememory_create_tensor(&embedding_tensor,
&embedding_tensor_desc,
wm_comm,
params.get_memory_type(),
params.get_memory_location()) ==
WHOLEMEMORY_SUCCESS);
params.get_memory_location(),
params.get_partition_method() == 1
? rank_partition.data()
: nullptr) == WHOLEMEMORY_SUCCESS);

cudaStream_t stream;
WM_CUDA_CHECK_NO_THROW(cudaStreamCreate(&stream));
Expand Down Expand Up @@ -318,16 +347,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
double gather_size_mb = (double)params.get_gather_size() / 1024.0 / 1024.0;
if (local_rank == 0) {
printf(
"%s, world_size=%d, memoryType=%s, memoryLocation=%s, elt_size=%ld, embeddingDim=%ld, "
"embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB\n",
"%s, worldSize=%d, memoryType=%s, memoryLocation=%s, eltSize=%ld, embeddingDim=%ld, "
"embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB, distributedBackend=%s\n",
test_type.c_str(),
world_size,
get_memory_type_string(params.get_memory_type()).c_str(),
get_memory_location_string(params.get_memory_location()).c_str(),
wholememory_dtype_get_element_size(params.get_embedding_type()),
params.get_embedding_dim(),
emb_size_mb,
gather_size_mb);
gather_size_mb,
distributed_backend.c_str());
}

PerformanceMeter meter;
Expand Down Expand Up @@ -388,7 +418,7 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
int main(int argc, char** argv)
{
wholegraph::bench::gather_scatter::GatherScatterBenchParam params;
const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:";
const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:m:b:";
struct option opts[] = {
{"help", no_argument, NULL, 'h'},
{"memory_type",
Expand All @@ -405,8 +435,9 @@ int main(int argc, char** argv)
{"node_size", required_argument, NULL, 's'}, // node_size
{"num_gpu", required_argument, NULL, 'n'}, // num gpu per node
{"server_addr", required_argument, NULL, 'a'}, // server_addr
{"server_port", required_argument, NULL, 'p'} // server_port
};
{"server_port", required_argument, NULL, 'p'}, // server_port
{"partition_method", required_argument, NULL, 'm'},
{"distributed_backend", required_argument, NULL, 'b'}};

const char* usage =
"Usage: %s [options]\n"
Expand All @@ -424,7 +455,9 @@ int main(int argc, char** argv)
" -s, --node_size node_size or process count\n"
" -n, --num_gpu num_gpu per process\n"
" -a, --server_addr specify sideband server address\n"
" -p, --server_port specify sideband server port\n";
" -p, --server_port specify sideband server port\n"
" -m, --partition_method specify rank partition method, 0: Default, 1: Random\n"
" -b, --distributed_backend specify distributed backend: nccl or nvshmem\n";

int c;
bool has_option = false;
Expand Down Expand Up @@ -536,6 +569,26 @@ int main(int argc, char** argv)
}
params.set_num_gpu(val);
break;
case 'm':
val = std::atoi(optarg);
if (val != 0 && val != 1) {
printf("Invalid argument for option -m\n");
printf(usage, argv[0]);
exit(EXIT_FAILURE);
}
params.set_partition_method(val);
break;
case 'b':
if (strcmp(optarg, "nccl") == 0) {
params.set_distributed_backend("nccl");
} else if (strcmp(optarg, "nvshmem") == 0) {
params.set_distributed_backend("nvshmem");
} else {
printf("Invalid argument for option -b\n");
printf(usage, argv[0]);
exit(EXIT_FAILURE);
}
break;
default:
printf("Invalid or unrecognized option\n");
printf(usage, argv[0]);
Expand Down
33 changes: 29 additions & 4 deletions cpp/include/wholememory/device_reference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,48 @@ class device_reference {
public:
__device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref)
: pointer_(static_cast<DataTypeT*>(gref.pointer)),
typed_stride_(gref.stride / sizeof(DataTypeT))
typed_stride_(gref.stride / sizeof(DataTypeT)),
world_size_(gref.world_size),
same_chunk_(gref.same_chunk)
{
assert(gref.stride % sizeof(DataTypeT) == 0);
if (typed_stride_ != 0 && !same_chunk_) {
assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED
for (int i = 0; i < world_size_ + 1; i++) {
assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0);
typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT);
}
}
}
__device__ device_reference() = delete;

__device__ __forceinline__ DataTypeT& operator[](size_t index)
{
if (typed_stride_ == 0) { return pointer_[index]; }
size_t rank = index / typed_stride_;
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
if (same_chunk_) {
size_t rank = index / typed_stride_;
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
} else {
size_t rank = 0;
for (int i = 1; i < world_size_ + 1; i++) {
if (index < typed_rank_mem_offsets_[i]) {
rank = i - 1;
break;
}
}
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - typed_rank_mem_offsets_[rank]];
}
}

private:
DataTypeT* pointer_;
int world_size_;
size_t typed_stride_;

bool same_chunk_;
size_t typed_rank_mem_offsets_[8 + 1];
};

} // namespace wholememory
7 changes: 5 additions & 2 deletions cpp/include/wholememory/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy(
* @param memory_type : Memory Type of the underlying WholeMemory
* @param memory_location : Memory Location of the underlying WholeMemory
* @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr
* @param embedding_entry_partition: Embedding entry count of each rank, the length must be
* world_size
* @param user_defined_sms : User-defined sms number for raw embedding gather/scatter
* @param round_robin_size : continuous embedding size in each rank under round-robin shard mode
* @return : wholememory_error_code_t
Expand All @@ -140,8 +142,9 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms = -1,
int round_robin_size = 0);
size_t* embedding_entry_partition = nullptr,
int user_defined_sms = -1,
int round_robin_size = 0);

/**
* Destroy WholeMemory Embedding
Expand Down
8 changes: 7 additions & 1 deletion cpp/include/wholememory/global_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ extern "C" {
/**
* @brief Global reference of a WholeMemory object
*
* A global reference is for Continuous of Chunked WholeMemory Type, in these types, each rank can
* A global reference is for Continuous or Chunked WholeMemory Type, in these types, each rank can
* directly access all memory from all ranks. The global reference is used to do this direct access.
*/
struct wholememory_gref_t {
void* pointer; /*!< pointer to data for CONTINUOUS WholeMemory or pointer to data pointer array
for CHUNKED WholeMemory */
size_t*
rank_memory_offsets; /*!< memory offset of each rank, and the length must be world_size+1 */
int world_size;
size_t
stride; /*!< must be 0 for CONTINUOUS WholeMemory or memory size in byte for each pointer */
bool same_chunk; /*!< if true, rank can be got by offset/stride */
};

/**
Expand All @@ -43,9 +47,11 @@ wholememory_gref_t wholememory_create_continuous_global_reference(void* ptr);

struct wholememory_nvshmem_ref_t {
void* pointer;
size_t* rank_memory_offsets;
size_t stride;
int world_rank;
int world_size;
bool same_chunk;
};

#ifdef __cplusplus
Expand Down
Loading