From 7f14e952ba7be5dc07ebc7ddb4e4656a8ca7aa60 Mon Sep 17 00:00:00 2001 From: Seneca Meeks Date: Tue, 17 Dec 2024 16:31:17 -0800 Subject: [PATCH] allow passing func to de/serialization funcs --- cirq-google/cirq_google/api/v2/sweeps.py | 89 +++++++++++++++++-- cirq-google/cirq_google/api/v2/sweeps_test.py | 61 +++++++++++++ 2 files changed, 143 insertions(+), 7 deletions(-) diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index 926ac8c16b2..480ed2d3a65 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast, Dict, List, Optional, Callable +import copy import sympy import tunits @@ -55,7 +56,10 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any: def sweep_to_proto( - sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None + sweep: cirq.Sweep, + *, + out: Optional[run_context_pb2.Sweep] = None, + func: Callable[..., None] | None = None, ) -> run_context_pb2.Sweep: """Converts a Sweep to v2 protobuf message. @@ -63,6 +67,7 @@ def sweep_to_proto( sweep: The sweep to convert. out: Optional message to be populated. If not given, a new message will be created. + func: A function called on Linspace, Points. Returns: Populated sweep protobuf message. @@ -91,6 +96,17 @@ def sweep_to_proto( for s in sweep.sweeps: sweep_to_proto(s, out=out.sweep_function.sweeps.add()) elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr): + if func: + try: + copied_sweep = copy.deepcopy(sweep) + func(copied_sweep) + sweep = copied_sweep + except Exception as e: + print( + f"The function {func} was not applied to {sweep}." + f" because there was an exception thrown: {e}." + ) + out.single_sweep.parameter_key = sweep.key if isinstance(sweep.start, tunits.Value): unit = sweep.start.unit @@ -110,6 +126,16 @@ def sweep_to_proto( if sweep.metadata and getattr(sweep.metadata, 'units', None): out.single_sweep.parameter.units = sweep.metadata.units elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr): + if func: + try: + copied_sweep = copy.deepcopy(sweep) + func(copied_sweep) + sweep = copied_sweep + except Exception as e: + print( + f"The function {func} was not applied to {sweep}." + f" because there was an exception thrown: {e}." + ) out.single_sweep.parameter_key = sweep.key if len(sweep.points) == 1: out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0])) @@ -142,8 +168,16 @@ def sweep_to_proto( return out -def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: - """Creates a Sweep from a v2 protobuf message.""" +def sweep_from_proto( + msg: run_context_pb2.Sweep, func: Callable[..., None] | None = None +) -> cirq.Sweep: + """Creates a Sweep from a v2 protobuf message. + + Args: + msg: Serialized sweep message. + func: A function called on Linspace, Point, and ConstValue. + + """ which = msg.WhichOneof('sweep') if which is None: return cirq.UnitSweep @@ -182,28 +216,69 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: unit: float | tunits.Value = 1.0 if msg.single_sweep.linspace.HasField('unit'): unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit) - return cirq.Linspace( + sweep = cirq.Linspace( key=key, start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type] stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type] length=msg.single_sweep.linspace.num_points, metadata=metadata, ) + try: + # Allow for a function to modify a copy of the the sweep. If there are + # no exceptions cirq.Point is modified. + if func: + copied_sweep = copy.deepcopy(sweep) + func(copied_sweep) + return copied_sweep + except Exception as e: + print( + f"The function {func} was not applied to {sweep}." + f" because there was an exception thrown: {e}." + ) + return sweep + if msg.single_sweep.WhichOneof('sweep') == 'points': unit = 1.0 if msg.single_sweep.points.HasField('unit'): unit = tunits.Value.from_proto(msg.single_sweep.points.unit) - return cirq.Points( + sweep = cirq.Points( key=key, points=[p * unit for p in msg.single_sweep.points.points], metadata=metadata, ) + try: + # Allow for a function to modify a copy of the the sweep. If there are + # no exceptions cirq.Point is modified. + if func: + copied_sweep = copy.deepcopy(sweep) + func(copied_sweep) + return copied_sweep + except Exception as e: + print( + f"The function {func} was not applied to {sweep}." + f" because there was an exception thrown: {e}." + ) + return sweep + if msg.single_sweep.WhichOneof('sweep') == 'const_value': - return cirq.Points( + sweep = cirq.Points( key=key, points=[_recover_sweep_const(msg.single_sweep.const_value)], metadata=metadata, ) + try: + # Allow for a function to modify a copy of the the sweep. If there are + # no exceptions cirq.Point is modified. + if func: + copied_sweep = copy.deepcopy(sweep) + func(copied_sweep) + return copied_sweep + except Exception as e: + print( + f"The function {func} was not applied to {sweep}." + f" because there was an exception thrown: {e}." + ) + return sweep raise ValueError(f'single sweep type not set: {msg}') diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index 4f41b780772..bd76476f38d 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -153,6 +153,42 @@ def test_sweep_to_proto_points(): assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5] +def test_sweep_to_proto_with_simple_func_succeeds(): + def func(sweep: cirq.Sweep): + for idx, point in enumerate(sweep.points): + sweep.points[idx] = int(point + 3) + + sweep = cirq.Points('foo', [1, 2, 3]) + proto = v2.sweep_to_proto(sweep, func=func) + + assert list(proto.single_sweep.points.points) == [4, 5, 6] + + +def test_sweep_to_proto_with_func_round_trip(): + def add_tunit_func(sweep: cirq.Sweep): + for idx, point in enumerate(sweep.points): + sweep.points[idx] = point * tunits.ns + + sweep = cirq.Points('foo', [1, 2, 3]) + proto = v2.sweep_to_proto(sweep, func=add_tunit_func) + recovered = v2.sweep_from_proto(proto) + + assert list(recovered.points) == [1.0 * tunits.ns, 2.0 * tunits.ns, 3.0 * tunits.ns] + + +def test_sweep_to_proto_with_invalid_func_round_trip(): + def raise_error_func(sweep: cirq.Sweep): + for idx, point in enumerate(sweep.points): + sweep.points[idx] = point * tunits.ns + raise ValueError("err") + + sweep = cirq.Points('foo', [1, 2, 3]) + proto = v2.sweep_to_proto(sweep, func=raise_error_func) + recovered = v2.sweep_from_proto(proto) + + assert list(recovered.points) == [1, 2, 3] + + def test_sweep_to_proto_unit(): proto = v2.sweep_to_proto(cirq.UnitSweep) assert isinstance(proto, v2.run_context_pb2.Sweep) @@ -188,6 +224,31 @@ def test_sweep_from_proto_single_sweep_type_not_set(): v2.sweep_from_proto(proto) +def test_sweep_from_proto_with_func_succeeds(): + def add_tunit_func(sweep: cirq.Sweep): + for idx, point in enumerate(sweep.points): + sweep.points[idx] = point * tunits.ns + + sweep = cirq.Points('foo', [1, 2, 3]) + msg = v2.sweep_to_proto(sweep) + sweep = v2.sweep_from_proto(msg, func=add_tunit_func) + + assert list(sweep.points) == [1.0 * tunits.ns, 2.0 * tunits.ns, 3.0 * tunits.ns] + + +def test_sweep_from_proto_with_invalid_func_round_trip(): + def raise_error_func(sweep: cirq.Sweep): + for idx, point in enumerate(sweep.points): + sweep.points[idx] = point * tunits.ns + raise ValueError("err") + + sweep = cirq.Points('foo', [1, 2, 3]) + proto = v2.sweep_to_proto(sweep) + recovered = v2.sweep_from_proto(proto, func=raise_error_func) + + assert list(recovered.points) == [1, 2, 3] + + def test_sweep_with_list_sweep(): ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) proto = v2.sweep_to_proto(ls)