Skip to content

Commit

Permalink
update wholegraph
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed Nov 6, 2024
1 parent f7ab898 commit 7aefaeb
Show file tree
Hide file tree
Showing 64 changed files with 3,036 additions and 664 deletions.
4 changes: 3 additions & 1 deletion cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ function(ConfigureBench)
rmm::rmm
pthread
)

if(BUILD_WITH_NVSHMEM)
target_compile_definitions(${BENCH_NAME} PRIVATE WITH_NVSHMEM_SUPPORT)
endif()
set_target_properties(
${BENCH_NAME}
PROPERTIES # set target compile options
Expand Down
17 changes: 17 additions & 0 deletions cpp/bench/common/wholegraph_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ void host_random_init_integer_indices(void* indices,
}
}

void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count)
{
std::default_random_engine random_engine(0);
std::uniform_int_distribution<size_t> uniform(90, 100);
size_t acc_size = 0;
size_t random_sum = 0;
for (int i = 0; i < partition_count; i++) {
partition_sizes[i] = (size_t)uniform(random_engine);
random_sum += partition_sizes[i];
}
for (int i = 0; i < partition_count; i++) {
partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size);
acc_size += partition_sizes[i];
}
partition_sizes[0] += total_size - acc_size;
}

void MultiProcessMeasurePerformance(std::function<void()> run_fn,
wholememory_comm_t& wm_comm,
const PerformanceMeter& meter,
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/common/wholegraph_benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void host_random_init_integer_indices(void* indices,
wholememory_array_description_t indices_desc,
int64_t max_indices);

void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count);

struct Metric {
Metric(const std::string& metrics_name,
const std::string& metrics_unit,
Expand Down
73 changes: 63 additions & 10 deletions cpp/bench/wholememory_ops/gather_scatter_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ typedef struct GatherScatterBenchParam {

int64_t get_embedding_dim() const { return embedding_dim; }
wholememory_dtype_t get_embedding_type() const { return embedding_type; }
int get_partition_method() const { return partition_method; }
std::string get_distributed_backend() const { return distributed_backend; }

GatherScatterBenchParam& set_memory_type(wholememory_memory_type_t new_memory_type)
{
Expand Down Expand Up @@ -153,6 +155,18 @@ typedef struct GatherScatterBenchParam {
return *this;
}

GatherScatterBenchParam& set_partition_method(int new_partition_method)
{
partition_method = new_partition_method;
return *this;
}

GatherScatterBenchParam& set_distributed_backend(std::string new_distributed_backend)
{
distributed_backend = new_distributed_backend;
return *this;
}

private:
int64_t get_embedding_entry_count() const
{
Expand Down Expand Up @@ -196,6 +210,8 @@ typedef struct GatherScatterBenchParam {
int64_t embedding_dim = 32;
int loop_count = 20;
std::string test_type = "gather"; // gather or scatter
int partition_method = 0;
std::string distributed_backend = "nccl"; // nccl or nvshmem

std::string server_addr = "localhost";
int server_port = 24987;
Expand Down Expand Up @@ -256,7 +272,15 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)

wholememory_comm_t wm_comm =
create_communicator_by_socket(side_band_communicator, world_rank, world_size);

std::string distributed_backend = params.get_distributed_backend();
#ifdef WITH_NVSHMEM_SUPPORT
if (distributed_backend.compare("nvshmem") == 0)
WHOLEMEMORY_CHECK_NOTHROW(wholememory_communicator_set_distributed_backend(
wm_comm, WHOLEMEMORY_DB_NVSHMEM) == WHOLEMEMORY_SUCCESS);
#else
distributed_backend = "nccl";
params.set_distributed_backend("nccl");
#endif
ShutDownSidebandCommunicator(side_band_communicator);

auto embedding_desc = params.get_embedding_desc();
Expand All @@ -268,12 +292,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
wholememory_tensor_t embedding_tensor;
wholememory_tensor_description_t embedding_tensor_desc;
wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_desc, &embedding_desc);
std::vector<size_t> rank_partition(world_size);
wholegraph::bench::host_random_partition(
rank_partition.data(), embedding_tensor_desc.sizes[0], world_size);
WHOLEMEMORY_CHECK_NOTHROW(wholememory_create_tensor(&embedding_tensor,
&embedding_tensor_desc,
wm_comm,
params.get_memory_type(),
params.get_memory_location()) ==
WHOLEMEMORY_SUCCESS);
params.get_memory_location(),
params.get_partition_method() == 1
? rank_partition.data()
: nullptr) == WHOLEMEMORY_SUCCESS);

cudaStream_t stream;
WM_CUDA_CHECK_NO_THROW(cudaStreamCreate(&stream));
Expand Down Expand Up @@ -318,16 +347,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
double gather_size_mb = (double)params.get_gather_size() / 1024.0 / 1024.0;
if (local_rank == 0) {
printf(
"%s, world_size=%d, memoryType=%s, memoryLocation=%s, elt_size=%ld, embeddingDim=%ld, "
"embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB\n",
"%s, worldSize=%d, memoryType=%s, memoryLocation=%s, eltSize=%ld, embeddingDim=%ld, "
"embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB, distributedBackend=%s\n",
test_type.c_str(),
world_size,
get_memory_type_string(params.get_memory_type()).c_str(),
get_memory_location_string(params.get_memory_location()).c_str(),
wholememory_dtype_get_element_size(params.get_embedding_type()),
params.get_embedding_dim(),
emb_size_mb,
gather_size_mb);
gather_size_mb,
distributed_backend.c_str());
}

PerformanceMeter meter;
Expand Down Expand Up @@ -388,7 +418,7 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params)
int main(int argc, char** argv)
{
wholegraph::bench::gather_scatter::GatherScatterBenchParam params;
const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:";
const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:m:b:";
struct option opts[] = {
{"help", no_argument, NULL, 'h'},
{"memory_type",
Expand All @@ -405,8 +435,9 @@ int main(int argc, char** argv)
{"node_size", required_argument, NULL, 's'}, // node_size
{"num_gpu", required_argument, NULL, 'n'}, // num gpu per node
{"server_addr", required_argument, NULL, 'a'}, // server_addr
{"server_port", required_argument, NULL, 'p'} // server_port
};
{"server_port", required_argument, NULL, 'p'}, // server_port
{"partition_method", required_argument, NULL, 'm'},
{"distributed_backend", required_argument, NULL, 'b'}};

const char* usage =
"Usage: %s [options]\n"
Expand All @@ -424,7 +455,9 @@ int main(int argc, char** argv)
" -s, --node_size node_size or process count\n"
" -n, --num_gpu num_gpu per process\n"
" -a, --server_addr specify sideband server address\n"
" -p, --server_port specify sideband server port\n";
" -p, --server_port specify sideband server port\n"
" -m, --partition_method specify rank partition method, 0: Default, 1: Random\n"
" -b, --distributed_backend specify distributed backend: nccl or nvshmem\n";

int c;
bool has_option = false;
Expand Down Expand Up @@ -536,6 +569,26 @@ int main(int argc, char** argv)
}
params.set_num_gpu(val);
break;
case 'm':
val = std::atoi(optarg);
if (val != 0 && val != 1) {
printf("Invalid argument for option -m\n");
printf(usage, argv[0]);
exit(EXIT_FAILURE);
}
params.set_partition_method(val);
break;
case 'b':
if (strcmp(optarg, "nccl") == 0) {
params.set_distributed_backend("nccl");
} else if (strcmp(optarg, "nvshmem") == 0) {
params.set_distributed_backend("nvshmem");
} else {
printf("Invalid argument for option -b\n");
printf(usage, argv[0]);
exit(EXIT_FAILURE);
}
break;
default:
printf("Invalid or unrecognized option\n");
printf(usage, argv[0]);
Expand Down
33 changes: 29 additions & 4 deletions cpp/include/wholememory/device_reference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,48 @@ class device_reference {
public:
__device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref)
: pointer_(static_cast<DataTypeT*>(gref.pointer)),
typed_stride_(gref.stride / sizeof(DataTypeT))
typed_stride_(gref.stride / sizeof(DataTypeT)),
world_size_(gref.world_size),
same_chunk_(gref.same_chunk)
{
assert(gref.stride % sizeof(DataTypeT) == 0);
if (typed_stride_ != 0 && !same_chunk_) {
assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED
for (int i = 0; i < world_size_ + 1; i++) {
assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0);
typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT);
}
}
}
__device__ device_reference() = delete;

__device__ __forceinline__ DataTypeT& operator[](size_t index)
{
if (typed_stride_ == 0) { return pointer_[index]; }
size_t rank = index / typed_stride_;
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
if (same_chunk_) {
size_t rank = index / typed_stride_;
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
} else {
size_t rank = 0;
for (int i = 1; i < world_size_ + 1; i++) {
if (index < typed_rank_mem_offsets_[i]) {
rank = i - 1;
break;
}
}
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - typed_rank_mem_offsets_[rank]];
}
}

private:
DataTypeT* pointer_;
int world_size_;
size_t typed_stride_;

bool same_chunk_;
size_t typed_rank_mem_offsets_[8 + 1];
};

} // namespace wholememory
7 changes: 5 additions & 2 deletions cpp/include/wholememory/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy(
* @param memory_type : Memory Type of the underlying WholeMemory
* @param memory_location : Memory Location of the underlying WholeMemory
* @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr
* @param embedding_entry_partition: Embedding entry count of each rank, the length must be
* world_size
* @param user_defined_sms : User-defined sms number for raw embedding gather/scatter
* @param round_robin_size : continuous embedding size in each rank under round-robin shard mode
* @return : wholememory_error_code_t
Expand All @@ -140,8 +142,9 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms = -1,
int round_robin_size = 0);
size_t* embedding_entry_partition = nullptr,
int user_defined_sms = -1,
int round_robin_size = 0);

/**
* Destroy WholeMemory Embedding
Expand Down
8 changes: 7 additions & 1 deletion cpp/include/wholememory/global_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ extern "C" {
/**
* @brief Global reference of a WholeMemory object
*
* A global reference is for Continuous of Chunked WholeMemory Type, in these types, each rank can
* A global reference is for Continuous or Chunked WholeMemory Type, in these types, each rank can
* directly access all memory from all ranks. The global reference is used to do this direct access.
*/
struct wholememory_gref_t {
void* pointer; /*!< pointer to data for CONTINUOUS WholeMemory or pointer to data pointer array
for CHUNKED WholeMemory */
size_t*
rank_memory_offsets; /*!< memory offset of each rank, and the length must be world_size+1 */
int world_size;
size_t
stride; /*!< must be 0 for CONTINUOUS WholeMemory or memory size in byte for each pointer */
bool same_chunk; /*!< if true, rank can be got by offset/stride */
};

/**
Expand All @@ -43,9 +47,11 @@ wholememory_gref_t wholememory_create_continuous_global_reference(void* ptr);

struct wholememory_nvshmem_ref_t {
void* pointer;
size_t* rank_memory_offsets;
size_t stride;
int world_rank;
int world_size;
bool same_chunk;
};

#ifdef __cplusplus
Expand Down
Loading

0 comments on commit 7aefaeb

Please sign in to comment.