Skip to content

Commit

Permalink
#0: Fix get_dispatch_core_config in conftest.py to not modify the dev…
Browse files Browse the repository at this point in the history
…ice_params to not affect subsequent tests

Changed to a new function get_updated_device_params that returns a new updated copy of device_params instead
  • Loading branch information
tt-aho committed Dec 26, 2024
1 parent 82b35a3 commit c0c1bb6
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,18 @@ def get_dispatch_core_type():
return dispatch_core_type


def get_dispatch_core_config(device_params):
def get_updated_device_params(device_params):
import ttnn

dispatch_core_type = get_dispatch_core_type()
dispatch_core_axis = device_params.pop(
new_device_params = device_params.copy()
dispatch_core_axis = new_device_params.pop(
"dispatch_core_axis",
ttnn.DispatchCoreAxis.COL if os.environ["ARCH_NAME"] == "blackhole" else ttnn.DispatchCoreAxis.ROW,
)
dispatch_core_config = ttnn.DispatchCoreConfig(dispatch_core_type, dispatch_core_axis)
return dispatch_core_config
new_device_params["dispatch_core_config"] = dispatch_core_config
return new_device_params


@pytest.fixture(scope="function")
Expand All @@ -129,8 +131,8 @@ def device(request, device_params):

num_devices = ttnn.GetNumPCIeDevices()
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
dispatch_core_config = get_dispatch_core_config(device_params)
device = ttnn.CreateDevice(device_id=device_id, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
device = ttnn.CreateDevice(device_id=device_id, **updated_device_params)
ttnn.SetDefaultDevice(device)

yield device
Expand All @@ -150,8 +152,8 @@ def pcie_devices(request, device_params):
request.node.pci_ids = device_ids

# Get only physical devices
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
devices = ttnn.CreateDevices(device_ids, **updated_device_params)

yield [devices[i] for i in range(num_devices)]

Expand All @@ -170,8 +172,8 @@ def all_devices(request, device_params):
request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids]

# Get only physical devices
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
devices = ttnn.CreateDevices(device_ids, **updated_device_params)

yield [devices[i] for i in range(num_devices)]

Expand Down Expand Up @@ -222,10 +224,8 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=mesh_shape, dispatch_core_config=dispatch_core_config, **device_params
)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(mesh_shape=mesh_shape, **updated_device_params)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -252,11 +252,10 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

request.node.pci_ids = device_ids[:num_pcie_devices_requested]

dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
)
Expand All @@ -278,11 +277,10 @@ def n300_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
if ttnn.get_num_devices() < 2:
pytest.skip()

dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(1, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -303,11 +301,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
pytest.skip()

request.node.pci_ids = ttnn.get_pcie_device_ids()
dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 4),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
mesh_type=ttnn.MeshType.Ring,
)

Expand Down

0 comments on commit c0c1bb6

Please sign in to comment.