Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Fix get_dispatch_core_config in conftest.py to not modify the device_params to not affect subsequent tests #16290

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,18 @@ def get_dispatch_core_type():
return dispatch_core_type


def get_dispatch_core_config(device_params):
def get_updated_device_params(device_params):
import ttnn

dispatch_core_type = get_dispatch_core_type()
dispatch_core_axis = device_params.pop(
new_device_params = device_params.copy()
dispatch_core_axis = new_device_params.pop(
"dispatch_core_axis",
ttnn.DispatchCoreAxis.COL if os.environ["ARCH_NAME"] == "blackhole" else ttnn.DispatchCoreAxis.ROW,
)
dispatch_core_config = ttnn.DispatchCoreConfig(dispatch_core_type, dispatch_core_axis)
return dispatch_core_config
new_device_params["dispatch_core_config"] = dispatch_core_config
return new_device_params


@pytest.fixture(scope="function")
Expand All @@ -129,8 +131,8 @@ def device(request, device_params):

num_devices = ttnn.GetNumPCIeDevices()
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
dispatch_core_config = get_dispatch_core_config(device_params)
device = ttnn.CreateDevice(device_id=device_id, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
device = ttnn.CreateDevice(device_id=device_id, **updated_device_params)
ttnn.SetDefaultDevice(device)

yield device
Expand All @@ -150,8 +152,8 @@ def pcie_devices(request, device_params):
request.node.pci_ids = device_ids

# Get only physical devices
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
devices = ttnn.CreateDevices(device_ids, **updated_device_params)

yield [devices[i] for i in range(num_devices)]

Expand All @@ -170,8 +172,8 @@ def all_devices(request, device_params):
request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids]

# Get only physical devices
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)
updated_device_params = get_updated_device_params(device_params)
devices = ttnn.CreateDevices(device_ids, **updated_device_params)

yield [devices[i] for i in range(num_devices)]

Expand Down Expand Up @@ -222,10 +224,8 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=mesh_shape, dispatch_core_config=dispatch_core_config, **device_params
)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(mesh_shape=mesh_shape, **updated_device_params)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -252,11 +252,10 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

request.node.pci_ids = device_ids[:num_pcie_devices_requested]

dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
)
Expand All @@ -278,11 +277,10 @@ def n300_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
if ttnn.get_num_devices() < 2:
pytest.skip()

dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(1, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -303,11 +301,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
pytest.skip()

request.node.pci_ids = ttnn.get_pcie_device_ids()
dispatch_core_config = get_dispatch_core_config(device_params)
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 4),
dispatch_core_config=dispatch_core_config,
**device_params,
**updated_device_params,
mesh_type=ttnn.MeshType.Ring,
)

Expand Down
Loading