From cf04180a3902bbc8ef191eda4c4a2cc6d2fd6362 Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Thu, 3 Oct 2024 23:04:29 +0000 Subject: [PATCH] #13454: Refactor files under tt_metal/distributed and ttnn/distributed --- .../Programming Mesh of Devices with TT-NN.md | 4 +- tests/scripts/t3000/run_t3000_unit_tests.sh | 2 +- tests/scripts/tg/run_tg_frequent_tests.sh | 4 +- tests/scripts/tgg/run_tgg_unit_tests.sh | 2 +- .../test_ethernet_hop_latencies_no_edm.cpp | 4 +- .../distributed/test_data_parallel_example.py | 60 +++++++++++++++++++ .../test_data_parallel_example_TG.py | 0 .../test_mesh_device_TGG.py | 0 .../test_multidevice_TG.py | 0 .../test_tensor_parallel_example_T3000.py | 0 .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 10 ++-- .../unit_tests/gtests/test_multi_device.cpp | 4 +- .../unit_tests/gtests/ttnn_test_fixtures.hpp | 6 +- tt_metal/CMakeLists.txt | 2 + .../mesh_configurations/N300.json | 0 .../mesh_configurations/T3000.json | 0 .../mesh_configurations/TG.json | 0 .../mesh_configurations/TGG.json | 0 .../mesh_configurations/device.json | 0 .../device => distributed}/mesh_device.cpp | 7 ++- .../device => distributed}/mesh_device.hpp | 2 +- .../mesh_device_view.cpp | 4 +- .../mesh_device_view.hpp | 0 tt_metal/impl/CMakeLists.txt | 2 - ttnn/CMakeLists.txt | 3 +- ttnn/cpp/pybind11/__init__.cpp | 6 +- .../distributed/distributed_pybind.cpp} | 19 +++--- .../ttnn/distributed/distributed_pybind.hpp | 15 +++++ .../mesh_device.cpp} | 8 +-- .../mesh_device.hpp} | 12 +--- ttnn/cpp/ttnn/events.hpp | 2 +- .../operations/ccl/all_gather/all_gather.cpp | 2 +- .../indexed_fill/indexed_fill.cpp | 2 +- .../all_gather_matmul/all_gather_matmul.cpp | 1 - .../all_gather_matmul/all_gather_matmul.hpp | 2 +- .../cpp/ttnn/operations/kv_cache/kv_cache.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor.hpp | 2 +- ttnn/ttnn/__init__.py | 25 +------- ttnn/ttnn/distributed/__init__.py | 27 +++++++++ .../distributed.py} | 0 40 files changed, 156 insertions(+), 85 deletions(-) create mode 100644 tests/ttnn/distributed/test_data_parallel_example.py rename tests/ttnn/{multichip_unit_tests => distributed}/test_data_parallel_example_TG.py (100%) rename tests/ttnn/{multichip_unit_tests => distributed}/test_mesh_device_TGG.py (100%) rename tests/ttnn/{multichip_unit_tests => distributed}/test_multidevice_TG.py (100%) rename tests/ttnn/{multichip_unit_tests => distributed}/test_tensor_parallel_example_T3000.py (100%) rename tt_metal/{impl/device => distributed}/mesh_configurations/N300.json (100%) rename tt_metal/{impl/device => distributed}/mesh_configurations/T3000.json (100%) rename tt_metal/{impl/device => distributed}/mesh_configurations/TG.json (100%) rename tt_metal/{impl/device => distributed}/mesh_configurations/TGG.json (100%) rename tt_metal/{impl/device => distributed}/mesh_configurations/device.json (100%) rename tt_metal/{impl/device => distributed}/mesh_device.cpp (98%) rename tt_metal/{impl/device => distributed}/mesh_device.hpp (99%) rename tt_metal/{impl/device => distributed}/mesh_device_view.cpp (99%) rename tt_metal/{impl/device => distributed}/mesh_device_view.hpp (100%) rename ttnn/cpp/{pybind11/multi_device.hpp => ttnn/distributed/distributed_pybind.cpp} (95%) create mode 100644 ttnn/cpp/ttnn/distributed/distributed_pybind.hpp rename ttnn/cpp/ttnn/{multi_device.cpp => distributed/mesh_device.cpp} (96%) rename ttnn/cpp/ttnn/{multi_device.hpp => distributed/mesh_device.hpp} (80%) create mode 100644 ttnn/ttnn/distributed/__init__.py rename ttnn/ttnn/{multi_device.py => distributed/distributed.py} (100%) diff --git a/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md b/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md index e76a59392529..289916bfec1e 100644 --- a/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md +++ b/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md @@ -418,7 +418,7 @@ ttnn.close_device(device) 4. **Executing TT-NN Falcon-7B MLP Module on MeshDevice with Data Parallel** -Full code example can be found in `tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py` +Full code example can be found in `tests/ttnn/distributed/test_data_parallel_example_TG.py` ```py # Load Falcon MLP model from huggingface @@ -507,7 +507,7 @@ model = transformers.models.falcon.modeling_falcon.FalconMLP(config).eval() 3. **Executing TT-NN Falcon-7B MLP Module on MeshDevice with Tensor Parallel** -See full code example in `tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py` +See full code example in `tests/ttnn/distributed/test_tensor_parallel_example_T3000.py` ```py # Initialize hidden states diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index ced18f54ec36..95483dbcf96c 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -38,7 +38,7 @@ run_t3000_ttnn_tests() { WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_events.py ; fail+=$? pytest -n auto tests/ttnn/unit_tests/test_multi_device.py ; fail+=$? pytest -n auto tests/ttnn/unit_tests/test_multi_device_async.py ; fail+=$? - pytest tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py ; fail+=$? + pytest tests/ttnn/distributed/test_tensor_parallel_example_T3000.py ; fail+=$? # Record the end time end_time=$(date +%s) duration=$((end_time - start_time)) diff --git a/tests/scripts/tg/run_tg_frequent_tests.sh b/tests/scripts/tg/run_tg_frequent_tests.sh index 6a7a35af6db7..77525f706317 100755 --- a/tests/scripts/tg/run_tg_frequent_tests.sh +++ b/tests/scripts/tg/run_tg_frequent_tests.sh @@ -4,8 +4,8 @@ run_tg_tests() { # Add tests here echo "LOG_METAL: running run_tg_frequent_tests" - pytest -n auto tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py --timeout=900 ; fail+=$? - pytest -n auto tests/ttnn/multichip_unit_tests/test_multidevice_TG.py --timeout=900 ; fail+=$? + pytest -n auto tests/ttnn/distributed/test_data_parallel_example_TG.py --timeout=900 ; fail+=$? + pytest -n auto tests/ttnn/distributed/test_multidevice_TG.py --timeout=900 ; fail+=$? pytest -n auto tests/ttnn/unit_tests/test_multi_device_trace_TG.py --timeout=900 ; fail+=$? pytest -n auto models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py --timeout=300 ; fail+=$? pytest -n auto models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py --timeout=480 ; fail+=$? diff --git a/tests/scripts/tgg/run_tgg_unit_tests.sh b/tests/scripts/tgg/run_tgg_unit_tests.sh index 3ccbb162a1ff..08f8f08c4211 100755 --- a/tests/scripts/tgg/run_tgg_unit_tests.sh +++ b/tests/scripts/tgg/run_tgg_unit_tests.sh @@ -9,7 +9,7 @@ run_tgg_tests() { ./build/test/ttnn/galaxy_unit_tests_ttnn TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" - pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py + pytest -s tests/ttnn/distributed/test_mesh_device_TGG.py } main() { diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index aad3f3ca8947..1ee2f290314a 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -8,7 +8,7 @@ #include #include -#include "impl/device/mesh_device_view.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" #include "tt_metal/common/logger.hpp" #include "device/tt_arch_types.h" #include "impl/device/device.hpp" @@ -27,7 +27,7 @@ #include "tt_metal/test_utils/stimulus.hpp" #include "tt_metal/detail/persistent_kernel_cache.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" using tt::tt_metal::Device; diff --git a/tests/ttnn/distributed/test_data_parallel_example.py b/tests/ttnn/distributed/test_data_parallel_example.py new file mode 100644 index 000000000000..fb5f59568c0e --- /dev/null +++ b/tests/ttnn/distributed/test_data_parallel_example.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import transformers +import pytest + +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters + + +class TtFalconMLP: + def __init__(self, parameters): + super().__init__() + self.dense_h_to_4h_weights = parameters.dense_h_to_4h.weight + self.dense_4h_to_h_weights = parameters.dense_4h_to_h.weight + + def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: + ff1_linear: ttnn.Tensor = ttnn.linear(x, self.dense_h_to_4h_weights) + gelu = ttnn.gelu(ff1_linear) + ff2_linear: ttnn.Tensor = ttnn.linear(gelu, self.dense_4h_to_h_weights) + + return ff2_linear + + +@pytest.mark.parametrize("mesh_device", [pytest.param((1, 4), id="1x4_grid")], indirect=True) +def test_data_parallel_falcon_mlp(mesh_device): + # Load Falcon MLP model from huggingface + config = transformers.FalconConfig.from_pretrained("tiiuae/falcon-7b-instruct") + model = transformers.models.falcon.modeling_falcon.FalconMLP(config).eval() + + # Initialize hidden states + batch_size, sequence_length = 4, 128 + torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_size, dtype=torch.float32) * 2) - 1 + torch_output = model.forward(torch_hidden_states) + + # Shard input activations on batch dimension to devices in the mesh + with ttnn.distribute(mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)): + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + # Replicate model parameters to devices in the mesh + with ttnn.distribute(mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + + # Initialize Model + ttnn_model = TtFalconMLP(parameters) + ttnn_output = ttnn_model(hidden_states) + + with ttnn.distribute(mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)): + assert_with_pcc(torch_output, ttnn.to_torch(ttnn_output), 0.98) diff --git a/tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py b/tests/ttnn/distributed/test_data_parallel_example_TG.py similarity index 100% rename from tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py rename to tests/ttnn/distributed/test_data_parallel_example_TG.py diff --git a/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py b/tests/ttnn/distributed/test_mesh_device_TGG.py similarity index 100% rename from tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py rename to tests/ttnn/distributed/test_mesh_device_TGG.py diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py similarity index 100% rename from tests/ttnn/multichip_unit_tests/test_multidevice_TG.py rename to tests/ttnn/distributed/test_multidevice_TG.py diff --git a/tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py similarity index 100% rename from tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py rename to tests/ttnn/distributed/test_tensor_parallel_example_T3000.py diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 2d2e99504677..18e635aa4f18 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -7,7 +7,7 @@ #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp" -#include "ttnn/cpp/ttnn/multi_device.hpp" +#include "ttnn/cpp/ttnn/distributed/mesh_device.hpp" #include "ttnn/async_runtime.hpp" #include "ttnn_multi_command_queue_fixture.hpp" @@ -110,7 +110,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { validate_num_tunnels_and_tunnel_depth(); MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ @@ -177,7 +177,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } } } - ttnn::multi_device::close_mesh_device(mesh); + ttnn::distributed::close_mesh_device(mesh); } TEST(GalaxyTests, TestReduceScatterDeadlock) { @@ -187,7 +187,7 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { validate_num_tunnels_and_tunnel_depth(); MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). auto view = MeshDeviceView(*mesh); @@ -273,5 +273,5 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { } } } - ttnn::multi_device::close_mesh_device(mesh); + ttnn::distributed::close_mesh_device(mesh); } diff --git a/tests/ttnn/unit_tests/gtests/test_multi_device.cpp b/tests/ttnn/unit_tests/gtests/test_multi_device.cpp index 1b51605d4bdd..4ecd79b33119 100644 --- a/tests/ttnn/unit_tests/gtests/test_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multi_device.cpp @@ -7,7 +7,7 @@ #include "ttnn/cpp/ttnn/tensor/types.hpp" #include "ttnn/cpp/ttnn/operations/creation.hpp" -namespace ttnn::multi_device::test { +namespace ttnn::distributed::test { using namespace tt::tt_metal; @@ -45,4 +45,4 @@ TEST_F(T3kMultiDeviceFixture, TestGetDistributedTensorConfigFromMultiDeviceStora EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); } -} // namespace ttnn::multi_device::test +} // namespace ttnn::distributed::test diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index c8506b79cd95..10e9bc23b974 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -15,7 +15,7 @@ #include "tests/tt_metal/test_utils/env_vars.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/hostdevcommon/common_values.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" namespace ttnn { @@ -53,7 +53,7 @@ class TTNNFixtureWithDevice : public TTNNFixture { } // namespace ttnn -namespace ttnn::multi_device::test { +namespace ttnn::distributed::test { class T3kMultiDeviceFixture : public ::testing::Test { protected: @@ -78,4 +78,4 @@ class T3kMultiDeviceFixture : public ::testing::Test { std::shared_ptr mesh_device_; }; -} // namespace ttnn::multi_device::test +} // namespace ttnn::distributed::test diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 881075aa279c..431adbb057e7 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -11,6 +11,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/detail) set(TT_METAL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tt_metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_tracking.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device_view.cpp $ $ $ diff --git a/tt_metal/impl/device/mesh_configurations/N300.json b/tt_metal/distributed/mesh_configurations/N300.json similarity index 100% rename from tt_metal/impl/device/mesh_configurations/N300.json rename to tt_metal/distributed/mesh_configurations/N300.json diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/distributed/mesh_configurations/T3000.json similarity index 100% rename from tt_metal/impl/device/mesh_configurations/T3000.json rename to tt_metal/distributed/mesh_configurations/T3000.json diff --git a/tt_metal/impl/device/mesh_configurations/TG.json b/tt_metal/distributed/mesh_configurations/TG.json similarity index 100% rename from tt_metal/impl/device/mesh_configurations/TG.json rename to tt_metal/distributed/mesh_configurations/TG.json diff --git a/tt_metal/impl/device/mesh_configurations/TGG.json b/tt_metal/distributed/mesh_configurations/TGG.json similarity index 100% rename from tt_metal/impl/device/mesh_configurations/TGG.json rename to tt_metal/distributed/mesh_configurations/TGG.json diff --git a/tt_metal/impl/device/mesh_configurations/device.json b/tt_metal/distributed/mesh_configurations/device.json similarity index 100% rename from tt_metal/impl/device/mesh_configurations/device.json rename to tt_metal/distributed/mesh_configurations/device.json diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp similarity index 98% rename from tt_metal/impl/device/mesh_device.cpp rename to tt_metal/distributed/mesh_device.cpp index dfe8926c5177..b97fb7b6ccaa 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" #include #include @@ -11,7 +11,8 @@ #include "tt_metal/common/logger.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" -#include "tt_metal/impl/device/mesh_device_view.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" +#include "tt_metal/distributed/mesh_device.hpp" namespace tt::tt_metal { @@ -20,7 +21,7 @@ using PhysicalCoordinate = eth_coord_t; static std::string get_config_path(const std::string& filename) { std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./"; - return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename; + return root_path + "/tt_metal/distributed/mesh_configurations/" + filename; } static std::unordered_map load_translation_map(const std::string& filename, const std::string& key) { diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp similarity index 99% rename from tt_metal/impl/device/mesh_device.hpp rename to tt_metal/distributed/mesh_device.hpp index f65e095f6d86..6589b3d0fce3 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -10,7 +10,7 @@ #include #include "tt_metal/impl/device/device.hpp" -#include "tt_metal/impl/device/mesh_device_view.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" namespace tt::tt_metal { diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/distributed/mesh_device_view.cpp similarity index 99% rename from tt_metal/impl/device/mesh_device_view.cpp rename to tt_metal/distributed/mesh_device_view.cpp index 48d8e151549c..d5f7e80855af 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/distributed/mesh_device_view.cpp @@ -2,12 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/impl/device/mesh_device_view.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" #include #include -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" namespace tt::tt_metal { diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/distributed/mesh_device_view.hpp similarity index 100% rename from tt_metal/impl/device/mesh_device_view.hpp rename to tt_metal/distributed/mesh_device_view.hpp diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index ab986d1229ed..3c2df98a6b11 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -2,8 +2,6 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/device/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device/device_pool.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device/mesh_device.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device/mesh_device_view.cpp ${CMAKE_CURRENT_SOURCE_DIR}/buffers/buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/buffers/circular_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/buffers/semaphore.cpp diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 3b1420b8fc0f..d22f39a3998e 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -3,9 +3,10 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/config.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/device.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/multi_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/events.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/mesh_device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp diff --git a/ttnn/cpp/pybind11/__init__.cpp b/ttnn/cpp/pybind11/__init__.cpp index 4f263b3760fb..b83e380b1215 100644 --- a/ttnn/cpp/pybind11/__init__.cpp +++ b/ttnn/cpp/pybind11/__init__.cpp @@ -12,9 +12,9 @@ #include "device.hpp" #include "profiler.hpp" #include "events.hpp" -#include "multi_device.hpp" #include "tensor.hpp" #include "reports.hpp" +#include "ttnn/distributed/distributed_pybind.hpp" #include "ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp" #include "ttnn/graph/graph_pybind.hpp" #include "types.hpp" @@ -58,7 +58,7 @@ PYBIND11_MODULE(_ttnn, module) { ttnn::activation::py_module_types(m_activation); ttnn::core::py_module_types(m_core); ttnn::device::py_device_module_types(m_device); - ttnn::multi_device::py_module_types(m_multi_device); + ttnn::distributed::py_module_types(m_multi_device); ttnn::events::py_module_types(m_events); ttnn::reports::py_module_types(m_reports); @@ -80,7 +80,7 @@ PYBIND11_MODULE(_ttnn, module) { ttnn::types::py_module(m_types); ttnn::activation::py_module(m_activation); ttnn::device::py_device_module(m_device); - ttnn::multi_device::py_module(m_multi_device); + ttnn::distributed::py_module(m_multi_device); ttnn::events::py_module(m_events); ttnn::profiler::py_module(m_profiler); ttnn::reports::py_module(m_reports); diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp similarity index 95% rename from ttnn/cpp/pybind11/multi_device.hpp rename to ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7dc90e202ec5..83e71b0a01b0 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -2,19 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#pragma once +#include "ttnn/distributed/distributed_pybind.hpp" -#include -#include - -#include "ttnn/multi_device.hpp" +#include "ttnn/distributed/mesh_device.hpp" #include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" +#include "tt_metal/impl/dispatch/command_queue.hpp" -namespace py = pybind11; -namespace ttnn { +namespace ttnn::distributed { -namespace multi_device { +namespace py = pybind11; void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); @@ -172,6 +171,4 @@ void py_module(py::module& module) { module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring); } -} // namespace multi_device - -} // namespace ttnn +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp new file mode 100644 index 000000000000..e197599e1656 --- /dev/null +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::distributed { + +void py_module_types(py::module& module); +void py_module(py::module& module); + +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/distributed/mesh_device.cpp similarity index 96% rename from ttnn/cpp/ttnn/multi_device.cpp rename to ttnn/cpp/ttnn/distributed/mesh_device.cpp index cb1e48fc7bf2..3d351ade00b6 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/distributed/mesh_device.cpp @@ -2,15 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "multi_device.hpp" +#include "ttnn/distributed/mesh_device.hpp" #include #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_utils.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" -namespace ttnn::multi_device { +namespace ttnn::distributed { std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair& offset, const std::vector& physical_device_ids) { auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type); @@ -82,4 +82,4 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) } } -} // namespace ttnn::multi_device +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/distributed/mesh_device.hpp similarity index 80% rename from ttnn/cpp/ttnn/multi_device.hpp rename to ttnn/cpp/ttnn/distributed/mesh_device.hpp index ecd95d659b73..25f70992d251 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/distributed/mesh_device.hpp @@ -6,14 +6,10 @@ #include -#include "tt_metal/impl/device/mesh_device.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/types.hpp" -using Device = ttnn::Device; - -namespace ttnn { -namespace multi_device { +namespace ttnn::distributed { std::shared_ptr open_mesh_device( const MeshShape& mesh_shape, @@ -33,8 +29,4 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards); std::vector get_t3k_physical_device_ids_ring(); -} // namespace multi_device - -using namespace multi_device; - -} // namespace ttnn +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/events.hpp b/ttnn/cpp/ttnn/events.hpp index 0f1d4fa64b1c..9fd11ccea39f 100644 --- a/ttnn/cpp/ttnn/events.hpp +++ b/ttnn/cpp/ttnn/events.hpp @@ -6,7 +6,7 @@ #include -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" namespace ttnn::events { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp index 4e03e02f59dd..37db34c6d8cd 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp @@ -4,7 +4,7 @@ #include "ttnn/operations/ccl/all_gather/all_gather.hpp" #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" -#include "ttnn/multi_device.hpp" +#include "ttnn/distributed/mesh_device.hpp" namespace ttnn::operations::ccl { diff --git a/ttnn/cpp/ttnn/operations/data_movement/indexed_fill/indexed_fill.cpp b/ttnn/cpp/ttnn/operations/data_movement/indexed_fill/indexed_fill.cpp index 7046e144cf1f..a29628707b1b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/indexed_fill/indexed_fill.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/indexed_fill/indexed_fill.cpp @@ -4,7 +4,7 @@ #include "ttnn/operations/data_movement/indexed_fill/indexed_fill.hpp" #include "ttnn/operations/data_movement/indexed_fill/device/indexed_fill_op.hpp" -#include "ttnn/multi_device.hpp" +#include "ttnn/distributed/mesh_device.hpp" #include "ttnn/common/constants.hpp" namespace ttnn::operations::data_movement{ diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp index 7b434ee6e3a2..ab9343ca4baf 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp @@ -4,7 +4,6 @@ #include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" #include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp" -// #include "ttnn/cpp/ttnn/multi_device.hpp" namespace ttnn { namespace operations::experimental::ccl { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp index 5a897a647f86..572274cda8b6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp @@ -7,7 +7,7 @@ #include "ttnn/decorators.hpp" #include "common/core_coord.h" #include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" -#include "ttnn/cpp/ttnn/multi_device.hpp" +#include "ttnn/cpp/ttnn/distributed/mesh_device.hpp" namespace ttnn { namespace operations::experimental::ccl { diff --git a/ttnn/cpp/ttnn/operations/kv_cache/kv_cache.cpp b/ttnn/cpp/ttnn/operations/kv_cache/kv_cache.cpp index db30d409e7f7..6e536008440b 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache/kv_cache.cpp +++ b/ttnn/cpp/ttnn/operations/kv_cache/kv_cache.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "kv_cache.hpp" -#include "ttnn/multi_device.hpp" +#include "ttnn/distributed/mesh_device.hpp" #include "ttnn/run_operation.hpp" #include "ttnn/operations/kv_cache/device/update_cache_op.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index bf00a516e57e..02baefc971a6 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -20,7 +20,7 @@ #include "tt_metal/impl/buffers/buffer.hpp" #include "tt_metal/impl/tile/tile.hpp" #include "tt_metal/impl/device/device.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device.hpp" #include "tt_metal/tt_stl/reflection.hpp" namespace tt { diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 8dcad39de63b..a0e3bf481fac 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -176,29 +176,8 @@ def manage_config(name, value): from ttnn.profiler import start_tracy_zone, stop_tracy_zone, tracy_message, tracy_frame -from ttnn.multi_device import ( - MeshDevice, - DispatchCoreType, - open_mesh_device, - close_mesh_device, - get_num_pcie_devices, - get_num_devices, - get_pcie_device_ids, - get_device_ids, - create_mesh_device, - synchronize_devices, - TensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, - MeshToTensor, - ConcatMeshToTensor, - ListMeshToTensor, - visualize_mesh_device, - ConcatMesh2dToTensor, - distribute, - MeshType, -) +# TODO: remove this after the distributed module is fully integrated +from ttnn.distributed import * from ttnn.core import ( set_printoptions, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py new file mode 100644 index 000000000000..635a60b04fa7 --- /dev/null +++ b/ttnn/ttnn/distributed/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from .distributed import ( + MeshDevice, + DispatchCoreType, + open_mesh_device, + close_mesh_device, + get_num_pcie_devices, + get_num_devices, + get_pcie_device_ids, + get_device_ids, + create_mesh_device, + synchronize_devices, + TensorToMesh, + ShardTensorToMesh, + ShardTensor2dMesh, + ReplicateTensorToMesh, + MeshToTensor, + ConcatMeshToTensor, + ListMeshToTensor, + visualize_mesh_device, + ConcatMesh2dToTensor, + distribute, + MeshType, +) diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/distributed/distributed.py similarity index 100% rename from ttnn/ttnn/multi_device.py rename to ttnn/ttnn/distributed/distributed.py