Skip to content

Commit

Permalink
Rename io_interconnect to npu_interconnect.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricefrog committed Jul 31, 2024
1 parent 28187e5 commit 9f9f4cf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
26 changes: 13 additions & 13 deletions src/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GenericHost(bld.HostBuilder):
def __init__(
self,
npu_count=1,
io_interconnect_bandwidth_gbps: int=0
npu_interconnect_bandwidth_gbps: int=0
):
"""Creates a generic device with only npu and nic components that are
connected by a pcie link.
Expand All @@ -29,7 +29,7 @@ def __init__(
name: The name of the generic device
npu_count: The number of npu/nic components in the device.
io_interconnect_bandwidth_gbps: npu-to-npu interconnect bandwidth in gigabits per second. If 0, no internal npu-to-npu connectivity will be added to the device.
npu_interconnect_bandwidth_gbps: npu-to-npu interconnect bandwidth in gigabits per second. If 0, no internal npu-to-npu connectivity will be added to the device.
"""
super(GenericHost).__init__()
npu = infra.Component(
Expand All @@ -46,13 +46,13 @@ def __init__(
name="pcie",
type=infra.LinkType.LINK_PCIE,
)
io_interconnect = infra.Link(
name="io_interconnect",
npu_interconnect = infra.Link(
name="npu_interconnect",
type=infra.LinkType.LINK_CUSTOM,
bandwidth=infra.Bandwidth(gbps=io_interconnect_bandwidth_gbps),
bandwidth=infra.Bandwidth(gbps=npu_interconnect_bandwidth_gbps),
)
io_interconnect_switch = infra.Component(
name="io_interconnect_switch",
npu_interconnect_switch = infra.Component(
name="npu_interconnect_switch",
count=1,
switch=infra.Switch(custom=infra.Custom()),
)
Expand All @@ -77,18 +77,18 @@ def __init__(
)
)

# Add io_interconnect connections if bandwidth was provided
if io_interconnect_bandwidth_gbps > 0:
components[io_interconnect_switch.name] = io_interconnect_switch
links[io_interconnect.name] = io_interconnect
# Add npu_interconnect connections if bandwidth was provided
if npu_interconnect_bandwidth_gbps > 0:
components[npu_interconnect_switch.name] = npu_interconnect_switch
links[npu_interconnect.name] = npu_interconnect
for npu_idx_a in range(npu_count):
connections.append(
infra.ComponentConnection(
link=infra.ComponentLink(
c1=npu.name,
c1_index=npu_idx_a,
link=io_interconnect.name,
c2=io_interconnect_switch.name,
link=npu_interconnect.name,
c2=npu_interconnect_switch.name,
c2_index=0,
)
)
Expand Down
10 changes: 5 additions & 5 deletions src/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ def test_generic_host_no_params():
host = GenericHost()
assert host.get_component("npu") is not None
assert host.get_component("nic") is not None
assert "io_interconnect" not in host._device.links
assert "npu_interconnect" not in host._device.links

def test_generic_host_with_params():
npu_count = 4
host = GenericHost(npu_count=npu_count, io_interconnect_bandwidth_gbps=600)
assert "io_interconnect" in host._device.links
assert host._device.links["io_interconnect"].type == infra.LINK_CUSTOM
host = GenericHost(npu_count=npu_count, npu_interconnect_bandwidth_gbps=600)
assert "npu_interconnect" in host._device.links
assert host._device.links["npu_interconnect"].type == infra.LINK_CUSTOM

seen_map = {}
for npu_index in range(npu_count):
seen_map[npu_index] = False

for connection in host._device.connections:
if connection.link.c1 == "npu" and connection.link.c2 == "io_interconnect_switch":
if connection.link.c1 == "npu" and connection.link.c2 == "npu_interconnect_switch":
npu_index = connection.link.c1_index
assert npu_index in seen_map
assert not seen_map[npu_index]
Expand Down

0 comments on commit 9f9f4cf

Please sign in to comment.