diff --git a/conftest.py b/conftest.py index 7df64b2c750..6c617cc1e7a 100644 --- a/conftest.py +++ b/conftest.py @@ -326,9 +326,6 @@ def device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") @@ -354,9 +351,6 @@ def pcie_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_pcie_devices_requested = len(device_ids) - if num_pcie_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh( ttnn.DeviceGrid(1, num_pcie_devices_requested), device_ids[:num_pcie_devices_requested] ) @@ -386,9 +380,6 @@ def t3k_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index 6301284023c..192babe1f3e 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -52,6 +52,7 @@ def torch_model(): @pytest.mark.parametrize( "device_mesh", [ + 1, 2, ], indirect=True, diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index c8b7386279d..501840cfe5c 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -159,6 +159,8 @@ def test_multi_device_replicate(device_mesh, shape, layout, memory_config): def test_ttnn_multi_device_all_gather(pcie_device_mesh): """Multidevice API test for ttnn.all_gather CCL operation""" + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") full_tensor = torch.rand((1, 1, 32, 32 * pcie_device_mesh.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=3)) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 2f5cc0e8252..35a3bf71a5b 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -278,8 +278,8 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(True) + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") # Create input tensors that cause OOM during op execution # Explictly deallocate buffers after each op to ensure we don't run OOM. @@ -311,9 +311,6 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): ttnn_output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) ) - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(False) - @pytest.mark.parametrize("scalar", [3]) @pytest.mark.parametrize("size", [64]) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index e7527971348..aa350b6d1e7 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -16,6 +16,9 @@ @pytest.mark.parametrize("use_all_gather", [True, False]) @pytest.mark.parametrize("enable_async", [True, False]) def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async): + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) @@ -103,6 +106,9 @@ def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enabl if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32): pytest.skip("This configuration is not working with all-gather") + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index c59e12608b5..694138fe1f8 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -604,6 +604,8 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { auto& worker = workers[worker_index]; worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index]() mutable { TT_ASSERT( + this->storage_type() == StorageType::OWNED || + this->storage_type() == StorageType::BORROWED|| this->storage_type() == StorageType::MULTI_DEVICE_HOST && "to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers " "are specified"); diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index c9d96d91cd6..f6cd958d791 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -363,16 +363,20 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { - const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); - if (tensor_storage.has_buffer_for_device_id(device_id)) { - return Tensor{ - DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, - multi_device_tensor.get_legacy_shape(), - multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout() - }; + if (std::holds_alternative(multi_device_tensor.get_storage())) { + const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); + if (tensor_storage.has_buffer_for_device_id(device_id)) { + return Tensor{ + DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, + multi_device_tensor.get_legacy_shape(), + multi_device_tensor.get_dtype(), + multi_device_tensor.get_layout()}; + } + } else if (std::holds_alternative(multi_device_tensor.get_storage())) { + return multi_device_tensor; } - TT_THROW("Device not found in multi-device tensor"); + + TT_THROW("User is trying to access a device tensor that is not on device."); } Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index 41943189363..1a36bad3086 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -46,6 +46,8 @@ std::vector get_device_tensors(const ttnn::Tensor& tensor) { tensors.push_back(shard); } return tensors; + } else { + return {tensor}; } TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!"); }