Skip to content

Commit

Permalink
#5560: Use static methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 25, 2024
1 parent 81e2007 commit 57c79cc
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace operations{
namespace experimental{
namespace ccl{

AllReduceStrategy choose_all_reduce_strategy(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links) {
static AllReduceStrategy choose_all_reduce_strategy(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links) {
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();

Expand Down Expand Up @@ -98,7 +98,7 @@ AllReduceStrategy choose_all_reduce_strategy(const Tensor& input_tensor, uint32_
}


Tensor all_gather_local_reduce(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
static Tensor all_gather_local_reduce(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const std::vector<Device*>& devices, const ttnn::ccl::Topology& topology) {

auto shape = input_tensor.get_logical_shape();
Expand All @@ -125,7 +125,7 @@ Tensor all_gather_local_reduce(const Tensor& input_tensor, uint32_t num_devices,
return ttnn::reshape(sum_tensor, shape);
}

Tensor reduce_scatter_all_gather(const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
static Tensor reduce_scatter_all_gather(const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const std::vector<Device*>& devices, const ttnn::ccl::Topology& topology) {
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();
Expand Down

0 comments on commit 57c79cc

Please sign in to comment.