Skip to content

Commit

Permalink
rename service, add param typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbalogh committed Nov 14, 2024
1 parent 23abd8a commit 4c02481
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 21 deletions.
3 changes: 2 additions & 1 deletion protos/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message ValidationResponse {
repeated string info = 3;
}

service Service {
service InfraService {
// Validate rpc validates both infra and binding messages
rpc Validate(ValidationRequest) returns (ValidationResponse);
}
6 changes: 3 additions & 3 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import pytest
from keysight_chakra.generated.infra_pb2 import Device, Component, Link, LinkType, Bandwidth, Npu, Nic
from service import Service
from validation import Validation


@pytest.fixture
def service() -> Service:
return Service()
def validation() -> Validation:
return Validation()


@pytest.fixture
Expand Down
12 changes: 6 additions & 6 deletions src/tests/test_service.py → src/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
)


def test_validate_device(service, host):
def test_validate_device(validation, host):
"""Test that a device is valid"""
response = service.validate(
response = validation.validate(
request=ValidationRequest(
infrastructure=Infrastructure(
inventory=Inventory(
Expand All @@ -29,7 +29,7 @@ def test_validate_device(service, host):
assert len(response.errors) == 0


def test_missing_bandwidth(service):
def test_missing_bandwidth(validation):
"""Test that a device is missing the bandwidth from a link"""
device = Device(name="host")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM)
Expand All @@ -38,13 +38,13 @@ def test_missing_bandwidth(service):
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = service.validate(request=request)
response = validation.validate(request=request)
print(response)
assert len(response.errors) == 1
assert response.errors[0].WhichOneof("type") == "oneof"


def test_referential_integrity(service):
def test_referential_integrity(validation):
"""Referential integrity tests"""
device = Device(name="laptop")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100))
Expand All @@ -61,7 +61,7 @@ def test_referential_integrity(service):
host = DeviceInstances(name="host", device="laptop", count=4)
infrastructure.device_instances[host.name].CopyFrom(host)
request = ValidationRequest(infrastructure=infrastructure)
response = service.validate(request=request)
response = validation.validate(request=request)
print(response)
for error in response.errors:
assert error.WhichOneof("type") == "referential_integrity"
Expand Down
29 changes: 18 additions & 11 deletions src/service.py → src/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
"""

from typing import Annotated
from google.protobuf.message import Message

if __package__ is None or __package__ == "":
from generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse
from generated.infra_pb2 import Device
from generated.bind_pb2 import Binding
else:
from .generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse
from .generated.infra_pb2 import Device
from .generated.bind_pb2 import Binding


class Service:
class Validation:
def __init__(self):
self._validation_request = None
self._validation_response = None

def _validate_presence(self, object, name):
def _validate_presence(self, object: Message, name: str):
if object.HasField(name) is False:
self._validation_response.errors.append(
ValidationError(optional=f"{object.DESCRIPTOR.name} {name} field has not been set")
)

def _validate_map(self, map):
def _validate_map(self, map: Annotated[object, "google protobuf MessageMapContainer"]):
for key, object in map.items():
if object.name != key:
self._validation_response.errors.append(
Expand All @@ -29,7 +36,7 @@ def _validate_map(self, map):
)
)

def _validate_component_connection(self, device, connection: str):
def _validate_component_connection(self, device: Device, connection: str):
try:
c1, c1_idx, link, c2, c2_idx = connection.split(".")
self._validate_component(device, c1, c1_idx)
Expand All @@ -42,7 +49,7 @@ def _validate_component_connection(self, device, connection: str):
)
)

def _validate_component(self, device, name, index):
def _validate_component(self, device: Device, name: str, index: int):
if name not in device.components:
self._validation_response.errors.append(
ValidationError(
Expand All @@ -62,33 +69,33 @@ def _validate_component(self, device, name, index):
ValidationError(referential_integrity=f"Index:{index} must be a valid integer")
)

def _validate_link_name(self, device, name: str):
def _validate_link_name(self, device: Device, name: str):
if name not in device.links:
self._validation_response.errors.append(
ValidationError(
referential_integrity=f"Infrastructure.devices[{device.name}].links[{name}] does not exist"
)
)

def _validate_oneof(self, object, name):
def _validate_oneof(self, object: Message, name: str):
if object.WhichOneof(name) is None:
self._validation_response.errors.append(
ValidationError(oneof=f"{object.DESCRIPTOR.name} oneof:{name} must be set")
)

def _validate_device_exists(self, name):
def _validate_device_exists(self, name: str):
if name not in self._validation_request.infrastructure.inventory.devices:
self._validation_response.errors.append(
ValidationError(referential_integrity=f"Infrastructure.devices[{name}] does not exist")
)

def _validate_device_connection(self, connection):
def _validate_device_connection(self, connection: str):
pass

def _validate_binding_infrastructure_path(self, binding):
def _validate_binding_infrastructure_path(self, binding: Binding):
pass

def _validate_count(self, object):
def _validate_count(self, object: Message):
"""Validates that the count of an object is greater than 0"""
if object.count < 1:
self._validation_response.errors.append(
Expand Down

0 comments on commit 4c02481

Please sign in to comment.