Skip to content

Commit

Permalink
#13454: Refactor files under tt_metal/distributed and ttnn/distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 4, 2024
1 parent 85a86a6 commit cf04180
Show file tree
Hide file tree
Showing 40 changed files with 156 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ ttnn.close_device(device)

4. **Executing TT-NN Falcon-7B MLP Module on MeshDevice with Data Parallel**

Full code example can be found in `tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py`
Full code example can be found in `tests/ttnn/distributed/test_data_parallel_example_TG.py`

```py
# Load Falcon MLP model from huggingface
Expand Down Expand Up @@ -507,7 +507,7 @@ model = transformers.models.falcon.modeling_falcon.FalconMLP(config).eval()

3. **Executing TT-NN Falcon-7B MLP Module on MeshDevice with Tensor Parallel**

See full code example in `tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py`
See full code example in `tests/ttnn/distributed/test_tensor_parallel_example_T3000.py`

```py
# Initialize hidden states
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/t3000/run_t3000_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ run_t3000_ttnn_tests() {
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_events.py ; fail+=$?
pytest -n auto tests/ttnn/unit_tests/test_multi_device.py ; fail+=$?
pytest -n auto tests/ttnn/unit_tests/test_multi_device_async.py ; fail+=$?
pytest tests/ttnn/multichip_unit_tests/test_tensor_parallel_example_T3000.py ; fail+=$?
pytest tests/ttnn/distributed/test_tensor_parallel_example_T3000.py ; fail+=$?
# Record the end time
end_time=$(date +%s)
duration=$((end_time - start_time))
Expand Down
4 changes: 2 additions & 2 deletions tests/scripts/tg/run_tg_frequent_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ run_tg_tests() {
# Add tests here
echo "LOG_METAL: running run_tg_frequent_tests"

pytest -n auto tests/ttnn/multichip_unit_tests/test_data_parallel_example_TG.py --timeout=900 ; fail+=$?
pytest -n auto tests/ttnn/multichip_unit_tests/test_multidevice_TG.py --timeout=900 ; fail+=$?
pytest -n auto tests/ttnn/distributed/test_data_parallel_example_TG.py --timeout=900 ; fail+=$?
pytest -n auto tests/ttnn/distributed/test_multidevice_TG.py --timeout=900 ; fail+=$?
pytest -n auto tests/ttnn/unit_tests/test_multi_device_trace_TG.py --timeout=900 ; fail+=$?
pytest -n auto models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py --timeout=300 ; fail+=$?
pytest -n auto models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py --timeout=480 ; fail+=$?
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/tgg/run_tgg_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ run_tgg_tests() {
./build/test/ttnn/galaxy_unit_tests_ttnn
TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
pytest -s tests/ttnn/distributed/test_mesh_device_TGG.py
}

main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <random>
#include <tuple>

#include "impl/device/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/common/logger.hpp"
#include "device/tt_arch_types.h"
#include "impl/device/device.hpp"
Expand All @@ -27,7 +27,7 @@
#include "tt_metal/test_utils/stimulus.hpp"

#include "tt_metal/detail/persistent_kernel_cache.hpp"
#include "tt_metal/impl/device/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

using tt::tt_metal::Device;

Expand Down
60 changes: 60 additions & 0 deletions tests/ttnn/distributed/test_data_parallel_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch
import transformers
import pytest

from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import preprocess_model_parameters


class TtFalconMLP:
def __init__(self, parameters):
super().__init__()
self.dense_h_to_4h_weights = parameters.dense_h_to_4h.weight
self.dense_4h_to_h_weights = parameters.dense_4h_to_h.weight

def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor:
ff1_linear: ttnn.Tensor = ttnn.linear(x, self.dense_h_to_4h_weights)
gelu = ttnn.gelu(ff1_linear)
ff2_linear: ttnn.Tensor = ttnn.linear(gelu, self.dense_4h_to_h_weights)

return ff2_linear


@pytest.mark.parametrize("mesh_device", [pytest.param((1, 4), id="1x4_grid")], indirect=True)
def test_data_parallel_falcon_mlp(mesh_device):
# Load Falcon MLP model from huggingface
config = transformers.FalconConfig.from_pretrained("tiiuae/falcon-7b-instruct")
model = transformers.models.falcon.modeling_falcon.FalconMLP(config).eval()

# Initialize hidden states
batch_size, sequence_length = 4, 128
torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_size, dtype=torch.float32) * 2) - 1
torch_output = model.forward(torch_hidden_states)

# Shard input activations on batch dimension to devices in the mesh
with ttnn.distribute(mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)):
hidden_states = ttnn.from_torch(
torch_hidden_states,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
)

# Replicate model parameters to devices in the mesh
with ttnn.distribute(mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
device=mesh_device,
)

# Initialize Model
ttnn_model = TtFalconMLP(parameters)
ttnn_output = ttnn_model(hidden_states)

with ttnn.distribute(mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)):
assert_with_pcc(torch_output, ttnn.to_torch(ttnn_output), 0.98)
10 changes: 5 additions & 5 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp"
#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "ttnn/cpp/ttnn/multi_device.hpp"
#include "ttnn/cpp/ttnn/distributed/mesh_device.hpp"
#include "ttnn/async_runtime.hpp"
#include "ttnn_multi_command_queue_fixture.hpp"

Expand Down Expand Up @@ -110,7 +110,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand Down Expand Up @@ -177,7 +177,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
}
}
ttnn::multi_device::close_mesh_device(mesh);
ttnn::distributed::close_mesh_device(mesh);
}

TEST(GalaxyTests, TestReduceScatterDeadlock) {
Expand All @@ -187,7 +187,7 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
auto view = MeshDeviceView(*mesh);
Expand Down Expand Up @@ -273,5 +273,5 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
}
}
}
ttnn::multi_device::close_mesh_device(mesh);
ttnn::distributed::close_mesh_device(mesh);
}
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/gtests/test_multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "ttnn/cpp/ttnn/tensor/types.hpp"
#include "ttnn/cpp/ttnn/operations/creation.hpp"

namespace ttnn::multi_device::test {
namespace ttnn::distributed::test {

using namespace tt::tt_metal;

Expand Down Expand Up @@ -45,4 +45,4 @@ TEST_F(T3kMultiDeviceFixture, TestGetDistributedTensorConfigFromMultiDeviceStora
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

} // namespace ttnn::multi_device::test
} // namespace ttnn::distributed::test
6 changes: 3 additions & 3 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "tests/tt_metal/test_utils/env_vars.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/hostdevcommon/common_values.hpp"
#include "tt_metal/impl/device/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace ttnn {

Expand Down Expand Up @@ -53,7 +53,7 @@ class TTNNFixtureWithDevice : public TTNNFixture {
} // namespace ttnn


namespace ttnn::multi_device::test {
namespace ttnn::distributed::test {

class T3kMultiDeviceFixture : public ::testing::Test {
protected:
Expand All @@ -78,4 +78,4 @@ class T3kMultiDeviceFixture : public ::testing::Test {
std::shared_ptr<MeshDevice> mesh_device_;
};

} // namespace ttnn::multi_device::test
} // namespace ttnn::distributed::test
2 changes: 2 additions & 0 deletions tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/detail)
set(TT_METAL_OBJECTS
${CMAKE_CURRENT_SOURCE_DIR}/tt_metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_tracking.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device_view.cpp
$<TARGET_OBJECTS:profiler>
$<TARGET_OBJECTS:common>
$<TARGET_OBJECTS:jit_build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/impl/device/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

#include <memory>
#include <unordered_map>
Expand All @@ -11,7 +11,8 @@
#include "tt_metal/common/logger.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/impl/device/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {

Expand All @@ -20,7 +21,7 @@ using PhysicalCoordinate = eth_coord_t;

static std::string get_config_path(const std::string& filename) {
std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./";
return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename;
return root_path + "/tt_metal/distributed/mesh_configurations/" + filename;
}

static std::unordered_map<LogicalCoordinate, PhysicalCoordinate> load_translation_map(const std::string& filename, const std::string& key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <vector>

#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/impl/device/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"

namespace tt::tt_metal {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/impl/device/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"

#include <algorithm>
#include <stdexcept>

#include "tt_metal/impl/device/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
set(IMPL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/device/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device/device_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device/mesh_device_view.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/circular_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/buffers/semaphore.cpp
Expand Down
3 changes: 2 additions & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/multi_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/events.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/pybind11/__init__.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#include "device.hpp"
#include "profiler.hpp"
#include "events.hpp"
#include "multi_device.hpp"
#include "tensor.hpp"
#include "reports.hpp"
#include "ttnn/distributed/distributed_pybind.hpp"
#include "ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp"
#include "ttnn/graph/graph_pybind.hpp"
#include "types.hpp"
Expand Down Expand Up @@ -58,7 +58,7 @@ PYBIND11_MODULE(_ttnn, module) {
ttnn::activation::py_module_types(m_activation);
ttnn::core::py_module_types(m_core);
ttnn::device::py_device_module_types(m_device);
ttnn::multi_device::py_module_types(m_multi_device);
ttnn::distributed::py_module_types(m_multi_device);
ttnn::events::py_module_types(m_events);
ttnn::reports::py_module_types(m_reports);

Expand All @@ -80,7 +80,7 @@ PYBIND11_MODULE(_ttnn, module) {
ttnn::types::py_module(m_types);
ttnn::activation::py_module(m_activation);
ttnn::device::py_device_module(m_device);
ttnn::multi_device::py_module(m_multi_device);
ttnn::distributed::py_module(m_multi_device);
ttnn::events::py_module(m_events);
ttnn::profiler::py_module(m_profiler);
ttnn::reports::py_module(m_reports);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "ttnn/distributed/distributed_pybind.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/multi_device.hpp"
#include "ttnn/distributed/mesh_device.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/types.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"

namespace py = pybind11;

namespace ttnn {
namespace ttnn::distributed {

namespace multi_device {
namespace py = pybind11;

void py_module_types(py::module& module) {
py::class_<MeshDevice, std::shared_ptr<MeshDevice>>(module, "MeshDevice");
Expand Down Expand Up @@ -172,6 +171,4 @@ void py_module(py::module& module) {
module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring);
}

} // namespace multi_device

} // namespace ttnn
} // namespace ttnn::distributed
15 changes: 15 additions & 0 deletions ttnn/cpp/ttnn/distributed/distributed_pybind.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "pybind11/pybind_fwd.hpp"

namespace py = pybind11;

namespace ttnn::distributed {

void py_module_types(py::module& module);
void py_module(py::module& module);

} // namespace ttnn::distributed
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "multi_device.hpp"
#include "ttnn/distributed/mesh_device.hpp"

#include <memory>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "tt_metal/impl/device/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace ttnn::multi_device {
namespace ttnn::distributed {

std::shared_ptr<MeshDevice> open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair<size_t, size_t>& offset, const std::vector<int>& physical_device_ids) {
auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type);
Expand Down Expand Up @@ -82,4 +82,4 @@ Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards)
}
}

} // namespace ttnn::multi_device
} // namespace ttnn::distributed
Loading

0 comments on commit cf04180

Please sign in to comment.