Skip to content

Commit

Permalink
feat: topology verification
Browse files Browse the repository at this point in the history
* feat: add topology verification

* fix: add missing copyright

* fix: fix issues with topology verification

Co-authored-by: Konrad Jałowiecki <[email protected]>

* test: update validation tests

* style: fix minor issues

---------

Co-authored-by: Konrad Jałowiecki <[email protected]>
  • Loading branch information
mstechly and dexter2206 authored Jun 13, 2024
1 parent ea4989f commit c66f1ec
Show file tree
Hide file tree
Showing 13 changed files with 545 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
24 changes: 24 additions & 0 deletions docs/library/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ data = load_some_program()
program = SchemaV1.model_validate(data)
```

### Topology validation

There can be cases where a program is correct from the perspective of Pydantic validation, but has incorrect topology. This includes cases such as:

- Disconnected ports
- Ports with multiple connections
- Cycles in the graph

In order to validate whether the topology of the program is correct you can use `verify_topology` method. Here's a short snippet showing how one can verify their program and print out the problems (if any).

```python
from qref.verification import verify_topology

program = load_some_program()

verification_output = verify_topology(program)

if not verification_output:
print("Program topology is incorrect, due to the following issues:")
for problem in verification_output.problems:
print(problem)

```

### Rendering QREF files using `qref-render` (experimental)

!!! Warning
Expand Down
33 changes: 32 additions & 1 deletion src/qref/_schema_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@

from typing import Annotated, Any, Literal, Optional, Union

from pydantic import AfterValidator, BaseModel, ConfigDict, Field, StringConstraints
from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Field,
StringConstraints,
field_validator,
)
from pydantic.json_schema import GenerateJsonSchema

NAME_PATTERN = "[A-Za-z_][A-Za-z0-9_]*"
Expand Down Expand Up @@ -96,6 +103,30 @@ class RoutineV1(BaseModel):
def __init__(self, **data: Any):
super().__init__(**{k: v for k, v in data.items() if v != [] and v != {}})

@field_validator("connections", mode="after")
@classmethod
def _validate_connections(cls, v, values) -> list[_ConnectionV1]:
children_port_names = [
f"{child.name}.{port.name}"
for child in values.data.get("children")
for port in child.ports
]
parent_port_names = [port.name for port in values.data["ports"]]
available_port_names = set(children_port_names + parent_port_names)

missed_ports = [
port
for connection in v
for port in (connection.source, connection.target)
if port not in available_port_names
]
if missed_ports:
raise ValueError(
"The following ports appear in a connection but are not "
"among routine's port or their children's ports: {missed_ports}."
)
return v


class SchemaV1(BaseModel):
"""Root object in Program schema V1."""
Expand Down
154 changes: 154 additions & 0 deletions src/qref/verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union

from ._schema_v1 import RoutineV1, SchemaV1


@dataclass
class TopologyVerificationOutput:
"""Dataclass containing the output of the topology verification"""

problems: list[str]

@property
def is_valid(self):
return len(self.problems) == 0

def __bool__(self) -> bool:
return self.is_valid


def verify_topology(routine: Union[SchemaV1, RoutineV1]) -> TopologyVerificationOutput:
"""Checks whether program has correct topology.
Args:
routine: Routine or program to be verified.
"""
if isinstance(routine, SchemaV1):
routine = routine.program
problems = _verify_routine_topology(routine)
return TopologyVerificationOutput(problems)


def _verify_routine_topology(routine: RoutineV1) -> list[str]:
problems = []
adjacency_list = _get_adjacency_list_from_routine(routine, path=None)

problems += _find_cycles(adjacency_list)
problems += _find_disconnected_ports(routine)

for child in routine.children:
new_problems = _verify_routine_topology(child)
problems += new_problems
return problems


def _get_adjacency_list_from_routine(
routine: RoutineV1, path: Optional[str]
) -> dict[str, list[str]]:
"""This function creates a flat graph representing one hierarchy level of a routine.
Nodes represent ports and edges represent connections (they're directed).
Additionaly, we add node for each children and edges coming from all the input ports
into the children, and from the children into all the output ports.
"""
graph = defaultdict(list)
if path is None:
current_path = routine.name
else:
current_path = ".".join([path, routine.name])

# First, we go through all the connections and add them as adges to the graph
for connection in routine.connections:
source = ".".join([current_path, connection.source])
target = ".".join([current_path, connection.target])
graph[source].append(target)

# Then for each children we add an extra node and set of connections
for child in routine.children:
input_ports = []
output_ports = []

child_path = ".".join([current_path, child.name])
for port in child.ports:
if port.direction == "input":
input_ports.append(".".join([child_path, port.name]))
elif port.direction == "output":
output_ports.append(".".join([child_path, port.name]))

for input_port in input_ports:
graph[input_port].append(child_path)

graph[child_path] += output_ports

return graph


def _find_cycles(adjacency_list: dict[str, list[str]]) -> list[str]:
# Note: it only returns the first detected cycle.
for node in list(adjacency_list.keys()):
problem = _dfs_iteration(adjacency_list, node)
if problem:
return problem
return []


def _dfs_iteration(adjacency_list, start_node) -> list[str]:
to_visit = [start_node]
visited = []
predecessors = {}

while to_visit:
node = to_visit.pop()
visited.append(node)
for neighbour in adjacency_list[node]:
predecessors[neighbour] = node
if neighbour == start_node:
# Reconstruct the cycle
cycle = [neighbour]
while len(cycle) < 2 or cycle[-1] != start_node:
cycle.append(predecessors[cycle[-1]])
return [f"Cycle detected: {cycle[::-1]}"]
if neighbour not in visited:
to_visit.append(neighbour)
return []


def _find_disconnected_ports(routine: RoutineV1):
problems = []
for child in routine.children:
for port in child.ports:
pname = f"{routine.name}.{child.name}.{port.name}"
if port.direction == "input":
matches_in = [
c for c in routine.connections if c.target == f"{child.name}.{port.name}"
]
if len(matches_in) == 0:
problems.append(f"No incoming connections to {pname}.")
elif len(matches_in) > 1:
problems.append(f"Too many incoming connections to {pname}.")
elif port.direction == "output":
matches_out = [
c for c in routine.connections if c.source == f"{child.name}.{port.name}"
]
if len(matches_out) == 0:
problems.append(f"No outgoing connections from {pname}.")
elif len(matches_out) > 1:
problems.append(f"Too many outgoing connections from {pname}.")

return problems
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def _load_valid_examples():
for path in VALID_PROGRAMS_ROOT_PATH.iterdir():
for path in sorted(VALID_PROGRAMS_ROOT_PATH.iterdir()):
with open(path) as f:
data = yaml.safe_load(f)
yield pytest.param(data["input"], id=data["description"])
Expand Down
27 changes: 27 additions & 0 deletions tests/qref/data/invalid_pydantic_programs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
- input:
version: v1
program:
name: root
children:
- name: foo
ports:
- name: in_0
direction: input
size: 3
- name: out_0
direction: output
size: 3
- name: bar
ports:
- name: in_0
direction: input
size: 3
- name: out_0
direction: output
size: 3
connections:
- source: foo.out_0
target: bar.in_1
description: "Connection contains non-existent port name"
error_path: "$.program.connections[0].source"
error_message: "'foo.foo.out_0' does not match '^(([A-Za-z_][A-Za-z0-9_]*)|([A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*))$'"
Loading

0 comments on commit c66f1ec

Please sign in to comment.