Skip to content

Commit

Permalink
Merge pull request #4 from Keysight/issue-generic-interconnect
Browse files Browse the repository at this point in the history
Issue generic interconnect
  • Loading branch information
ajbalogh authored Jul 31, 2024
2 parents 283a2fa + 9f9f4cf commit 5735a56
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.20
0.0.21
4 changes: 4 additions & 0 deletions protos/infra.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ message Pcie {
message NvLink {
}

message Custom {
}

message Switch {
oneof type {
Pcie pcie = 1;
NvLink nvlink = 2;
Custom custom = 3;
}
}

Expand Down
32 changes: 16 additions & 16 deletions src/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ class GenericHost(bld.HostBuilder):
def __init__(
self,
npu_count=1,
nvlink_bandwidth_gbps: int=0
npu_interconnect_bandwidth_gbps: int=0
):
"""Creates a generic device with only npu and nic components that are
connected by a pcie link.
Optionally, npu components can be connected via nvlink using a single nvswitch.
Optionally, npu components within a device can be interconnected via generic links attached to a single generic switch.
name: The name of the generic device
npu_count: The number of npu/nic components in the device.
nvlink_bandwidth_gbps: nvlink bandwidth in gigabits per second. If 0, no nvlink connections will be added to the device.
npu_interconnect_bandwidth_gbps: npu-to-npu interconnect bandwidth in gigabits per second. If 0, no internal npu-to-npu connectivity will be added to the device.
"""
super(GenericHost).__init__()
npu = infra.Component(
Expand All @@ -46,15 +46,15 @@ def __init__(
name="pcie",
type=infra.LinkType.LINK_PCIE,
)
nvlink = infra.Link(
name="nvlink",
type=infra.LinkType.LINK_NVLINK,
bandwidth=infra.Bandwidth(gbps=nvlink_bandwidth_gbps),
npu_interconnect = infra.Link(
name="npu_interconnect",
type=infra.LinkType.LINK_CUSTOM,
bandwidth=infra.Bandwidth(gbps=npu_interconnect_bandwidth_gbps),
)
nvswitch = infra.Component(
name="nvswitch",
npu_interconnect_switch = infra.Component(
name="npu_interconnect_switch",
count=1,
switch=infra.Switch(nvlink=infra.NvLink()),
switch=infra.Switch(custom=infra.Custom()),
)

links = { pcie.name: pcie }
Expand All @@ -77,18 +77,18 @@ def __init__(
)
)

# Add nvlink connections if bandwidth was provided
if nvlink_bandwidth_gbps > 0:
components[nvswitch.name] = nvswitch
links[nvlink.name] = nvlink
# Add npu_interconnect connections if bandwidth was provided
if npu_interconnect_bandwidth_gbps > 0:
components[npu_interconnect_switch.name] = npu_interconnect_switch
links[npu_interconnect.name] = npu_interconnect
for npu_idx_a in range(npu_count):
connections.append(
infra.ComponentConnection(
link=infra.ComponentLink(
c1=npu.name,
c1_index=npu_idx_a,
link=nvlink.name,
c2=nvswitch.name,
link=npu_interconnect.name,
c2=npu_interconnect_switch.name,
c2_index=0,
)
)
Expand Down
10 changes: 5 additions & 5 deletions src/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ def test_generic_host_no_params():
host = GenericHost()
assert host.get_component("npu") is not None
assert host.get_component("nic") is not None
assert "nvlink" not in host._device.links
assert "npu_interconnect" not in host._device.links

def test_generic_host_with_params():
npu_count = 4
host = GenericHost(npu_count=npu_count, nvlink_bandwidth_gbps=600)
assert "nvlink" in host._device.links
assert host._device.links["nvlink"].type == infra.LINK_NVLINK
host = GenericHost(npu_count=npu_count, npu_interconnect_bandwidth_gbps=600)
assert "npu_interconnect" in host._device.links
assert host._device.links["npu_interconnect"].type == infra.LINK_CUSTOM

seen_map = {}
for npu_index in range(npu_count):
seen_map[npu_index] = False

for connection in host._device.connections:
if connection.link.c1 == "npu" and connection.link.c2 == "nvswitch":
if connection.link.c1 == "npu" and connection.link.c2 == "npu_interconnect_switch":
npu_index = connection.link.c1_index
assert npu_index in seen_map
assert not seen_map[npu_index]
Expand Down

0 comments on commit 5735a56

Please sign in to comment.