Skip to content

Commit

Permalink
#9088: support multi-device mesh with single device
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jun 4, 2024
1 parent 54b93f2 commit 03c757e
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 23 deletions.
9 changes: 0 additions & 9 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0):
except (ValueError, AttributeError):
num_devices_requested = len(device_ids)

if num_devices_requested <= 1:
pytest.skip("Requires multiple devices to run")

device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested])

logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created")
Expand All @@ -354,9 +351,6 @@ def pcie_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0):
except (ValueError, AttributeError):
num_pcie_devices_requested = len(device_ids)

if num_pcie_devices_requested <= 1:
pytest.skip("Requires multiple devices to run")

device_mesh = ttnn.open_device_mesh(
ttnn.DeviceGrid(1, num_pcie_devices_requested), device_ids[:num_pcie_devices_requested]
)
Expand Down Expand Up @@ -386,9 +380,6 @@ def t3k_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0):
except (ValueError, AttributeError):
num_devices_requested = len(device_ids)

if num_devices_requested <= 1:
pytest.skip("Requires multiple devices to run")

device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested])

logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def torch_model():
@pytest.mark.parametrize(
"device_mesh",
[
1,
2,
],
indirect=True,
Expand Down
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def test_multi_device_replicate(device_mesh, shape, layout, memory_config):

def test_ttnn_multi_device_all_gather(pcie_device_mesh):
"""Multidevice API test for ttnn.all_gather CCL operation"""
if pcie_device_mesh.get_num_devices() <= 1:
pytest.skip("Requires multiple devices to run")
full_tensor = torch.rand((1, 1, 32, 32 * pcie_device_mesh.get_num_devices()), dtype=torch.bfloat16)

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=3))
Expand Down
7 changes: 2 additions & 5 deletions tests/ttnn/unit_tests/test_multi_device_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh):
"""Multidevice API: Ensure that deallocating multi-device tensors works as expected"""
from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh

for device in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device).enable_async(True)
if pcie_device_mesh.get_num_devices() <= 1:
pytest.skip("Requires multiple devices to run")

# Create input tensors that cause OOM during op execution
# Explictly deallocate buffers after each op to ensure we don't run OOM.
Expand Down Expand Up @@ -311,9 +311,6 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh):
ttnn_output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0)
)

for device in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device).enable_async(False)


@pytest.mark.parametrize("scalar", [3])
@pytest.mark.parametrize("size", [64])
Expand Down
6 changes: 6 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
@pytest.mark.parametrize("use_all_gather", [True, False])
@pytest.mark.parametrize("enable_async", [True, False])
def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async):
if pcie_device_mesh.get_num_devices() <= 1:
pytest.skip("This test requires multiple devices")

# Trace requires program cache to be enabled
for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(enable_async)
Expand Down Expand Up @@ -103,6 +106,9 @@ def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enabl
if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32):
pytest.skip("This configuration is not working with all-gather")

if pcie_device_mesh.get_num_devices() <= 1:
pytest.skip("This test requires multiple devices")

# Trace requires program cache to be enabled
for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(enable_async)
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,8 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const {
auto& worker = workers[worker_index];
worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index]() mutable {
TT_ASSERT(
this->storage_type() == StorageType::OWNED ||
this->storage_type() == StorageType::BORROWED||
this->storage_type() == StorageType::MULTI_DEVICE_HOST &&
"to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers "
"are specified");
Expand Down
22 changes: 13 additions & 9 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,20 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu
bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; }

Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) {
const auto& tensor_storage = std::get<MultiDeviceStorage>(multi_device_tensor.get_storage());
if (tensor_storage.has_buffer_for_device_id(device_id)) {
return Tensor{
DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)},
multi_device_tensor.get_legacy_shape(),
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()
};
if (std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(multi_device_tensor.get_storage())) {
const auto& tensor_storage = std::get<MultiDeviceStorage>(multi_device_tensor.get_storage());
if (tensor_storage.has_buffer_for_device_id(device_id)) {
return Tensor{
DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)},
multi_device_tensor.get_legacy_shape(),
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()};
}
} else if (std::holds_alternative<tt::tt_metal::DeviceStorage>(multi_device_tensor.get_storage())) {
return multi_device_tensor;
}
TT_THROW("Device not found in multi-device tensor");

TT_THROW("User is trying to access a device tensor that is not on device.");
}

Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device) {
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ std::vector<ttnn::Tensor> get_device_tensors(const ttnn::Tensor& tensor) {
tensors.push_back(shard);
}
return tensors;
} else {
return {tensor};
}
TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!");
}
Expand Down

0 comments on commit 03c757e

Please sign in to comment.