Skip to content

Commit

Permalink
feat: Implement accessing ports/children/resources by name (#114)
Browse files Browse the repository at this point in the history
* Implement basic version of accessing ports/children/resources by name

* Add tests for accessing resources by name

* Add tests for accessing and deleting nonexistent children

* Rename test file to better reflect its purpose

* Implement assignment validation

* Improve type hints

* Mention by_name accessor in the docs
  • Loading branch information
dexter2206 authored Jul 25, 2024
1 parent 54f963c commit b2920bd
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 8 deletions.
9 changes: 9 additions & 0 deletions docs/library/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ data = load_some_program()
program = SchemaV1.model_validate(data)
```

One of the benefits of using QREF's pydantic models is ability to obtain objects like children, ports
or resources by name, instead of list indices. This is done by special `.by_name` accessor. For instance
to get a child named `"foo"` of a `routine` object, one can use the following syntax:

```python
foo = routine.children.by_name["foo"]
```


### Topology validation

There can be cases where a program is correct from the perspective of Pydantic validation, but has incorrect topology. This includes cases such as:
Expand Down
73 changes: 65 additions & 8 deletions src/qref/schema_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@

from __future__ import annotations

from typing import Annotated, Any, Literal, Optional, Union
from typing import (
Annotated,
Any,
Iterator,
Literal,
MutableMapping,
Optional,
TypeVar,
Union,
get_args,
)

from pydantic import (
AfterValidator,
Expand All @@ -28,6 +38,7 @@
model_validator,
)
from pydantic.json_schema import GenerateJsonSchema
from pydantic_core import core_schema
from typing_extensions import Self

NAME_PATTERN = "[A-Za-z_][A-Za-z0-9_]*"
Expand All @@ -44,15 +55,61 @@
]
_Value = Union[int, float, str]

T = TypeVar("T")


class _ProxyMapping(MutableMapping[str, T]):
def __init__(self, source: list[T]):
self.source = source

def _find_item(self, name: str) -> tuple[int, T]:
try:
# To avoid the type: ignore below, we would have to define a protocol for named things,
# which seems to be an overkill, especially that this class is private.
return next(iter([(i, item) for i, item in enumerate(self.source) if item.name == name])) # type: ignore
except StopIteration:
raise KeyError(name)

def __getitem__(self, name: str) -> T:
_index, item = self._find_item(name)
return item

def __setitem__(self, name: str, new_item: T) -> None:
index, _current_item = self._find_item(name)
self.source[index] = new_item

def __delitem__(self, name: str):
index, _current_item = self._find_item(name)
del self.source[index]

def __iter__(self) -> Iterator[str]:
# Same reason for type: ignore as above
return iter((item.name for item in self.source)) # type: ignore

def __len__(self) -> int:
return len(self.source)


class NamedList(list[T]):
@property
def by_name(self) -> _ProxyMapping[T]:
return _ProxyMapping(self)

@classmethod
def __get_pydantic_core_schema__(cls, source, handler):
args = get_args(source)
schema = handler.generate_schema(list[args[0]])
return core_schema.no_info_after_validator_function(NamedList, schema)


def _sorter(key):
def _sorter(key, cls=list):
def _inner(v):
return sorted(v, key=key)
return cls(sorted(v, key=key))

return _inner


_name_sorter = AfterValidator(_sorter(lambda p: p.name))
_name_sorter = AfterValidator(_sorter(lambda p: p.name, NamedList))
_source_sorter = AfterValidator(_sorter(lambda c: c.source))


Expand Down Expand Up @@ -128,10 +185,10 @@ class RoutineV1(BaseModel):
"""

name: _Name
children: Annotated[list[RoutineV1], _name_sorter] = Field(default_factory=list)
children: Annotated[NamedList[RoutineV1], _name_sorter] = Field(default_factory=list)
type: Optional[str] = None
ports: Annotated[list[PortV1], _name_sorter] = Field(default_factory=list)
resources: Annotated[list[ResourceV1], _name_sorter] = Field(default_factory=list)
ports: Annotated[NamedList[PortV1], _name_sorter] = Field(default_factory=list)
resources: Annotated[NamedList[ResourceV1], _name_sorter] = Field(default_factory=list)
connections: Annotated[list[Annotated[ConnectionV1, _connection_parser]], _source_sorter] = Field(
default_factory=list
)
Expand All @@ -140,7 +197,7 @@ class RoutineV1(BaseModel):
linked_params: Annotated[list[ParamLinkV1], _source_sorter] = Field(default_factory=list)
meta: dict[str, Any] = Field(default_factory=dict)

model_config = ConfigDict(title="Routine")
model_config = ConfigDict(title="Routine", validate_assignment=True)

def __init__(self, **data: Any):
super().__init__(**{k: v for k, v in data.items() if v != [] and v != {}})
Expand Down
134 changes: 134 additions & 0 deletions tests/qref/test_accessing_objects_by_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from qref.schema_v1 import PortV1, ResourceV1, RoutineV1


@pytest.fixture
def example_routine():
return RoutineV1.model_validate(
{
"name": "root",
"children": [
{
"name": "a",
"ports": [
{"name": "ctrl", "size": "N", "direction": "input"},
{"name": "target", "size": "N", "direction": "input"},
],
},
{"name": "b", "children": [{"name": "c"}]},
],
"ports": [
{"name": "in_0", "size": 2, "direction": "input"},
{"name": "out_0", "size": 3, "direction": "output"},
],
"resources": [
{"name": "n_rotations", "value": 4, "type": "additive"},
{"name": "n_toffs", "value": 100, "type": "additive"},
],
}
)


class TestAccessingChildrenByName:
def test_can_get_direct_child_by_name(self, example_routine):
assert example_routine.children.by_name["a"] == example_routine.children[0]

def test_can_get_nested_child_by_name(self, example_routine):
assert example_routine.children.by_name["b"].children.by_name["c"] == example_routine.children[1].children[0]

def test_can_set_child_by_name(self, example_routine):
new_child = RoutineV1(name="b")
example_routine.children.by_name["b"] = new_child

assert example_routine.children[1] == new_child
assert example_routine.children.by_name["b"] == new_child

def test_can_delete_child_by_name(self, example_routine):
del example_routine.children.by_name["b"]

assert [child.name for child in example_routine.children] == ["a"]

def test_trying_to_get_nonexisting_child_raises_key_error(self, example_routine):
with pytest.raises(KeyError) as exc_info:
_ = example_routine.children.by_name["x"]

assert exc_info.value.args == ("x",)

def test_nonexistent_child_cannot_be_set(self, example_routine):
with pytest.raises(KeyError) as exc_info:
new_child = RoutineV1(name="x")
example_routine.children.by_name["x"] = new_child

assert exc_info.value.args == ("x",)


class TestAccessingPortsByName:
def test_can_get_port_by_name(self, example_routine):
assert example_routine.ports.by_name["in_0"] == example_routine.ports[0]

def test_can_set_port_by_name(self, example_routine):
new_port = PortV1(name="ctrl", direction="input", size=10)
example_routine.children[0].ports.by_name["ctrl"] = new_port

assert example_routine.children[0].ports[0] == new_port

def test_can_delete_port_by_name(self, example_routine):
del example_routine.ports.by_name["out_0"]

assert [port.name for port in example_routine.ports] == ["in_0"]

def test_trying_to_get_nonexisting_port_raises_key_error(self, example_routine):
with pytest.raises(KeyError) as exc_info:
_ = example_routine.ports.by_name["in_10"]

assert exc_info.value.args == ("in_10",)

def test_nonexistent_port_cannot_be_set(self, example_routine):
with pytest.raises(KeyError) as exc_info:
new_port = PortV1(name="in_10", size=42, direction="input")
example_routine.children.by_name["in_10"] = new_port

assert exc_info.value.args == ("in_10",)


class TestAccessingResourcesByName:
def test_can_get_resource_by_name(self, example_routine):
assert example_routine.resources.by_name["n_toffs"] == example_routine.resources[1]

def test_can_set_resource_by_name(self, example_routine):
new_resource = ResourceV1(name="n_toffs", value=10, type="multiplicative")
example_routine.resources.by_name["n_toffs"] = new_resource

assert example_routine.resources[1] == new_resource

def test_can_delete_resource_by_name(self, example_routine):
del example_routine.resources.by_name["n_toffs"]

assert [resource.name for resource in example_routine.resources] == ["n_rotations"]

def test_trying_to_get_nonexisting_resource_raises_key_error(self, example_routine):
with pytest.raises(KeyError) as exc_info:
_ = example_routine.resources.by_name["n_qubits"]

assert exc_info.value.args == ("n_qubits",)

def test_nonexistent_resource_cannot_be_set(self, example_routine):
with pytest.raises(KeyError) as exc_info:
new_resource = ResourceV1(name="n_qubits", value=42, type="other")
example_routine.resources.by_name["n_qubits"] = new_resource

assert exc_info.value.args == ("n_qubits",)
113 changes: 113 additions & 0 deletions tests/qref/test_assignment_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pydantic import ValidationError

from qref.schema_v1 import PortV1, ResourceV1, RoutineV1


@pytest.fixture
def example_routine():
return RoutineV1.model_validate(
{
"name": "root",
"children": [
{
"name": "a",
"ports": [
{"name": "ctrl", "size": "N", "direction": "input"},
{"name": "target", "size": "N", "direction": "input"},
{"name": "out", "size": "2N", "direction": "output"},
],
},
{"name": "b"},
],
"ports": [
{"name": "in_0", "size": 2, "direction": "input"},
{"name": "out_0", "size": 3, "direction": "output"},
],
"resources": [
{"name": "n_rotations", "value": 4, "type": "additive"},
{"name": "n_toffs", "value": 100, "type": "additive"},
],
"connections": ["in_0 -> a.ctrl"],
}
)


def test_setting_children_list_to_a_list_of_dictionaries_gives_a_list_of_routines(example_routine):
example_routine.children = [
{"name": "a", "ports": [{"name": "ctrl", "size": "N", "direction": "input"}]},
{"name": "c"},
{"name": "d"},
]

assert all(isinstance(child, RoutineV1) for child in example_routine.children)


def test_setting_children_list_to_an_incorrect_dictionary_raises_validation_error(example_routine):
with pytest.raises(ValidationError):
# The assignment here is invalid because we are left with connection to a.in_0, which
# will cease to exist after the assignment.
example_routine.children = [
{"name": "a", "ports": [{"name": "in_0", "size": "N", "direction": "input"}]},
{"name": "c"},
{"name": "d"},
]


def test_setting_ports_list_to_a_list_of_dictionaries_gives_a_list_of_ports(example_routine):
example_routine.ports = [
{"name": "in_0", "size": 2, "direction": "input"},
]

assert len(example_routine.ports) == 1 and isinstance(example_routine.ports[0], PortV1)


def test_setting_port_list_to_an_incorrect_value_raises_validation_error(example_routine):
with pytest.raises(ValidationError):
example_routine.ports = [{"name": "out_0", "size": 3, "direction": "output"}]


def test_setting_resources_to_a_list_of_dictionaries_gives_a_list_of_resources_v1(example_routine):
example_routine.resources = [
{"name": "n_rotations", "value": 40, "type": "additive"},
{"name": "n_toffs", "value": 10, "type": "additive"},
]

assert all(isinstance(resource, ResourceV1) for resource in example_routine.resources)

assert example_routine.resources.by_name["n_rotations"].value == 40
assert example_routine.resources.by_name["n_toffs"].value == 10


def test_setting_resources_to_an_incorrect_value_raies_validation_error(example_routine):
with pytest.raises(ValidationError):
example_routine.resources = [{"name": "n_toffs", "quantity": "N"}]


def test_setting_connections_to_a_new_value_converts_it_to_list_of_connection_v1(example_routine):
example_routine.connections = ["in_0 -> a.target"]

assert len(example_routine.connections) == 1

connection = example_routine.connections[0]

assert (connection.source, connection.target) == ("in_0", "a.target")


def test_setting_connections_to_incorrect_value_raises_validation_error(example_routine):
with pytest.raises(ValidationError):
# Incorrect, since there is no b.in_0 port
example_routine.connections = ["in_0 -> b.in_0"]

0 comments on commit b2920bd

Please sign in to comment.