diff --git a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu
index e6d751280..0249ba1ac 100644
--- a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu
+++ b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu
@@ -214,7 +214,8 @@ __global__ void sgd_optimizer_step_kernel(const IndiceT* indices_ptr,
     int local_dim_idx = threadIdx.x;
     float grad_value  = 0.0f;
     int embedding_idx = local_dim_idx + loop_start_idx;
-    if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; }
+    if (embedding_idx >= embedding_dim) { break; }
+    grad_value            = grads_ptr[embedding_idx];
     float embedding_value = embedding_ptr[embedding_idx];
     grad_value += weight_decay * embedding_value;
     embedding_value -= lr * grad_value;
@@ -392,7 +393,8 @@ __global__ void lazy_adam_optimizer_step_kernel(const IndiceT* indices_ptr,
     int local_dim_idx = threadIdx.x;
     float grad_value  = 0.0f;
     int embedding_idx = local_dim_idx + loop_start_idx;
-    if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; }
+    if (embedding_idx >= embedding_dim) { break; }
+    grad_value            = grads_ptr[local_dim_idx + loop_start_idx];
     float embedding_value = embedding_ptr[embedding_idx];
     if (AdamW) {
       embedding_value -= lr * weight_decay * embedding_value;
@@ -644,7 +646,8 @@ __global__ void ada_grad_optimizer_step_kernel(const IndiceT* indices_ptr,
     int local_dim_idx = threadIdx.x;
     float grad_value  = 0.0f;
     int embedding_idx = local_dim_idx + loop_start_idx;
-    if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; }
+    if (embedding_idx >= embedding_dim) { break; }
+    grad_value                   = grads_ptr[embedding_idx];
     float embedding_value        = embedding_ptr[embedding_idx];
     grad_value                   = grad_value + weight_decay * embedding_value;
     float state_sum              = state_sum_ptr[embedding_idx];
@@ -841,7 +844,8 @@ __global__ void rms_prop_optimizer_step_kernel(const IndiceT* indices_ptr,
     int local_dim_idx = threadIdx.x;
     float grad_value  = 0.0f;
     int embedding_idx = local_dim_idx + loop_start_idx;
-    if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; }
+    if (embedding_idx >= embedding_dim) { break; }
+    grad_value                   = grads_ptr[local_dim_idx + loop_start_idx];
     float embedding_value        = embedding_ptr[embedding_idx];
     grad_value                   = grad_value + weight_decay * embedding_value;
     float v                      = v_ptr[embedding_idx];
diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh
index d2d040a0e..5fa93ee12 100644
--- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh
+++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh
@@ -29,7 +29,7 @@ class nvshmem_device_reference {
     : pointer_(static_cast<DataTypeT*>(nvshmem_ref.pointer)),
       typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT))
   {
-    assert(gref.stride % sizeof(DataTypeT) == 0);
+    assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0);
   }
 
   __device__ nvshmem_device_reference() = delete;
diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu
index a860cbc6c..4051f12bd 100644
--- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu
+++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu
@@ -185,6 +185,7 @@ wholememory_error_code_t wholememory_gather_nvshmem(
                                      p_env_fns,
                                      stream);
     // ungistre
+    WM_CUDA_CHECK(cudaStreamSynchronize(stream));
     if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) {
       WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem");
     }
diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu
index 453b13b41..bb6360fc0 100644
--- a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu
+++ b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu
@@ -149,7 +149,7 @@ struct EmbeddingBackwardTestParams {
   wholememory_optimizer_type_t optimizer_type         = WHOLEMEMORY_OPT_SGD;
   float cache_ratio                                   = 0.2;
   bool use_cache                                      = false;
-  int run_count                                       = 1;
+  int run_count                                       = 3;
 
   float lr_ = 0.1;
 
@@ -428,7 +428,7 @@ void prepare_data_and_reference(
                    int64_t end_entry = (thread_rank + 1) * total_entry_count / thread_world_size;
                    CPUOptimizer cpu_optimizer(&params, start_entry, end_entry);
                    int embedding_dim = params.grad_description.sizes[1];
-                   for (int step = 0; step <= params.run_count; step++) {
+                   for (int step = 0; step < params.run_count; step++) {
                      int step_id = std::min(step, params.run_count - 1);
                      std::vector<int64_t> indices;
                      std::vector<std::vector<float>> grads;
@@ -625,7 +625,7 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT
       EXPECT_EQ(cudaStreamSynchronize(nullptr), cudaSuccess);
       EXPECT_EQ(wholememory_communicator_barrier(wm_comm), WHOLEMEMORY_SUCCESS);
 
-      for (int run = 0; run <= params.run_count; run++) {
+      for (int run = 0; run < params.run_count; run++) {
         int step_id            = std::min(run, params.run_count - 1);
         auto& rank_indices_vec = step_rank_indices[step_id][world_rank];
         auto& rank_grads_vec   = step_rank_grads[step_id][world_rank];
@@ -737,6 +737,8 @@ INSTANTIATE_TEST_SUITE_P(
         EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD),
         EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM),
 #endif
+    EmbeddingBackwardTestParams().set_entry_count(500).set_indice_count(400).set_embedding_dim(4),
+    EmbeddingBackwardTestParams().set_embedding_dim(3),
     EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131),
     EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131).set_optimizer_type(
       WHOLEMEMORY_OPT_RMSPROP),
diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py
index 0999fdfe5..42746add8 100644
--- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py
+++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py
@@ -132,7 +132,7 @@ def add_common_sampler_options(argparser: ArgumentParser):
     argparser.add_argument(
         "-s",
         "--inferencesample",
-        type=int,
+        type=str,
         dest="inferencesample",
         default="30",
         help="inference sample count, -1 is all",