diff --git a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu index e6d751280..0249ba1ac 100644 --- a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu +++ b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu @@ -214,7 +214,8 @@ __global__ void sgd_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[embedding_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value += weight_decay * embedding_value; embedding_value -= lr * grad_value; @@ -392,7 +393,8 @@ __global__ void lazy_adam_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[local_dim_idx + loop_start_idx]; float embedding_value = embedding_ptr[embedding_idx]; if (AdamW) { embedding_value -= lr * weight_decay * embedding_value; @@ -644,7 +646,8 @@ __global__ void ada_grad_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[embedding_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value = grad_value + weight_decay * embedding_value; float state_sum = state_sum_ptr[embedding_idx]; @@ -841,7 +844,8 @@ __global__ void rms_prop_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[local_dim_idx + loop_start_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value = grad_value + weight_decay * embedding_value; float v = v_ptr[embedding_idx]; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index d2d040a0e..5fa93ee12 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -29,7 +29,7 @@ class nvshmem_device_reference { : pointer_(static_cast<DataTypeT*>(nvshmem_ref.pointer)), typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) { - assert(gref.stride % sizeof(DataTypeT) == 0); + assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); } __device__ nvshmem_device_reference() = delete; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index a860cbc6c..4051f12bd 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -185,6 +185,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( p_env_fns, stream); // ungistre + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) { WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem"); } diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu index 453b13b41..bb6360fc0 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu @@ -149,7 +149,7 @@ struct EmbeddingBackwardTestParams { wholememory_optimizer_type_t optimizer_type = WHOLEMEMORY_OPT_SGD; float cache_ratio = 0.2; bool use_cache = false; - int run_count = 1; + int run_count = 3; float lr_ = 0.1; @@ -428,7 +428,7 @@ void prepare_data_and_reference( int64_t end_entry = (thread_rank + 1) * total_entry_count / thread_world_size; CPUOptimizer cpu_optimizer(¶ms, start_entry, end_entry); int embedding_dim = params.grad_description.sizes[1]; - for (int step = 0; step <= params.run_count; step++) { + for (int step = 0; step < params.run_count; step++) { int step_id = std::min(step, params.run_count - 1); std::vector<int64_t> indices; std::vector<std::vector<float>> grads; @@ -625,7 +625,7 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT EXPECT_EQ(cudaStreamSynchronize(nullptr), cudaSuccess); EXPECT_EQ(wholememory_communicator_barrier(wm_comm), WHOLEMEMORY_SUCCESS); - for (int run = 0; run <= params.run_count; run++) { + for (int run = 0; run < params.run_count; run++) { int step_id = std::min(run, params.run_count - 1); auto& rank_indices_vec = step_rank_indices[step_id][world_rank]; auto& rank_grads_vec = step_rank_grads[step_id][world_rank]; @@ -737,6 +737,8 @@ INSTANTIATE_TEST_SUITE_P( EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD), EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM), #endif + EmbeddingBackwardTestParams().set_entry_count(500).set_indice_count(400).set_embedding_dim(4), + EmbeddingBackwardTestParams().set_embedding_dim(3), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131).set_optimizer_type( WHOLEMEMORY_OPT_RMSPROP), diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 0999fdfe5..42746add8 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -132,7 +132,7 @@ def add_common_sampler_options(argparser: ArgumentParser): argparser.add_argument( "-s", "--inferencesample", - type=int, + type=str, dest="inferencesample", default="30", help="inference sample count, -1 is all",