diff --git a/conftest.py b/conftest.py index 6e43d1a6499..9406e5a55f0 100644 --- a/conftest.py +++ b/conftest.py @@ -103,16 +103,18 @@ def get_dispatch_core_type(): return dispatch_core_type -def get_dispatch_core_config(device_params): +def get_updated_device_params(device_params): import ttnn dispatch_core_type = get_dispatch_core_type() - dispatch_core_axis = device_params.pop( + new_device_params = device_params.copy() + dispatch_core_axis = new_device_params.pop( "dispatch_core_axis", ttnn.DispatchCoreAxis.COL if os.environ["ARCH_NAME"] == "blackhole" else ttnn.DispatchCoreAxis.ROW, ) dispatch_core_config = ttnn.DispatchCoreConfig(dispatch_core_type, dispatch_core_axis) - return dispatch_core_config + new_device_params["dispatch_core_config"] = dispatch_core_config + return new_device_params @pytest.fixture(scope="function") @@ -129,8 +131,8 @@ def device(request, device_params): num_devices = ttnn.GetNumPCIeDevices() assert device_id < num_devices, "CreateDevice not supported for non-mmio device" - dispatch_core_config = get_dispatch_core_config(device_params) - device = ttnn.CreateDevice(device_id=device_id, dispatch_core_config=dispatch_core_config, **device_params) + updated_device_params = get_updated_device_params(device_params) + device = ttnn.CreateDevice(device_id=device_id, **updated_device_params) ttnn.SetDefaultDevice(device) yield device @@ -150,8 +152,8 @@ def pcie_devices(request, device_params): request.node.pci_ids = device_ids # Get only physical devices - dispatch_core_config = get_dispatch_core_config(device_params) - devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params) + updated_device_params = get_updated_device_params(device_params) + devices = ttnn.CreateDevices(device_ids, **updated_device_params) yield [devices[i] for i in range(num_devices)] @@ -170,8 +172,8 @@ def all_devices(request, device_params): request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids] # Get only physical devices - dispatch_core_config = get_dispatch_core_config(device_params) - devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params) + updated_device_params = get_updated_device_params(device_params) + devices = ttnn.CreateDevices(device_ids, **updated_device_params) yield [devices[i] for i in range(num_devices)] @@ -222,10 +224,8 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]] - dispatch_core_config = get_dispatch_core_config(device_params) - mesh_device = ttnn.open_mesh_device( - mesh_shape=mesh_shape, dispatch_core_config=dispatch_core_config, **device_params - ) + updated_device_params = get_updated_device_params(device_params) + mesh_device = ttnn.open_mesh_device(mesh_shape=mesh_shape, **updated_device_params) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") yield mesh_device @@ -252,11 +252,10 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic request.node.pci_ids = device_ids[:num_pcie_devices_requested] - dispatch_core_config = get_dispatch_core_config(device_params) + updated_device_params = get_updated_device_params(device_params) mesh_device = ttnn.open_mesh_device( mesh_shape=ttnn.MeshShape(2, 2), - dispatch_core_config=dispatch_core_config, - **device_params, + **updated_device_params, offset=ttnn.MeshOffset(0, 1), mesh_type=ttnn.MeshType.Ring, ) @@ -278,11 +277,10 @@ def n300_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic if ttnn.get_num_devices() < 2: pytest.skip() - dispatch_core_config = get_dispatch_core_config(device_params) + updated_device_params = get_updated_device_params(device_params) mesh_device = ttnn.open_mesh_device( mesh_shape=ttnn.MeshShape(1, 2), - dispatch_core_config=dispatch_core_config, - **device_params, + **updated_device_params, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -303,11 +301,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device pytest.skip() request.node.pci_ids = ttnn.get_pcie_device_ids() - dispatch_core_config = get_dispatch_core_config(device_params) + updated_device_params = get_updated_device_params(device_params) mesh_device = ttnn.open_mesh_device( mesh_shape=ttnn.MeshShape(2, 4), - dispatch_core_config=dispatch_core_config, - **device_params, + **updated_device_params, mesh_type=ttnn.MeshType.Ring, )