Skip to content

Commit

Permalink
Cleaned up a few more instances of counters that should be converted …
Browse files Browse the repository at this point in the history
…to uint64_t in the data ingestion pipeline. (#2405)
  • Loading branch information
bvanessen authored Jan 30, 2024
1 parent d7c5780 commit 4f62fab
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/lbann/callbacks/variable_minibatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace callback {
class variable_minibatch : public callback_base
{
public:
variable_minibatch(size_t starting_mbsize);
variable_minibatch(uint64_t starting_mbsize);
variable_minibatch(const variable_minibatch&) = default;
variable_minibatch& operator=(const variable_minibatch&) = default;
/// Set the initial mini-batch size.
Expand Down
2 changes: 1 addition & 1 deletion include/lbann/data_ingestion/data_store_conduit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ class data_store_conduit
// methods follow
//=========================================================================

void start_exchange_data_by_sample(size_t current_pos, size_t mb_size);
void start_exchange_data_by_sample(uint64_t current_pos, uint64_t mb_size);
void finish_exchange_data_by_sample();

void setup_data_store_buffers();
Expand Down
14 changes: 7 additions & 7 deletions src/callbacks/variable_minibatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
namespace lbann {
namespace callback {

variable_minibatch::variable_minibatch(size_t starting_mbsize)
variable_minibatch::variable_minibatch(uint64_t starting_mbsize)
: m_starting_mbsize(starting_mbsize),
m_current_mini_batch_size(starting_mbsize)
{}
Expand Down Expand Up @@ -177,9 +177,9 @@ float variable_minibatch::get_current_learning_rate(model* m) const
return 0.0f;
}

step_minibatch::step_minibatch(size_t starting_mbsize,
size_t step,
size_t ramp_time)
step_minibatch::step_minibatch(uint64_t starting_mbsize,
uint64_t step,
uint64_t ramp_time)
: variable_minibatch(starting_mbsize), m_step(step), m_ramp_time(ramp_time)
{}

Expand Down Expand Up @@ -209,7 +209,7 @@ void step_minibatch::write_specific_proto(lbann_data::Callback& proto) const
msg->set_ramp_time(m_ramp_time);
}

minibatch_schedule::minibatch_schedule(size_t starting_mbsize,
minibatch_schedule::minibatch_schedule(uint64_t starting_mbsize,
std::vector<minibatch_step> steps)
: variable_minibatch(starting_mbsize), m_steps(std::move(steps))
{
Expand All @@ -221,9 +221,9 @@ minibatch_schedule::minibatch_schedule(size_t starting_mbsize,
}

bool minibatch_schedule::schedule(model* m,
size_t& new_mbsize,
uint64_t& new_mbsize,
float& new_lr,
size_t& ramp_time)
uint64_t& ramp_time)
{
const auto& c =
static_cast<const SGDExecutionContext&>(m->get_execution_context());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ int buffered_data_coordinator<TensorDataType>::fetch_to_local_matrix(
// Compute the size of the current local mini-batch
const uint64_t end_pos =
std::min(relative_base_position + loaded_mini_batch_size,
dr->m_shuffled_indices.size());
(uint64_t)dr->m_shuffled_indices.size());
const uint64_t local_mini_batch_size = std::min(
((end_pos - relative_base_position) + ds.get_sample_stride() - 1) /
ds.get_sample_stride(),
Expand Down
4 changes: 2 additions & 2 deletions src/data_ingestion/data_store_conduit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1819,8 +1819,8 @@ void data_store_conduit::profile_timing()
}
}

void data_store_conduit::start_exchange_mini_batch_data(size_t current_pos,
size_t mb_size,
void data_store_conduit::start_exchange_mini_batch_data(uint64_t current_pos,
uint64_t mb_size,
bool at_new_epoch)
{
if (is_local_cache() && is_fully_loaded()) {
Expand Down

0 comments on commit 4f62fab

Please sign in to comment.