Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow passing a callable to de/serialization funcs #6855

Merged
merged 11 commits into from
Dec 19, 2024
59 changes: 41 additions & 18 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Callable, Dict, List, Optional

import sympy
import tunits

import cirq
from cirq.study import sweeps
from cirq_google.api.v2 import run_context_pb2
from cirq_google.study.device_parameter import DeviceParameter

@@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:


def sweep_to_proto(
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
sweep: cirq.Sweep,
*,
out: Optional[run_context_pb2.Sweep] = None,
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
) -> run_context_pb2.Sweep:
"""Converts a Sweep to v2 protobuf message.

Args:
sweep: The sweep to convert.
out: Optional message to be populated. If not given, a new message will
be created.
sweep_transformer: A function called on Linspace, Points.

Returns:
Populated sweep protobuf message.
@@ -91,6 +96,7 @@ def sweep_to_proto(
for s in sweep.sweeps:
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
sweep = cast(cirq.Linspace, sweep_transformer(sweep))
out.single_sweep.parameter_key = sweep.key
if isinstance(sweep.start, tunits.Value):
unit = sweep.start.unit
@@ -110,6 +116,7 @@ def sweep_to_proto(
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
sweep = cast(cirq.Points, sweep_transformer(sweep))
out.single_sweep.parameter_key = sweep.key
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
@@ -142,8 +149,17 @@ def sweep_to_proto(
return out


def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message."""
def sweep_from_proto(
msg: run_context_pb2.Sweep,
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message.

Args:
msg: Serialized sweep message.
sweep_transformer: A function called on Linspace, Point, and ConstValue.

"""
which = msg.WhichOneof('sweep')
if which is None:
return cirq.UnitSweep
@@ -178,31 +194,38 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
)
else:
metadata = None

if msg.single_sweep.WhichOneof('sweep') == 'linspace':
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
return sweep_transformer(
cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
)
if msg.single_sweep.WhichOneof('sweep') == 'points':
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
return sweep_transformer(
cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
)
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
return sweep_transformer(
cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
)
)

raise ValueError(f'single sweep type not set: {msg}')
87 changes: 87 additions & 0 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
@@ -153,6 +153,58 @@ def test_sweep_to_proto_points():
assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5]


def test_sweep_to_proto_with_simple_func_succeeds():
def func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert list(proto.single_sweep.points.points) == [4.0, 5.0, 6.0]


def test_sweep_to_proto_with_func_linspace():
def func(sweep: sweeps.SingleSweep):
return cirq.Linspace('foo', 3 * tunits.ns, 6 * tunits.ns, 3) # type: ignore[arg-type]

sweep = cirq.Linspace('foo', start=1, stop=3, length=3)
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert proto.single_sweep.linspace.first_point == 3.0
assert proto.single_sweep.linspace.last_point == 6.0
assert tunits.Value.from_proto(proto.single_sweep.linspace.unit) == tunits.ns


def test_sweep_to_proto_with_func_const_value():
def func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', points=[1])
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert proto.single_sweep.const_value.int_value == 4


@pytest.mark.parametrize('sweep', [(cirq.Points('foo', [1, 2, 3])), (cirq.Points('foo', [1]))])
def test_sweep_to_proto_with_func_round_trip(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

proto = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points)[0] == 1 * tunits.ns


def test_sweep_to_proto_unit():
proto = v2.sweep_to_proto(cirq.UnitSweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
@@ -188,6 +240,41 @@ def test_sweep_from_proto_single_sweep_type_not_set():
v2.sweep_from_proto(proto)


@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
def test_sweep_from_proto_with_func_succeeds(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

msg = v2.sweep_to_proto(sweep)
sweep = v2.sweep_from_proto(msg, sweep_transformer=add_tunit_func)

assert list(sweep.points)[0] == [1.0 * tunits.ns]
Comment on lines +251 to +254
Copy link
Collaborator

@BichengYing BichengYing Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add another test that add tunit in sweep_to_proto while remove tunit in sweep_from_proto, this is a more practical usage in the pyle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
def test_sweep_from_proto_with_func_round_trip(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

def strip_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
if isinstance(sweep.points[0], tunits.Value):
sweep.points = [point[point.unit] for point in sweep.points]

return sweep

msg = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
sweep = v2.sweep_from_proto(msg, sweep_transformer=strip_tunit_func)

assert list(sweep.points)[0] == 1.0


def test_sweep_with_list_sweep():
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
proto = v2.sweep_to_proto(ls)