Skip to content

Commit

Permalink
Merge branch 'main' into z_cal2
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri authored Nov 7, 2024
2 parents a372c8e + 2c914ce commit fcb2d2a
Show file tree
Hide file tree
Showing 23 changed files with 1,212 additions and 335 deletions.
51 changes: 5 additions & 46 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,7 @@
"""Abstract base class for things sampling quantum circuits."""

import collections
from itertools import islice
from typing import (
Dict,
FrozenSet,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
TYPE_CHECKING,
Union,
)
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TypeVar, TYPE_CHECKING, Union

import duet
import pandas as pd
Expand All @@ -49,14 +37,6 @@
class Sampler(metaclass=value.ABCMetaImplementAnyOneOf):
"""Something capable of sampling quantum circuits. Simulator or hardware."""

# Users have a rate limit of 1000 QPM for read/write requests to
# the Quantum Engine. The sampler will poll from the DB every 1s
# for inflight requests for results. Empirically, for circuits
# sent in run_batch, sending circuits in CHUNK_SIZE=5 for large
# number of circuits (> 200) with large depths (100 layers)
# does not encounter quota exceeded issues for non-streaming cases.
CHUNK_SIZE: int = 5

def run(
self,
program: 'cirq.AbstractCircuit',
Expand Down Expand Up @@ -311,32 +291,16 @@ async def run_batch_async(
programs: Sequence['cirq.AbstractCircuit'],
params_list: Optional[Sequence['cirq.Sweepable']] = None,
repetitions: Union[int, Sequence[int]] = 1,
limiter: duet.Limiter = duet.Limiter(10),
) -> Sequence[Sequence['cirq.Result']]:
"""Runs the supplied circuits asynchronously.
See docs for `cirq.Sampler.run_batch`.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
if len(programs) <= self.CHUNK_SIZE:
return await duet.pstarmap_async(
self.run_sweep_async, zip(programs, params_list, repetitions)
)

results = []
for program_chunk, params_chunk, reps_chunk in zip(
_chunked(programs, self.CHUNK_SIZE),
_chunked(params_list, self.CHUNK_SIZE),
_chunked(repetitions, self.CHUNK_SIZE),
):
# Run_sweep_async for the current chunk
await duet.sleep(1) # Delay for 1 second between chunk
results.extend(
await duet.pstarmap_async(
self.run_sweep_async, zip(program_chunk, params_chunk, reps_chunk)
)
)

return results
return await duet.pstarmap_async(
self.run_sweep_async, zip(programs, params_list, repetitions, [limiter] * len(programs))
)

def _normalize_batch_args(
self,
Expand Down Expand Up @@ -489,8 +453,3 @@ def _get_measurement_shapes(
)
num_instances[key] += 1
return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()}


def _chunked(iterable: Sequence[T], n: int) -> Iterator[tuple[T, ...]]:
it = iter(iterable)
return iter(lambda: tuple(islice(it, n)), ())
54 changes: 3 additions & 51 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for cirq.Sampler."""
from typing import Sequence
from unittest import mock

import pytest

Expand Down Expand Up @@ -224,7 +223,9 @@ async def test_run_batch_async_calls_run_sweep_asynchronously():
params_list = [params1, params2]

class AsyncSampler(cirq.Sampler):
async def run_sweep_async(self, program, params, repetitions: int = 1):
async def run_sweep_async(
self, program, params, repetitions: int = 1, unused: duet.Limiter = duet.Limiter(None)
):
if params == params1:
await duet.sleep(0.001)

Expand Down Expand Up @@ -267,55 +268,6 @@ def test_sampler_run_batch_bad_input_lengths():
)


@mock.patch('duet.pstarmap_async')
@pytest.mark.parametrize('call_count', [1, 2, 3])
@duet.sync
async def test_run_batch_async_sends_circuits_in_chunks(spy, call_count):
class AsyncSampler(cirq.Sampler):
CHUNK_SIZE = 3

async def run_sweep_async(self, _, params, __: int = 1):
pass # pragma: no cover

sampler = AsyncSampler()
a = cirq.LineQubit(0)
circuit_list = [cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))] * (
sampler.CHUNK_SIZE * call_count
)
param_list = [cirq.Points('t', [0.3, 0.7])] * (sampler.CHUNK_SIZE * call_count)

await sampler.run_batch_async(circuit_list, params_list=param_list)

assert spy.call_count == call_count


@pytest.mark.parametrize('call_count', [1, 2, 3])
@duet.sync
async def test_run_batch_async_runs_runs_sequentially(call_count):
a = cirq.LineQubit(0)
finished = []
circuit1 = cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
circuit2 = cirq.Circuit(cirq.Y(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
params1 = cirq.Points('t', [0.3, 0.7])
params2 = cirq.Points('t', [0.4, 0.6])

class AsyncSampler(cirq.Sampler):
CHUNK_SIZE = 1

async def run_sweep_async(self, _, params, __: int = 1):
if params == params1:
await duet.sleep(0.001)

finished.append(params)

sampler = AsyncSampler()
circuit_list = [circuit1, circuit2] * call_count
param_list = [params1, params2] * call_count
await sampler.run_batch_async(circuit_list, params_list=param_list)

assert finished == param_list


def test_sampler_simple_sample_expectation_values():
a = cirq.LineQubit(0)
sampler = cirq.Simulator()
Expand Down
3 changes: 3 additions & 0 deletions cirq-google/cirq_google/api/v2/program.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
syntax = "proto3";

import "tunits/proto/tunits.proto";

package cirq.google.api.v2;

option java_package = "com.google.cirq.google.api.v2";
Expand Down Expand Up @@ -296,6 +298,7 @@ message ArgValue {
RepeatedInt64 int64_values = 5;
RepeatedDouble double_values = 6;
RepeatedString string_values = 7;
tunits.Value value_with_unit = 8;
}
}

Expand Down
191 changes: 96 additions & 95 deletions cirq-google/cirq_google/api/v2/program_pb2.py

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions cirq-google/cirq_google/api/v2/program_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 14 additions & 9 deletions cirq-google/cirq_google/engine/engine_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,14 @@ def delete(self) -> None:
"""Deletes the job and result, if any."""
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)

async def results_async(self) -> Sequence[EngineResult]:
async def results_async(
self, limiter: duet.Limiter = duet.Limiter(None)
) -> Sequence[EngineResult]:
"""Returns the job results, blocking until the job is complete."""
import cirq_google.engine.engine as engine_base

if self._results is None:
result_response = await self._await_result_async()
result_response = await self._await_result_async(limiter)
result = result_response.result
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
if (
Expand All @@ -286,7 +288,9 @@ async def results_async(self) -> Sequence[EngineResult]:
raise ValueError(f'invalid result proto version: {result_type}')
return self._results

async def _await_result_async(self) -> quantum.QuantumResult:
async def _await_result_async(
self, limiter: duet.Limiter = duet.Limiter(None)
) -> quantum.QuantumResult:
if self._job_result_future is not None:
response = await self._job_result_future
if isinstance(response, quantum.QuantumResult):
Expand All @@ -299,12 +303,13 @@ async def _await_result_async(self) -> quantum.QuantumResult:
'Internal error: The job response type is not recognized.'
) # pragma: no cover

async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
while True:
job = await self._refresh_job_async()
if job.execution_status.state in TERMINAL_STATES:
break
await duet.sleep(1)
async with limiter:
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
while True:
job = await self._refresh_job_async()
if job.execution_status.state in TERMINAL_STATES:
break
await duet.sleep(1)
_raise_on_failure(job)
response = await self.context.client.get_job_results_async(
self.project_id, self.program_id, self.job_id
Expand Down
20 changes: 18 additions & 2 deletions cirq-google/cirq_google/engine/processor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import cirq
import duet
from cirq_google.engine.engine_job import EngineJob

if TYPE_CHECKING:
import cirq_google as cg
Expand Down Expand Up @@ -58,9 +59,14 @@ def __init__(
self._run_name = run_name
self._snapshot_id = snapshot_id
self._device_config_name = device_config_name
self._result_limiter = duet.Limiter(None)

async def run_sweep_async(
self, program: 'cirq.AbstractCircuit', params: cirq.Sweepable, repetitions: int = 1
self,
program: 'cirq.AbstractCircuit',
params: cirq.Sweepable,
repetitions: int = 1,
limiter: duet.Limiter = duet.Limiter(None),
) -> Sequence['cg.EngineResult']:
job = await self._processor.run_sweep_async(
program=program,
Expand All @@ -70,6 +76,10 @@ async def run_sweep_async(
snapshot_id=self._snapshot_id,
device_config_name=self._device_config_name,
)

if isinstance(job, EngineJob):
return await job.results_async(limiter)

return await job.results_async()

run_sweep = duet.sync(run_sweep_async)
Expand All @@ -79,10 +89,12 @@ async def run_batch_async(
programs: Sequence[cirq.AbstractCircuit],
params_list: Optional[Sequence[cirq.Sweepable]] = None,
repetitions: Union[int, Sequence[int]] = 1,
limiter: duet.Limiter = duet.Limiter(10),
) -> Sequence[Sequence['cg.EngineResult']]:
self._result_limiter = limiter
return cast(
Sequence[Sequence['cg.EngineResult']],
await super().run_batch_async(programs, params_list, repetitions),
await super().run_batch_async(programs, params_list, repetitions, self._result_limiter),
)

run_batch = duet.sync(run_batch_async)
Expand All @@ -102,3 +114,7 @@ def snapshot_id(self) -> str:
@property
def device_config_name(self) -> str:
return self._device_config_name

@property
def result_limiter(self) -> duet.Limiter:
return self._result_limiter
26 changes: 26 additions & 0 deletions cirq-google/cirq_google/engine/processor_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import cirq
import cirq_google as cg
from cirq_google.engine.abstract_processor import AbstractProcessor
from cirq_google.engine.engine_job import EngineJob


@pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()])
Expand Down Expand Up @@ -169,6 +170,31 @@ def test_run_batch_differing_repetitions():
)


def test_run_batch_receives_results_using_limiter():
processor = mock.create_autospec(AbstractProcessor)
run_name = "RUN_NAME"
device_config_name = "DEVICE_CONFIG_NAME"
sampler = cg.ProcessorSampler(
processor=processor, run_name=run_name, device_config_name=device_config_name
)

job = mock.AsyncMock(EngineJob)

processor.run_sweep_async.return_value = job
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a))
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
repetitions = [1, 2]

sampler.run_batch(circuits, params_list, repetitions)

job.results_async.assert_called_with(sampler.result_limiter)


def test_processor_sampler_processor_property():
processor = mock.create_autospec(AbstractProcessor)
sampler = cg.ProcessorSampler(processor=processor)
Expand Down
11 changes: 9 additions & 2 deletions cirq-google/cirq_google/serialization/arg_func_langs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cirq_google.api import v2
from cirq_google.ops import InternalGate
from cirq.qis import CliffordTableau
import tunits

SUPPORTED_FUNCTIONS_FOR_LANGUAGE: Dict[Optional[str], FrozenSet[str]] = {
'': frozenset(),
Expand All @@ -33,8 +34,10 @@
SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow)

# Argument types for gates.
ARG_LIKE = Union[int, float, numbers.Real, Sequence[bool], str, sympy.Expr]
ARG_RETURN_LIKE = Union[float, int, str, List[bool], List[int], List[float], List[str], sympy.Expr]
ARG_LIKE = Union[int, float, numbers.Real, Sequence[bool], str, sympy.Expr, tunits.Value]
ARG_RETURN_LIKE = Union[
float, int, str, List[bool], List[int], List[float], List[str], sympy.Expr, tunits.Value
]
FLOAT_ARG_LIKE = Union[float, sympy.Expr]

# Types for comparing floats
Expand Down Expand Up @@ -182,6 +185,8 @@ def arg_to_proto(
)
field, types_tuple = numerical_fields[cur_index]
field.extend(types_tuple[0](x) for x in value)
elif isinstance(value, tunits.Value):
msg.arg_value.value_with_unit.MergeFrom(value.to_proto())
else:
_arg_func_to_proto(value, arg_function_language, msg)

Expand Down Expand Up @@ -329,6 +334,8 @@ def arg_from_proto(
return [float(v) for v in arg_value.double_values.values]
if which_val == 'string_values':
return [str(v) for v in arg_value.string_values.values]
if which_val == 'value_with_unit':
return tunits.Value.from_proto(arg_value.value_with_unit)
raise ValueError(f'Unrecognized value type: {which_val!r}')

if which == 'symbol':
Expand Down
Loading

0 comments on commit fcb2d2a

Please sign in to comment.