Skip to content

Commit

Permalink
update validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbalogh committed Nov 12, 2024
1 parent 59ab581 commit 4d9a04b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 36 deletions.
2 changes: 2 additions & 0 deletions protos/bind.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ package bind;

import "google/protobuf/any.proto";

// Data message allows a user to provide data outside of the scope of the
// infrastructure graph.
message Data {
// Use this field to provide descriptive information about the message
// that is packed into the value field.
Expand Down
113 changes: 83 additions & 30 deletions src/service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
from typing import Tuple


if __package__ is None or __package__ == "":
from generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse
else:
from .generated.service_pb2 import ValidationRequest, ValidationError, ValidationResponse


class Service:
@staticmethod
def Validate(request: ValidationRequest):
def __init__(self):
self._validation_request = None
self._validation_response = None

def _validate_presence(self, object, name):
if object.HasField(name) is False:
self._validation_response.errors.append(
ValidationError(optional=f"{object.DESCRIPTOR.name} {name} field has not been set")
)

def _validate_map(self, map):
for key, object in map.items():
if object.name != key:
self._validation_response.errors.append(
ValidationError(
map=f"{object.DESCRIPTOR.name} name value:{object.name} does not match map key:{key}'"
)
)

def _validate_component_connection(self, device, connection: str):
try:
c1, c1_idx, link, c2, c2_idx = connection.split(".")
self._validate_component(device, c1, c1_idx)
self._validate_component(device, c2, c2_idx)
self._validate_link_name(device, link)
except ValueError:
self._validation_response.errors.append(
ValidationError(
connection=f"Connection:{connection} in device:{device.name} is incorrectly formatted"
)
)

def _validate_component(self, device, name, index):
if name not in device.components:
self._validation_response.errors.append(
ValidationError(connection=f"Component:{name} not present in device:{device.name}")
)
try:
index = int(index)
if index < 0 or index > device.components[name].count - 1:
self._validation_response.errors.append(
ValidationError(
connection=f"Component:{name} index:{index} must be >= 0 and <{device.components[name].count}"
)
)
except ValueError:
self._validation_response.errors.append(
ValidationError(connection=f"Index:{index} must be a valid integer")
)

def _validate_link_name(self, device, name: str):
if name not in device.links:
self._validation_response.errors.append(
ValidationError(connection=f"{device.name} does not contain a link:{name}")
)

def _validate_oneof(self, object, name):
if object.WhichOneof(name) is None:
self._validation_response.errors.append(
ValidationError(oneof=f"{object.DESCRIPTOR.name} oneof:{name} must be set")
)

def validate(self, request: ValidationRequest):
"""Validate every connection in the Infrastructure.
Every Device in Infrastructure.inventory.devices has connections which
Expand All @@ -19,32 +83,21 @@ def Validate(request: ValidationRequest):
The format of a Device connection is the following:
"component_name.component_index.link_name.component_name.component_index"
"""
errors = []
for device_key, device in request.infrastructure.inventory.devices.items():
if device.HasField("name") is False:
errors.append(ValidationError(optional=f"Device name field has not been set"))
if device_key != device.name:
errors.append(
ValidationError(
map=f"Device key '{device_key}' does not match Device.name '{device.name}'"
)
)
for link_key, link in device.links.items():
if link_key != link.name:
errors.append(
ValidationError(
map=f"Device '{device.name}' link key '{link_key}' does not match Link.name '{link.name}'"
)
)
if link.bandwidth.WhichOneof("type") is None:
errors.append(ValidationError(oneof="Device.links.bandwidth type must be set"))
self._validation_request = request
self._validation_response = ValidationResponse()

self._validate_map(request.infrastructure.inventory.devices)
for device in request.infrastructure.inventory.devices.values():
self._validate_presence(device, "name")
self._validate_map(device.links)
self._validate_map(device.components)
for link in device.links.values():
self._validate_presence(link, "name")
for component in device.components.values():
self._validate_presence(component, "name")
self._validate_presence(component, "count")
for connection in device.connections:
try:
src, src_idx, link, dst, dst_idx = connection.split(".")
except ValueError:
errors.append(
ValidationError(
connection=f"Component connection in device '{device.name}' is incorrectly formatted"
)
)
return ValidationResponse(errors=errors)
self._validate_component_connection(device, connection)
for link in device.links.values():
self._validate_oneof(link.bandwidth, "type")
return self._validation_response
9 changes: 9 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import os
import sys

Expand All @@ -7,3 +8,11 @@
os.path.join(os.path.dirname(__file__), "../generated"),
]
)

from infra_pb2 import Device


@pytest.fixture
def device():
host = Device(name="host")
return host
31 changes: 25 additions & 6 deletions src/tests/test_service.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import pytest
from service import Service
from generated.service_pb2 import ValidationRequest
from generated.infra_pb2 import Infrastructure, Inventory, Device, Link, LinkType, Bandwidth
from generated.infra_pb2 import Infrastructure, Inventory, Device, Link, LinkType, Bandwidth, Component


def test_valid_device():
def test_valid_device(device):
"""Test that a device is valid"""
device = Device(name="host")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100))
device.links[mii.name].CopyFrom(mii)
inventory = Inventory()
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = Service.Validate(request=request)
response = Service().validate(request=request)
assert len(response.errors) == 0


Expand All @@ -26,11 +25,31 @@ def test_missing_bandwidth():
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = Service.Validate(request=request)
response = Service().validate(request=request)
print(response)
assert len(response.errors) == 1
assert response.errors[0].WhichOneof("type") == "oneof"


def test_invalid_component_connection():
"""Test that a component connection is valid"""
device = Device(name="host")
mii = Link(name="mii", type=LinkType.LINK_CUSTOM, bandwidth=Bandwidth(gbps=100))
device.links[mii.name].CopyFrom(mii)
asic = Component(name="asic", count=1)
nic = Component(name="nic", count=1)
device.components[asic.name].CopyFrom(asic)
device.components[nic.name].CopyFrom(nic)
device.connections.append(f"{asic.name}.x.{mii.name}.null.-1")
inventory = Inventory()
inventory.devices[device.name].CopyFrom(device)
infrastructure = Infrastructure(inventory=inventory)
request = ValidationRequest(infrastructure=infrastructure)
response = Service().validate(request=request)
print(response)
assert len(response.errors) == 3
assert response.errors[0].WhichOneof("type") == "connection"


if __name__ == "__main__":
pytest.main(["-s", __file__])
pytest.main(["-s", "-o", "log_cli=True", "-o", "log_cli_level=INFO", __file__])

0 comments on commit 4d9a04b

Please sign in to comment.