Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#9486: Merge CCL reduce_scatter to TTNN #9979

Merged
merged 10 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions models/demos/t3000/falcon40b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def fwd_decode(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.expe
hidden_states
) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh

hidden_states = ttnn.experimental.tensor.reduce_scatter(
hidden_states = ttnn.reduce_scatter(
hidden_states,
scatter_split_dim=3,
reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
scatter_dim=3,
math_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
num_links=1, # only unidirectional supported for now
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse
Expand Down Expand Up @@ -198,12 +198,12 @@ def fwd_prefill(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.exp
self.output
) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh

hidden_states = ttnn.experimental.tensor.reduce_scatter(
hidden_states = ttnn.reduce_scatter(
hidden_states,
scatter_split_dim=3,
reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
scatter_dim=3,
math_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
num_links=1, # only one link supported for now
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/t3000/run_t3000_frequent_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ run_t3000_tteager_tests() {
echo "LOG_METAL: Running run_t3000_tteager_tests"

pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k post_commit ; fail+=$?
pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py ; fail+=$?
pytest -n auto tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py ; fail+=$?

# distributed layernorm
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/operations/test_distributed_layernorm.py ; fail+=$?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from loguru import logger
import tt_lib as ttl
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
import itertools
Expand Down Expand Up @@ -75,12 +76,12 @@ def run_reduce_scatter_test(

# Run the op
# for i in range(num_iters):
tt_out_tensors = ttl.tensor.reduce_scatter(
tt_out_tensors = ttnn.reduce_scatter(
tt_input_tensors,
scatter_split_dim=scatter_dim,
reduce_op=math_op,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
output_mem_config=mem_config,
memory_config=mem_config,
)

for d in devices:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from loguru import logger
import tt_lib as ttl
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
import itertools
Expand Down Expand Up @@ -75,12 +76,12 @@ def run_reduce_scatter_test(

# Run the op
# for i in range(num_iters):
tt_out_tensors = ttl.tensor.reduce_scatter(
tt_out_tensors = ttnn.reduce_scatter(
tt_input_tensors,
Copy link
Contributor

@SeanNijjar SeanNijjar Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume @cfjchu will have the same request here to use the new mesh tensor infrastructure and to pass in the new t3000 fixture instead of all_devices. I view it as somewhat orthogonal to moving the ops to TTNN so I personally would not hold up the PR and would just make sure to do that after the move is done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as long as it is not dropped. The primary way that users should be using the op should be via multi-device tensors so I'd expect this change to be done so it's usable.

Copy link
Collaborator

@cfjchu cfjchu Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tracking this here: #10296

scatter_split_dim=scatter_dim,
reduce_op=math_op,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
output_mem_config=mem_config,
memory_config=mem_config,
)

for d in devices:
Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp"
#include "ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp"
#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp"
#include "pybind11/operations/copy.hpp"
#include "pybind11/operations/core.hpp"
#include "pybind11/operations/creation.hpp"
Expand Down Expand Up @@ -66,6 +67,7 @@ void py_module(py::module& module) {
auto m_ccl = module.def_submodule("ccl", "collective communication operations");
ccl::py_bind_all_gather(m_ccl);
ccl::py_bind_line_all_gather(m_ccl);
ccl::py_bind_reduce_scatter(m_ccl);

auto m_complex_unary = module.def_submodule("complex_unary", "complex_unary operations");
complex_unary::py_module(m_complex_unary);
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/auto_format.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_transfer/data_transfer_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout_conversion/layout_conversion_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/reduce_scatter_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sharded/sharded_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sharded/multi_core/sharded_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sharded_partial/sharded_op_partial.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "ttnn/experimental/tt_dnn/op_library/non_zero_indices/non_zero_indices_op.hpp"
#include "ttnn/experimental/tt_dnn/op_library/sharded/sharded_op.hpp"
#include "ttnn/experimental/tt_dnn/op_library/sharded_partial/sharded_op_partial.hpp"
#include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp"


namespace tt::tt_metal::detail{
Expand Down Expand Up @@ -447,28 +446,6 @@ namespace tt::tt_metal::detail{
R"doc(Converts a partial tensor from sharded_to_interleaved memory layout)doc"
);

// ---------- Multi-Device ops ----------

// Reduce Scatter
m_tensor.def("reduce_scatter", &reduce_scatter,
py::arg("input_tensors"), py::arg("scatter_split_dim"), py::arg("reduce_op"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
R"doc(
Performs reduce scatter across chips, where the input tensors are sliced along the scatter dim, and pairwise reduced as they propagate and reduce through the cluster.

For example, a reduce scatter on a ring of rank 8 and input tensor shapes (per rank) of [1,1,1024,8096] and scatter_dim=3, will split each input tensor
on width into 8 parts of size [1,1,1024,1024]. Each of those parts will reduce with the corresponding chunk from the other ranks. All chips will collectively
reduce the first incoming [1,1,1024,1024] chunk with their local first [1,1,1024,1024] chunk and be forwarded. The second incoming [1,1,1024,1024] chunk will
be reduced with the second local [1,1,1024,1024] chunk and be forwarded and so on. Each rank in the ring will start on a different offset into the chunk such
that by the end, they will finish with a different reduced chunk offset from the original tensor shape.

.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"

"scatter_split_dim", "Dimension to evenly slice input tensor along for each rank", "int", "0..3", "Yes"
"reduce_op", "reduction math operation", " ReduceOpMath", "SUM", "No"
"num_links", "Number of ethernet links to allow the op to use to send data chip to chip for the operation. Default=1", "int", "1..max_num_links", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ class ChannelBuffer final {
ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED);
goto_state(STATE::DONE);
}
};

}
// Resets the semaphore in local L1, which workers write to remotely.
FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ struct LineAllGather {
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};


namespace operations {
namespace ccl {

Expand Down
Loading
Loading