From 59ab58165b5237ed98084f7976a91d76e3ed798f Mon Sep 17 00:00:00 2001 From: anbalogh Date: Mon, 11 Nov 2024 19:29:17 +0000 Subject: [PATCH] remove builders, add tests --- Makefile | 2 +- protos/bind.proto | 16 +------ protos/infra.proto | 50 +++++++--------------- protos/service.proto | 33 ++++++++++++++ src/infra_service.py | 25 ----------- src/service.py | 50 ++++++++++++++++++++++ src/tests/conftest.py | 10 ++++- src/tests/test_device.py | 83 ------------------------------------ src/tests/test_rack_plane.py | 81 ----------------------------------- src/tests/test_service.py | 36 ++++++++++++++++ 10 files changed, 146 insertions(+), 240 deletions(-) create mode 100644 protos/service.proto delete mode 100644 src/infra_service.py create mode 100644 src/service.py delete mode 100644 src/tests/test_device.py delete mode 100644 src/tests/test_rack_plane.py create mode 100644 src/tests/test_service.py diff --git a/Makefile b/Makefile index edf774c..03c6338 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ build: ## compile all .proto files and generate artifacts --python_out=$(GENERATED_DIR) \ --pyi_out=$(GENERATED_DIR) \ --grpc_python_out=$(GENERATED_DIR) \ - et_def.proto infra.proto bind.proto + et_def.proto infra.proto bind.proto service.proto python3 -m pip uninstall -y keysight-chakra python3 setup.py bdist_wheel python3 -m pip install --no-cache . diff --git a/protos/bind.proto b/protos/bind.proto index cc4626d..5c89745 100644 --- a/protos/bind.proto +++ b/protos/bind.proto @@ -8,7 +8,6 @@ syntax = "proto3"; package bind; import "google/protobuf/any.proto"; -import "infra.proto"; message Data { // Use this field to provide descriptive information about the message @@ -53,7 +52,7 @@ message Binding { oneof infrastructure_path { // the binding is global to the infrastructure and the value provided here // is for informational purposes only - string global = 1; + string infrastructure = 1; // binding is specific to an Infrastructure.inventory.device.name // example: dgx @@ -102,16 +101,3 @@ message Bindings { // A list of user defined information for specific endpoints repeated Binding bindings = 1; } - -message ValidationRequest { - infra.Infrastructure infrastructure = 1; - repeated Binding bindings = 2; -} - -message ValidationResponse { - repeated Binding invalid_bindings = 1; -} - -service BindService { - rpc Validate(ValidationRequest) returns (ValidationResponse); -} \ No newline at end of file diff --git a/protos/infra.proto b/protos/infra.proto index bab57b5..6d8a791 100644 --- a/protos/infra.proto +++ b/protos/infra.proto @@ -66,10 +66,10 @@ message Switch { // Component describes a number of components that share a specific type message Component { // the name of the component - string name = 1; + optional string name = 1; // the number of this type of component - uint32 count = 2; + optional uint32 count = 2; // the type of component oneof type { @@ -115,7 +115,7 @@ message Bandwidth { // Link describes a link between Components message Link { // name of the link - string name = 1; + optional string name = 1; // type of link LinkType type = 2; @@ -128,21 +128,18 @@ message Link { // between those components message Device { // the name of the device - string name = 1; + optional string name = 1; // collection of unique components in the device - repeated Component components = 3; + map components = 3; // collection of unique links in the device - repeated Link links = 4; + map links = 4; // a list of connections that describe how Components are connected to each // other in a single Device // // format: The following pieces of information each separated by a "." - // An * indicates all possible indexes for a component will be mapped - // In the case where an * is used for both indexes then the mapping will be - // a one-to-one mapping. // // - name = Component.name // - index < Component.count @@ -153,6 +150,7 @@ message Device { // examples: // nic.0.pcie.cpu.0 // npu.0.pcie.nvswitch.0 + // asic.0.mii.nic.0 // repeated string connections = 5; } @@ -162,15 +160,15 @@ message DeviceInstances { // it should be used to categorize the devices // for example it can be Dgx1Host, ZionexHost, RackSwitch, PodSwitch, // SpineSwitch etc. - string name = 1; + optional string name = 1; // the name of an actual device that exists in the // Infrastructure.inventory.devices field // this allows for a Device to be reused. - string device = 2; + optional string device = 2; // the number of instances of the device in the infrastructure under this name - uint32 count = 3; + optional uint32 count = 3; } // The Inventory message is a collection of unique devices and links present @@ -179,15 +177,15 @@ message DeviceInstances { // DeviceInstance, DeviceLink, ConnectionLink messages message Inventory { // A collection of all unique types of devices in the infrastructure - // Uniquess is determined by the Device.name field. - // This list is not an instance list, for eg, you define one DGX1 or ZionEx - // device and use the DeviceInstances message to scale up the number of those - // devices. + // Uniqueness is determined by the Device.name field. + // This list is not an instance list instead use the DeviceInstances message + // to create an instance of a Device and to scale it to the count present + // in your infrastructure. map devices = 1; // A collection of all unique types of links in the infrastructure. - // These links can be reused multiple times when creating ComponentConnection - // and DeviceConnection messages. + // These links can be reused multiple times when creating connections + // between devices. map links = 2; } @@ -207,9 +205,6 @@ message Infrastructure { repeated DeviceInstances device_instances = 2; // format: The following pieces of information each separated by a "." - // An * indicates all possible indexes for a component will be mapped - // In the case where an * is used for both indexes then the mapping will be - // a one-to-one mapping. // // - name = DeviceInstance.name // - index < DeviceInstance.count @@ -223,19 +218,6 @@ message Infrastructure { // // examples: // host.0.nic.0.100gpbs.racksw.0.nic.0 - // host.0.nic.0-8.100gpbs.racksw.0.nic.8-15 // repeated string connections = 3; } - -message ValidationRequest { - infra.Infrastructure infrastructure = 1; -} - -message ValidationResponse { - repeated string invalid_connections = 1; -} - -service InfraService { - rpc Validate(ValidationRequest) returns (ValidationResponse); -} \ No newline at end of file diff --git a/protos/service.proto b/protos/service.proto new file mode 100644 index 0000000..3014998 --- /dev/null +++ b/protos/service.proto @@ -0,0 +1,33 @@ +// service.proto +// +// Service and rpcs for infrastructure and bindings + +syntax = "proto3"; + +package service; + +import "infra.proto"; +import "bind.proto"; + +message ValidationRequest { + infra.Infrastructure infrastructure = 1; + bind.Bindings bindingas = 2; +} + +message ValidationError { + oneof type { + string optional = 1; + string oneof = 2; + string map = 3; + string repeated = 4; + string connection = 5; + } +} + +message ValidationResponse { + repeated ValidationError errors = 1; +} + +service Service { + rpc Validate(ValidationRequest) returns (ValidationResponse); +} \ No newline at end of file diff --git a/src/infra_service.py b/src/infra_service.py deleted file mode 100644 index e67274b..0000000 --- a/src/infra_service.py +++ /dev/null @@ -1,25 +0,0 @@ -if __package__ is None or __package__ == "": - from generated import infra_pb2, infra_pb2_grpc -else: - from .generated import infra_pb2, infra_pb2_grpc - -class InfraService(): - @staticmethod - def Validate(request: infra_pb2.ValidationRequest): - """Validate every connection in Device and Infrastructure. - - Every Device in Infrastructure.inventory.devices has connections which - must have a valid number of pieces separated by a ".". - - The names in the following connection breakdown must be present in the - Device components and links. - - The format of a Device connection is the following: - "component_name.component_index.link_name.component_name.component_index" - """ - validation_response = infra_pb2.ValidationResponse() - for name, device in request.infrastructure.inventory.devices.items(): - if name != device.name: - validation_response.inva - for connection in device.connections: - return validation_response \ No newline at end of file diff --git a/src/service.py b/src/service.py new file mode 100644 index 0000000..849648f --- /dev/null +++ b/src/service.py @@ -0,0 +1,50 @@ +if __package__ is None or __package__ == "": + from generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse +else: + from .generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse + + +class Service: + @staticmethod + def Validate(request: ValidationRequest): + """Validate every connection in the Infrastructure. + + Every Device in Infrastructure.inventory.devices has connections which + must have a valid number of pieces separated by a ".". + + Every connection in Infrastructure.connections must be composed of + a valid number of pieces separated by a "." and the pieces must exist + in the Infrastructure.inventory.devices and Infrastructure.inventory.links. + + The format of a Device connection is the following: + "component_name.component_index.link_name.component_name.component_index" + """ + errors = [] + for device_key, device in request.infrastructure.inventory.devices.items(): + if device.HasField("name") is False: + errors.append(ValidationError(optional=f"Device name field has not been set")) + if device_key != device.name: + errors.append( + ValidationError( + map=f"Device key '{device_key}' does not match Device.name '{device.name}'" + ) + ) + for link_key, link in device.links.items(): + if link_key != link.name: + errors.append( + ValidationError( + map=f"Device '{device.name}' link key '{link_key}' does not match Link.name '{link.name}'" + ) + ) + if link.bandwidth.WhichOneof("type") is None: + errors.append(ValidationError(oneof="Device.links.bandwidth type must be set")) + for connection in device.connections: + try: + src, src_idx, link, dst, dst_idx = connection.split(".") + except ValueError: + errors.append( + ValidationError( + connection=f"Component connection in device '{device.name}' is incorrectly formatted" + ) + ) + return ValidationResponse(errors=errors) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 5871ed8..169750e 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1 +1,9 @@ -import pytest +import os +import sys + +sys.path.extend( + [ + os.path.join(os.path.dirname(__file__), ".."), + os.path.join(os.path.dirname(__file__), "../generated"), + ] +) diff --git a/src/tests/test_device.py b/src/tests/test_device.py deleted file mode 100644 index f4ce2a1..0000000 --- a/src/tests/test_device.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest - -from keysight_chakra.generic import GenericHost -from keysight_chakra.closfabric import ClosFabric -from keysight_chakra.zionex import ZionEx -from keysight_chakra.infrastructure import Infrastructure - -import sys -import os -sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../") -import generated.infra_pb2 as infra - - -def test_generic_host_no_params(): - host = GenericHost() - assert host.get_component("npu") is not None - assert host.get_component("nic") is not None - assert "npu_interconnect" not in host._device.links - -def test_generic_host_with_params(): - npu_count = 4 - host = GenericHost(npu_count=npu_count, npu_interconnect_bandwidth_gbps=600) - assert "npu_interconnect" in host._device.links - assert host._device.links["npu_interconnect"].type == infra.LINK_CUSTOM - - seen_map = {} - for npu_index in range(npu_count): - seen_map[npu_index] = False - - for connection in host._device.connections: - if connection.link.c1 == "npu" and connection.link.c2 == "npu_interconnect_switch": - npu_index = connection.link.c1_index - assert npu_index in seen_map - assert not seen_map[npu_index] - seen_map[npu_index] = True - for v in seen_map.values(): - assert v == True - -@pytest.mark.parametrize( - "host_devices, rack_capacity, pod_capacity, spine_capacity", - [ - (8, 4, 1, 0), - (4, 4, 0, 0), - ], -) -@pytest.mark.parametrize( - "over_subscription", - [(1, 1), (2, 1)], -) -def test_clos_fabric( - host_devices, - rack_capacity, - pod_capacity, - over_subscription, - spine_capacity, -): - """Test creating 2 tier clos fabric""" - host = GenericHost(npu_count=1) - clos_fabric = ClosFabric( - host_device=host, - host_devices=host_devices, - rack_capacity=rack_capacity, - rack_to_pod_oversubscription=over_subscription, - pod_capacity=pod_capacity, - spine_capacity=spine_capacity, - ) - infrastructure = Infrastructure( - host_device=host, - host_devices=host_devices, - fabric=clos_fabric, - assignment_scheme="ROUND_ROBIN", - ) - - -def test_zionex_host(): - host = ZionEx() - assert host.get_component("cpu") is not None - assert host.get_component("npu") is not None - assert host.get_component("nic") is not None - - -if __name__ == "__main__": - pytest.main(["-s", __file__]) diff --git a/src/tests/test_rack_plane.py b/src/tests/test_rack_plane.py deleted file mode 100644 index 1ce34f8..0000000 --- a/src/tests/test_rack_plane.py +++ /dev/null @@ -1,81 +0,0 @@ -"""rack plane related unit tests""" - -import pytest - -if __package__ is None or __package__ == "": - from src.generated import infra_pb2 - from src.rack_plane_host import RackPlaneHostBuilder - from src.rack_plane_fabric import RackPlaneFabricBuilder - from src.infrastructure import Infrastructure -else: - from .generated import infra_pb2 - from keysight_chakra.rack_plane_host import RackPlaneHostBuilder - from keysight_chakra.rack_plane_fabric import RackPlaneFabricBuilder - from keysight_chakra.infrastructure import Infrastructure - - -@pytest.mark.parametrize("host_count", [2, 3, 4, 8]) -@pytest.mark.parametrize("sup_nic_count", [2, 3, 4]) -def test_rack_plane_fabric_and_host(host_count: int, sup_nic_count: int): - """verifies that the correct infrastructure can be created from rack_plane fabric/host""" - rp_host_builder = RackPlaneHostBuilder( - npu_count=1, scale_up_nic_count=sup_nic_count, scale_out_nic_count=1 - ) - rp_fabric_builder = RackPlaneFabricBuilder(host_builder=rp_host_builder) - infra_builder = Infrastructure( - host_device=rp_host_builder, - host_devices=host_count, - fabric=rp_fabric_builder, - assignment_scheme="ROUND_ROBIN", - ) - infrastructure = infra_builder.infrastructure - - assert infrastructure is not None - # loose check confirming the correct number of connections - # between host and rack switches - assert len(infrastructure.connections) == host_count * sup_nic_count - - -def test_rack_plane_fabric_and_host_detailed(): - """verifies that the correct infrastructure can be created from rack_plane fabric/host""" - sup_nic_count = 2 - host_count = 2 - rp_host_builder = RackPlaneHostBuilder( - npu_count=1, scale_up_nic_count=sup_nic_count, scale_out_nic_count=1 - ) - rp_fabric_builder = RackPlaneFabricBuilder(host_builder=rp_host_builder) - infra_builder = Infrastructure( - host_device=rp_host_builder, - host_devices=host_count, - fabric=rp_fabric_builder, - assignment_scheme="ROUND_ROBIN", - ) - infrastructure = infra_builder.infrastructure - - assert infrastructure is not None - assert len(infrastructure.connections) == host_count * sup_nic_count - - # now let's confirm every details of the DeviceConnections - def assert_device_conn( - dev_conn: infra_pb2.DeviceConnection, - d1_index: int, - c1_index: int, - d2_index: int, - c2_index: int, - ): - assert dev_conn.link.d1 == "RackPlaneHost" - assert dev_conn.link.c1 == "scale-up-nic" - assert dev_conn.link.d2 == "RackSwitch" - assert dev_conn.link.c2 == "port-down" - assert dev_conn.link.link == "eth" - assert dev_conn.link.d1_index == d1_index - assert dev_conn.link.c1_index == c1_index - assert dev_conn.link.d2_index == d2_index - assert dev_conn.link.c2_index == c2_index - - # plane 0 d1,c1,d2,c2 - assert_device_conn(infrastructure.connections[0], 0, 0, 0, 0) - assert_device_conn(infrastructure.connections[1], 0, 1, 1, 0) - # plane 1 (and thus rack switch 1; and scale up nic 1 on all hosts) - assert_device_conn(infrastructure.connections[2], 1, 0, 0, 1) - assert_device_conn(infrastructure.connections[3], 1, 1, 1, 1) diff --git a/src/tests/test_service.py b/src/tests/test_service.py new file mode 100644 index 0000000..b30c0fb --- /dev/null +++ b/src/tests/test_service.py @@ -0,0 +1,36 @@ +import pytest +from service import Service +from generated.service_pb2 import ValidationRequest +from generated.infra_pb2 import Infrastructure, Inventory, Device, Link, LinkType, Bandwidth + + +def test_valid_device(): + """Test that a device is valid""" + device = Device(name="host") + mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100)) + device.links[mii.name].CopyFrom(mii) + inventory = Inventory() + inventory.devices[device.name].CopyFrom(device) + infrastructure = Infrastructure(inventory=inventory) + request = ValidationRequest(infrastructure=infrastructure) + response = Service.Validate(request=request) + assert len(response.errors) == 0 + + +def test_missing_bandwidth(): + """Test that a device is missing the bandwidth from a link""" + device = Device(name="host") + mii = Link(name="mii", type=LinkType.LINK_CUSTOM) + device.links[mii.name].CopyFrom(mii) + inventory = Inventory() + inventory.devices[device.name].CopyFrom(device) + infrastructure = Infrastructure(inventory=inventory) + request = ValidationRequest(infrastructure=infrastructure) + response = Service.Validate(request=request) + print(response) + assert len(response.errors) == 1 + assert response.errors[0].WhichOneof("type") == "oneof" + + +if __name__ == "__main__": + pytest.main(["-s", __file__])