Skip to content

Commit

Permalink
allow passing func to de/serialization funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
senecameeks committed Dec 18, 2024
1 parent 2b19bd3 commit 7f14e95
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 7 deletions.
89 changes: 82 additions & 7 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Dict, List, Optional, Callable

import copy
import sympy
import tunits

Expand Down Expand Up @@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:


def sweep_to_proto(
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
sweep: cirq.Sweep,
*,
out: Optional[run_context_pb2.Sweep] = None,
func: Callable[..., None] | None = None,
) -> run_context_pb2.Sweep:
"""Converts a Sweep to v2 protobuf message.
Args:
sweep: The sweep to convert.
out: Optional message to be populated. If not given, a new message will
be created.
func: A function called on Linspace, Points.
Returns:
Populated sweep protobuf message.
Expand Down Expand Up @@ -91,6 +96,17 @@ def sweep_to_proto(
for s in sweep.sweeps:
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
if func:
try:
copied_sweep = copy.deepcopy(sweep)
func(copied_sweep)
sweep = copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {e}."
)

out.single_sweep.parameter_key = sweep.key
if isinstance(sweep.start, tunits.Value):
unit = sweep.start.unit
Expand All @@ -110,6 +126,16 @@ def sweep_to_proto(
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
if func:
try:
copied_sweep = copy.deepcopy(sweep)
func(copied_sweep)
sweep = copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {e}."
)
out.single_sweep.parameter_key = sweep.key
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
Expand Down Expand Up @@ -142,8 +168,16 @@ def sweep_to_proto(
return out


def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message."""
def sweep_from_proto(
msg: run_context_pb2.Sweep, func: Callable[..., None] | None = None
) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message.
Args:
msg: Serialized sweep message.
func: A function called on Linspace, Point, and ConstValue.
"""
which = msg.WhichOneof('sweep')
if which is None:
return cirq.UnitSweep
Expand Down Expand Up @@ -182,28 +216,69 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
sweep = cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
try:
# Allow for a function to modify a copy of the the sweep. If there are
# no exceptions cirq.Point is modified.
if func:
copied_sweep = copy.deepcopy(sweep)
func(copied_sweep)
return copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {e}."
)
return sweep

if msg.single_sweep.WhichOneof('sweep') == 'points':
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
sweep = cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
try:
# Allow for a function to modify a copy of the the sweep. If there are
# no exceptions cirq.Point is modified.
if func:
copied_sweep = copy.deepcopy(sweep)
func(copied_sweep)
return copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {e}."
)
return sweep

if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
sweep = cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
)
try:
# Allow for a function to modify a copy of the the sweep. If there are
# no exceptions cirq.Point is modified.
if func:
copied_sweep = copy.deepcopy(sweep)
func(copied_sweep)
return copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {e}."
)
return sweep

raise ValueError(f'single sweep type not set: {msg}')

Expand Down
61 changes: 61 additions & 0 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,42 @@ def test_sweep_to_proto_points():
assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5]


def test_sweep_to_proto_with_simple_func_succeeds():
def func(sweep: cirq.Sweep):
for idx, point in enumerate(sweep.points):
sweep.points[idx] = int(point + 3)

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, func=func)

assert list(proto.single_sweep.points.points) == [4, 5, 6]


def test_sweep_to_proto_with_func_round_trip():
def add_tunit_func(sweep: cirq.Sweep):
for idx, point in enumerate(sweep.points):
sweep.points[idx] = point * tunits.ns

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, func=add_tunit_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points) == [1.0 * tunits.ns, 2.0 * tunits.ns, 3.0 * tunits.ns]


def test_sweep_to_proto_with_invalid_func_round_trip():
def raise_error_func(sweep: cirq.Sweep):
for idx, point in enumerate(sweep.points):
sweep.points[idx] = point * tunits.ns
raise ValueError("err")

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, func=raise_error_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points) == [1, 2, 3]


def test_sweep_to_proto_unit():
proto = v2.sweep_to_proto(cirq.UnitSweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
Expand Down Expand Up @@ -188,6 +224,31 @@ def test_sweep_from_proto_single_sweep_type_not_set():
v2.sweep_from_proto(proto)


def test_sweep_from_proto_with_func_succeeds():
def add_tunit_func(sweep: cirq.Sweep):
for idx, point in enumerate(sweep.points):
sweep.points[idx] = point * tunits.ns

sweep = cirq.Points('foo', [1, 2, 3])
msg = v2.sweep_to_proto(sweep)
sweep = v2.sweep_from_proto(msg, func=add_tunit_func)

assert list(sweep.points) == [1.0 * tunits.ns, 2.0 * tunits.ns, 3.0 * tunits.ns]


def test_sweep_from_proto_with_invalid_func_round_trip():
def raise_error_func(sweep: cirq.Sweep):
for idx, point in enumerate(sweep.points):
sweep.points[idx] = point * tunits.ns
raise ValueError("err")

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep)
recovered = v2.sweep_from_proto(proto, func=raise_error_func)

assert list(recovered.points) == [1, 2, 3]


def test_sweep_with_list_sweep():
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
proto = v2.sweep_to_proto(ls)
Expand Down

0 comments on commit 7f14e95

Please sign in to comment.