Skip to content

Commit

Permalink
Improve logic for scaling dram read BW w.r.t. fork factor
Browse files Browse the repository at this point in the history
(cherry picked from commit 187e6c3c1c2bd1b03084c37be7a88aa513cf7581)
  • Loading branch information
derdeljanTT authored and vmilosevic committed May 20, 2024
1 parent 4fb9e63 commit e4fd66d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 44 deletions.
34 changes: 2 additions & 32 deletions pybuda/csrc/balancer/data_movement_bw_estimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,43 +527,13 @@ BandwidthBucket DramReadEstimator::estimate_bandwidth_impl() const
features_.get_tile_size());
}

return make_bandwidth_prediction(
dram_receiving_stream_buffer_size, dram_buf_read_chunk_size_tiles, dram_scatter_chunk_size_tiles);
}

BandwidthBucket DramReadEstimator::make_bandwidth_prediction(
const int unpacker_buffer_size_bytes,
const int dram_buf_read_chunk_size_tiles,
const int dram_scatter_chunk_size_tiles) const
{
const BandwidthBucket bw_bucket_without_fork = estimate_dram_read_connection(
return estimate_dram_read_connection(
features_.get_consumer_epoch_tiles(),
features_.get_tile_size(),
features_.get_kernel_clear_granularity(),
unpacker_buffer_size_bytes,
dram_receiving_stream_buffer_size,
dram_buf_read_chunk_size_tiles,
dram_scatter_chunk_size_tiles);

return scale_bandwidth_wrt_fork_factor(bw_bucket_without_fork.get_bandwidth(), features_.get_producer_fan_out());
}

BandwidthBucket DramReadEstimator::scale_bandwidth_wrt_fork_factor(
const double bw_without_fork, const int fork_factor) const
{
const double linear_cap = 1.0 * c_linear_noc_threshold / fork_factor;
const double theoretical_cap = 1.0 * c_theoretical_noc_threshold / fork_factor;

if (bw_without_fork <= linear_cap || fork_factor == 1)
{
return BandwidthBucket(bw_without_fork);
}

const double dx = c_theoretical_noc_threshold - bw_without_fork;
const double dy = (bw_without_fork - linear_cap) * (theoretical_cap - linear_cap);

const double fork_bw = dy / dx + linear_cap;

return BandwidthBucket(fork_bw);
}

//----------------------------------------------------------------------------------------------------------------------
Expand Down
10 changes: 0 additions & 10 deletions pybuda/csrc/balancer/data_movement_bw_estimation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,6 @@ class DramReadEstimator : public Estimator

private:
BandwidthBucket estimate_bandwidth_impl() const override;

BandwidthBucket make_bandwidth_prediction(
const int unpacker_buffer_size_bytes,
const int dram_buf_read_chunk_size_tiles,
const int dram_scatter_chunk_size_tiles) const;

BandwidthBucket scale_bandwidth_wrt_fork_factor(const double bw_without_fork, const int fork_factor) const;

constexpr static int c_linear_noc_threshold = 20;
constexpr static int c_theoretical_noc_threshold = 24;
};

//----------------------------------------------------------------------------------------------------------------------
Expand Down
27 changes: 25 additions & 2 deletions pybuda/csrc/balancer/policies/policy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,29 @@ std::optional<OpModel> get_op_model_for_input_queue(
}
}

float scale_bandwidth_wrt_fork_factor(const float bw_without_fork, const float fork_factor)
{
constexpr static float c_linear_noc_threshold = 20.0f;
constexpr static float c_theoretical_noc_threshold = 24.0f;

const float linear_cap = c_linear_noc_threshold / fork_factor;
const float theoretical_cap = c_theoretical_noc_threshold / fork_factor;

if (bw_without_fork <= linear_cap || fork_factor <= 1.0f)
{
return bw_without_fork;
}

const float dx = c_theoretical_noc_threshold - bw_without_fork;
const float dy = (bw_without_fork - linear_cap) * (theoretical_cap - linear_cap);

const float fork_bw = dy / dx + linear_cap;

TT_ASSERT(fork_bw > 0, "Scaled bandwidth must be a positive value");

return fork_bw;
}

float get_dram_read_bw_estimation_for_edge(
const Graph *graph,
const Edge &queue_to_op_edge,
Expand Down Expand Up @@ -1885,7 +1908,7 @@ float get_dram_read_bw_estimation_for_edge(
decompose_t_stream)
.get_bandwidth());

edge_dram_bw = std::ceil(edge_dram_bw / dram_fork_divider);
edge_dram_bw = scale_bandwidth_wrt_fork_factor(edge_dram_bw, dram_fork_divider);
}
}
else if (queue_data_inputs.size() > 0)
Expand All @@ -1908,7 +1931,7 @@ float get_dram_read_bw_estimation_for_edge(
decompose_t_stream)
.get_bandwidth());

edge_dram_bw = std::ceil(edge_dram_bw / dram_fork_divider);
edge_dram_bw = scale_bandwidth_wrt_fork_factor(edge_dram_bw, dram_fork_divider);
}
}

Expand Down
2 changes: 2 additions & 0 deletions pybuda/csrc/balancer/policies/policy_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ int get_limiter_cycles(
bool invalidate_cached = false,
const OpModels* selected_op_models = nullptr);

float scale_bandwidth_wrt_fork_factor(const float bw_without_fork, const float fork_factor);

float get_dram_read_bw_estimation_for_edge(
const Graph* graph,
const Edge& queue_to_op_edge,
Expand Down

0 comments on commit e4fd66d

Please sign in to comment.