Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbalogh committed Nov 13, 2024
1 parent 248106a commit 23abd8a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 33 deletions.
16 changes: 10 additions & 6 deletions src/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,19 @@ def _validate_device_exists(self, name):
ValidationError(referential_integrity=f"Infrastructure.devices[{name}] does not exist")
)

def _validate_infrastructure_connection(self, connection):
def _validate_device_connection(self, connection):
pass

def _validate_binding_infrastructure_path(self, binding):
pass

def _validate_count(self, count):
if count < 1:
def _validate_count(self, object):
"""Validates that the count of an object is greater than 0"""
if object.count < 1:
self._validation_response.errors.append(
ValidationError(count=f"Count {count} must be greater than 0")
ValidationError(
count=f"{object.DESCRIPTOR.name}.count == {object.count} and must be greater than 0"
)
)

def validate(self, request: ValidationRequest):
Expand Down Expand Up @@ -124,6 +127,7 @@ def validate(self, request: ValidationRequest):
for component in device.components.values():
self._validate_presence(component, "name")
self._validate_presence(component, "count")
self._validate_count(component)
for connection in device.connections:
self._validate_component_connection(device, connection)
for link in device.links.values():
Expand All @@ -132,10 +136,10 @@ def validate(self, request: ValidationRequest):
self._validate_presence(link, "name")
self._validate_map(request.infrastructure.device_instances)
for device_instance in request.infrastructure.device_instances.values():
self._validate_count(device_instance.count)
self._validate_count(device_instance)
self._validate_device_exists(device_instance.device)
for connection in request.infrastructure.connections:
self._validate_infrastructure_connection(connection)
self._validate_device_connection(connection)
if request.bindings is not None:
for binding in request.bindings.bindings:
self._validate_oneof(binding, "infrastructure_path")
Expand Down
34 changes: 22 additions & 12 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
import pytest
import os
import sys
import os

sys.path.extend(
[
os.path.join(os.path.dirname(__file__), ".."),
os.path.join(os.path.dirname(__file__), "../generated"),
]
)
for directory_piece in ["..", "generated", "tests"]:
sys.path.append(os.path.join(os.path.dirname(__file__), directory_piece))

from infra_pb2 import Device
import pytest
from keysight_chakra.generated.infra_pb2 import Device, Component, Link, LinkType, Bandwidth, Npu, Nic
from service import Service


@pytest.fixture
def service() -> Service:
return Service()


@pytest.fixture
def device():
host = Device(name="host")
return host
def host() -> Device:
device = Device(
name="aic-sb203-lx",
components={
"tesla-t4": Component(name="tesla-t4", count=1, npu=Npu()),
"cx6": Component(name="cx6", count=1, nic=Nic()),
},
links={"pcie-3": Link(name="pcie-3", type=LinkType.LINK_PCIE, bandwidth=Bandwidth(gbps=32))},
connections=["tesla-t4.0.pcie-3.cx6.0"],
)
return device
34 changes: 19 additions & 15 deletions src/tests/test_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from service import Service
from generated.service_pb2 import ValidationRequest
from generated.infra_pb2 import (

from keysight_chakra.generated.service_pb2 import ValidationRequest
from keysight_chakra.generated.infra_pb2 import (
Infrastructure,
Inventory,
Device,
Expand All @@ -13,19 +13,23 @@
)


def test_valid_device(device):
def test_validate_device(service, host):
"""Test that a device is valid"""
mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100))
device.links[mii.name].CopyFrom(mii)
inventory = Inventory()
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = Service().validate(request=request)
response = service.validate(
request=ValidationRequest(
infrastructure=Infrastructure(
inventory=Inventory(
devices={
host.name: host,
},
),
)
)
)
assert len(response.errors) == 0


def test_missing_bandwidth():
def test_missing_bandwidth(service):
"""Test that a device is missing the bandwidth from a link"""
device = Device(name="host")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM)
Expand All @@ -34,13 +38,13 @@ def test_missing_bandwidth():
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = Service().validate(request=request)
response = service.validate(request=request)
print(response)
assert len(response.errors) == 1
assert response.errors[0].WhichOneof("type") == "oneof"


def test_referential_integrity():
def test_referential_integrity(service):
"""Referential integrity tests"""
device = Device(name="laptop")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100))
Expand All @@ -57,7 +61,7 @@ def test_referential_integrity():
host = DeviceInstances(name="host", device="laptop", count=4)
infrastructure.device_instances[host.name].CopyFrom(host)
request = ValidationRequest(infrastructure=infrastructure)
response = Service().validate(request=request)
response = service.validate(request=request)
print(response)
for error in response.errors:
assert error.WhichOneof("type") == "referential_integrity"
Expand Down

0 comments on commit 23abd8a

Please sign in to comment.