From 4f62fab9682e43d04af76cfca96389f71d363963 Mon Sep 17 00:00:00 2001 From: Brian Van Essen Date: Tue, 30 Jan 2024 10:39:31 +0900 Subject: [PATCH] Cleaned up a few more instances of counters that should be converted to uint64_t in the data ingestion pipeline. (#2405) --- include/lbann/callbacks/variable_minibatch.hpp | 2 +- .../lbann/data_ingestion/data_store_conduit.hpp | 2 +- src/callbacks/variable_minibatch.cpp | 14 +++++++------- .../coordinator/buffered_data_coordinator.cpp | 2 +- src/data_ingestion/data_store_conduit.cpp | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/lbann/callbacks/variable_minibatch.hpp b/include/lbann/callbacks/variable_minibatch.hpp index 6a63d09f980..67138e3bff7 100644 --- a/include/lbann/callbacks/variable_minibatch.hpp +++ b/include/lbann/callbacks/variable_minibatch.hpp @@ -42,7 +42,7 @@ namespace callback { class variable_minibatch : public callback_base { public: - variable_minibatch(size_t starting_mbsize); + variable_minibatch(uint64_t starting_mbsize); variable_minibatch(const variable_minibatch&) = default; variable_minibatch& operator=(const variable_minibatch&) = default; /// Set the initial mini-batch size. diff --git a/include/lbann/data_ingestion/data_store_conduit.hpp b/include/lbann/data_ingestion/data_store_conduit.hpp index 342cf23b89e..e55445c8e53 100644 --- a/include/lbann/data_ingestion/data_store_conduit.hpp +++ b/include/lbann/data_ingestion/data_store_conduit.hpp @@ -541,7 +541,7 @@ class data_store_conduit // methods follow //========================================================================= - void start_exchange_data_by_sample(size_t current_pos, size_t mb_size); + void start_exchange_data_by_sample(uint64_t current_pos, uint64_t mb_size); void finish_exchange_data_by_sample(); void setup_data_store_buffers(); diff --git a/src/callbacks/variable_minibatch.cpp b/src/callbacks/variable_minibatch.cpp index e91b9393db0..36f16385c8a 100644 --- a/src/callbacks/variable_minibatch.cpp +++ b/src/callbacks/variable_minibatch.cpp @@ -43,7 +43,7 @@ namespace lbann { namespace callback { -variable_minibatch::variable_minibatch(size_t starting_mbsize) +variable_minibatch::variable_minibatch(uint64_t starting_mbsize) : m_starting_mbsize(starting_mbsize), m_current_mini_batch_size(starting_mbsize) {} @@ -177,9 +177,9 @@ float variable_minibatch::get_current_learning_rate(model* m) const return 0.0f; } -step_minibatch::step_minibatch(size_t starting_mbsize, - size_t step, - size_t ramp_time) +step_minibatch::step_minibatch(uint64_t starting_mbsize, + uint64_t step, + uint64_t ramp_time) : variable_minibatch(starting_mbsize), m_step(step), m_ramp_time(ramp_time) {} @@ -209,7 +209,7 @@ void step_minibatch::write_specific_proto(lbann_data::Callback& proto) const msg->set_ramp_time(m_ramp_time); } -minibatch_schedule::minibatch_schedule(size_t starting_mbsize, +minibatch_schedule::minibatch_schedule(uint64_t starting_mbsize, std::vector steps) : variable_minibatch(starting_mbsize), m_steps(std::move(steps)) { @@ -221,9 +221,9 @@ minibatch_schedule::minibatch_schedule(size_t starting_mbsize, } bool minibatch_schedule::schedule(model* m, - size_t& new_mbsize, + uint64_t& new_mbsize, float& new_lr, - size_t& ramp_time) + uint64_t& ramp_time) { const auto& c = static_cast(m->get_execution_context()); diff --git a/src/data_ingestion/coordinator/buffered_data_coordinator.cpp b/src/data_ingestion/coordinator/buffered_data_coordinator.cpp index 8a401acca3e..4d348d834eb 100644 --- a/src/data_ingestion/coordinator/buffered_data_coordinator.cpp +++ b/src/data_ingestion/coordinator/buffered_data_coordinator.cpp @@ -167,7 +167,7 @@ int buffered_data_coordinator::fetch_to_local_matrix( // Compute the size of the current local mini-batch const uint64_t end_pos = std::min(relative_base_position + loaded_mini_batch_size, - dr->m_shuffled_indices.size()); + (uint64_t)dr->m_shuffled_indices.size()); const uint64_t local_mini_batch_size = std::min( ((end_pos - relative_base_position) + ds.get_sample_stride() - 1) / ds.get_sample_stride(), diff --git a/src/data_ingestion/data_store_conduit.cpp b/src/data_ingestion/data_store_conduit.cpp index 9f7d87caf7c..3b127793d2b 100644 --- a/src/data_ingestion/data_store_conduit.cpp +++ b/src/data_ingestion/data_store_conduit.cpp @@ -1819,8 +1819,8 @@ void data_store_conduit::profile_timing() } } -void data_store_conduit::start_exchange_mini_batch_data(size_t current_pos, - size_t mb_size, +void data_store_conduit::start_exchange_mini_batch_data(uint64_t current_pos, + uint64_t mb_size, bool at_new_epoch) { if (is_local_cache() && is_fully_loaded()) {