Skip to content

Commit

Permalink
tests: run parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-v4s committed Oct 15, 2024
1 parent 5b33bc5 commit 391e00f
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 22 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
.vscode/
206 changes: 206 additions & 0 deletions tests/test_parameters/test_execution_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import List, Optional

import pytest
from pydantic import Field

from qualibrate.parameters import (
ExecutionParameters,
GraphParameters,
NodeParameters,
NodesParameters,
RunnableParameters,
)


class Node1(NodeParameters):
qubits: Optional[List[str]] = ["a", "b", "c"]
int_value: int = 1


class Node2(NodeParameters):
qubits: Optional[List[str]] = ["d", "e", "f"]
float_value: float = 2.0


class Graph(GraphParameters):
qubits: Optional[List[str]] = ["1", "2", "3"]
str_value: str = "test"


class NodesParams(NodesParameters):
node1: Node1 = Field(default_factory=Node1)
node2: Node2 = Field(default_factory=Node2)


class ExecutionParams(ExecutionParameters):
parameters: Graph = Field(default_factory=Graph)
nodes: NodesParams = Field(default_factory=NodesParams)


class TestExecutionParameters:
def test_default_initialization(self):
instance = ExecutionParameters()
assert isinstance(instance.parameters, GraphParameters)
assert isinstance(instance.nodes, NodesParameters)

def test_serialize_with_additional_fields(self):
class ExtendedExecutionParameters(ExecutionParameters):
extra_field: str = "extra"

serialized = ExtendedExecutionParameters.serialize()
assert "parameters" in serialized
assert "nodes" in serialized
assert "extra_field" in serialized

def test_serialize_exclude_targets_default(self):
assert ExecutionParams.serialize() == {
"parameters": {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": ["1", "2", "3"],
"title": "Qubits",
"is_targets": True,
},
"str_value": {
"default": "test",
"title": "Str Value",
"type": "string",
"is_targets": False,
},
},
"nodes": {
"node1": {
"int_value": {
"default": 1,
"title": "Int Value",
"type": "integer",
"is_targets": False,
}
},
"node2": {
"float_value": {
"default": 2.0,
"title": "Float Value",
"type": "number",
"is_targets": False,
}
},
},
}

def test_serialize_force_exclude_targets(self):
assert ExecutionParams.serialize(exclude_targets=True) == {
"parameters": {
"str_value": {
"default": "test",
"title": "Str Value",
"type": "string",
"is_targets": False,
}
},
"nodes": {
"node1": {
"int_value": {
"default": 1,
"title": "Int Value",
"type": "integer",
"is_targets": False,
}
},
"node2": {
"float_value": {
"default": 2.0,
"title": "Float Value",
"type": "number",
"is_targets": False,
}
},
},
}

def test_serialize_force_not_exclude_targets(self):
assert ExecutionParams.serialize(exclude_targets=False) == {
"parameters": {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": ["1", "2", "3"],
"title": "Qubits",
"is_targets": True,
},
"str_value": {
"default": "test",
"title": "Str Value",
"type": "string",
"is_targets": False,
},
},
"nodes": {
"node1": {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": ["a", "b", "c"],
"title": "Qubits",
"is_targets": True,
},
"int_value": {
"default": 1,
"title": "Int Value",
"type": "integer",
"is_targets": False,
},
},
"node2": {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": ["d", "e", "f"],
"title": "Qubits",
"is_targets": True,
},
"float_value": {
"default": 2.0,
"title": "Float Value",
"type": "number",
"is_targets": False,
},
},
},
}

def test_serialize_with_none_parameters_class(self, mocker):
mock_model_fields = mocker.patch(
"qualibrate.parameters.ExecutionParameters.model_fields"
)
mock_model_fields.__getitem__.return_value = mocker.MagicMock(
annotation=None
)

with pytest.raises(
RuntimeError, match="Graph parameters class can't be none"
):
ExecutionParameters.serialize()

def test_serialize_none_parameters_class(self, mocker):
mock_model_fields = mocker.patch(
"qualibrate.parameters.ExecutionParameters.model_fields"
)
mock_model_fields.__getitem__.return_value = mocker.MagicMock(
annotation=RunnableParameters
)

with pytest.raises(
RuntimeError,
match="Graph parameters class should be subclass of qualibrate.parameters.GraphParameters",
):
ExecutionParameters.serialize()
78 changes: 78 additions & 0 deletions tests/test_parameters/test_node_and_graph_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List, Optional

import pytest

from qualibrate.parameters import GraphParameters, NodeParameters


class TestCreateParameters:
class SampleNodeParameters(NodeParameters):
qubits: Optional[List[str]] = None
other_param: str = "test"

class SampleGraphParameters(GraphParameters):
qubits: Optional[List[str]] = None
other_param: str = "test"

def test_node_targets_name(self):
assert NodeParameters.targets_name == "qubits"

def test_graph_targets_name(self):
assert GraphParameters.targets_name == "qubits"

@pytest.mark.parametrize(
"parameters_class", [SampleNodeParameters, SampleGraphParameters]
)
def test_serialize_include_targets(self, parameters_class):
assert parameters_class.serialize(exclude_targets=False) == {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": None,
"title": "Qubits",
"is_targets": True,
},
"other_param": {
"default": "test",
"title": "Other Param",
"type": "string",
"is_targets": False,
},
}

@pytest.mark.parametrize(
"parameters_class", [SampleNodeParameters, SampleGraphParameters]
)
def test_serialize_exclude_targets(self, parameters_class):
assert parameters_class.serialize(exclude_targets=True) == {
"other_param": {
"default": "test",
"title": "Other Param",
"type": "string",
"is_targets": False,
}
}

@pytest.mark.parametrize(
"parameters_class", [SampleNodeParameters, SampleGraphParameters]
)
def test_serialize_no_exclude_param(self, parameters_class):
assert parameters_class.serialize() == {
"qubits": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": None,
"title": "Qubits",
"is_targets": True,
},
"other_param": {
"default": "test",
"title": "Other Param",
"type": "string",
"is_targets": False,
},
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from qualibrate.parameters import NodeParameters
from qualibrate.parameters import RunnableParameters


def test_parameters_empty_serialization():
class Parameters(NodeParameters):
class Parameters(RunnableParameters):
qubits: list[str] = []

parameters = Parameters()
Expand All @@ -19,7 +19,7 @@ class Parameters(NodeParameters):


def test_parameters_default_types_serialization():
class Parameters(NodeParameters):
class Parameters(RunnableParameters):
qubits: list[str] = []
bool_val: bool = False
int_val: int = 0
Expand Down
Loading

0 comments on commit 391e00f

Please sign in to comment.