Skip to content

Commit

Permalink
Use duet for asynchrony, in particular for Collector (#4009)
Browse files Browse the repository at this point in the history
This switches async sampler methods to use the duet library (https://github.com/google/duet). Duet is "reentrant" in that it supports nested invocations of the event loop (duet.run). This differs from most other python async libraries which are not natively reentrant (without patching as in the case of nest-asyncio for asyncio). Re-entrancy makes it possible to refactor code incrementally to use duet whereas non-reentrant libraries typically require more global changes.

The asyncio-based work pool and testing utilities are deprecated and will be removed in a future release.

Review: @mpharrigan, @dstrain115
  • Loading branch information
maffoo authored Aug 25, 2021
1 parent 58cda7a commit de31ee9
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 129 deletions.
4 changes: 2 additions & 2 deletions cirq-core/cirq/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem._obj):
# coverage: ignore
raise ValueError(
f'{pyfuncitem._obj.__name__} is async but not '
f'decorated with "@pytest.mark.asyncio".'
f'{pyfuncitem._obj.__name__} is a bare async function. '
f'It should be decorated with "@duet.sync".'
)
3 changes: 2 additions & 1 deletion cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Generic, Dict, Any, List, Sequence, Union
from unittest import mock

import duet
import numpy as np
import pytest

Expand Down Expand Up @@ -371,7 +372,7 @@ def text(self, to_print):
assert p.text_pretty == 'SimulationTrialResult(...)'


@pytest.mark.asyncio
@duet.sync
async def test_async_sample():
m = {'mock': np.array([[0], [1]])}

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/testing/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import asyncio
from typing import Union, Awaitable, Coroutine

from cirq._compat import deprecated


@deprecated(deadline="v0.14", fix="Use duet instead.")
def asyncio_pending(
future: Union[Awaitable, asyncio.Future, Coroutine], timeout: float = 0.001
) -> Awaitable[bool]:
Expand Down
19 changes: 12 additions & 7 deletions cirq-core/cirq/testing/asynchronous_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,32 @@
import cirq


def _asyncio_pending(*args, **kw):
with cirq.testing.assert_deprecated("Use duet", deadline="v0.14"):
return cirq.testing.asyncio_pending(*args, **kw)


@pytest.mark.asyncio
async def test_asyncio_pending():
f = asyncio.Future()

assert await cirq.testing.asyncio_pending(f)
assert await _asyncio_pending(f)
f.set_result(5)
assert not await cirq.testing.asyncio_pending(f)
assert not await cirq.testing.asyncio_pending(f, timeout=100)
assert not await _asyncio_pending(f)
assert not await _asyncio_pending(f, timeout=100)

e = asyncio.Future()

assert await cirq.testing.asyncio_pending(e)
assert await _asyncio_pending(e)
e.set_exception(ValueError('test fail'))
assert not await cirq.testing.asyncio_pending(e)
assert not await cirq.testing.asyncio_pending(e, timeout=100)
assert not await _asyncio_pending(e)
assert not await _asyncio_pending(e, timeout=100)


@pytest.mark.asyncio
async def test_asyncio_pending_common_mistake_caught():
f = asyncio.Future()
pending = cirq.testing.asyncio_pending(f)
pending = _asyncio_pending(f)
with pytest.raises(RuntimeError, match='forgot the "await"'):
assert pending
assert await pending
109 changes: 57 additions & 52 deletions cirq-core/cirq/work/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, List, Optional, TYPE_CHECKING, Union
import abc
import asyncio
from typing import Any, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from typing_extensions import Protocol

import duet
import numpy as np

from cirq import circuits, study, value
from cirq.work import work_pool

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -56,7 +56,12 @@ def __repr__(self) -> str:
)


CIRCUIT_SAMPLE_JOB_TREE = Union[CircuitSampleJob, Iterable[Any]]
class CircuitSampleJobTree(Protocol):
def __iter__(self) -> Iterator[Union[CircuitSampleJob, 'CircuitSampleJobTree']]:
pass


CIRCUIT_SAMPLE_JOB_TREE = Union[CircuitSampleJob, CircuitSampleJobTree]


class Collector(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -126,15 +131,12 @@ def collect(
Returns:
The collector's result after all desired samples have been
collected.
See Also:
Python 3 documentation "Coroutines and Tasks"
https://docs.python.org/3/library/asyncio-task.html
"""
return asyncio.get_event_loop().run_until_complete(
self.collect_async(
sampler, concurrency=concurrency, max_total_samples=max_total_samples
)
return duet.run(
self.collect_async,
sampler,
concurrency=concurrency,
max_total_samples=max_total_samples,
)

async def collect_async(
Expand Down Expand Up @@ -164,53 +166,56 @@ async def collect_async(
Returns:
The collector's result after all desired samples have been
collected.
See Also:
Python 3 documentation "Coroutines and Tasks"
https://docs.python.org/3/library/asyncio-task.html
"""
pool = work_pool.CompletionOrderedAsyncWorkPool()
results: duet.AsyncCollector[Tuple[CircuitSampleJob, 'cirq.Result']] = duet.AsyncCollector()
job_error = None
running_jobs = 0
queued_jobs: List[CircuitSampleJob] = []
remaining_samples = np.infty if max_total_samples is None else max_total_samples

async def _start_async_job(job):
return job, await sampler.run_async(job.circuit, repetitions=job.repetitions)
async def run_job(job):
nonlocal job_error
try:
result = await sampler.run_async(job.circuit, repetitions=job.repetitions)
except Exception as error:
if not job_error:
results.error(error)
job_error = error
else:
if not job_error:
results.add((job, result))

# Keep dispatching and processing work.
while True:
# Fill up the work pool.
while remaining_samples > 0 and pool.num_uncollected < concurrency:
if not queued_jobs:
queued_jobs.extend(_flatten_jobs(self.next_job()))

# If no jobs were given, stop asking until something completes.
if not queued_jobs:
async with duet.new_scope() as scope:
while True:
# Fill up the work pool.
while remaining_samples > 0 and running_jobs < concurrency:
if not queued_jobs:
queued_jobs.extend(_flatten_jobs(self.next_job()))

# If no jobs were given, stop asking until something completes.
if not queued_jobs:
break

# Start new sampling job.
new_job = queued_jobs.pop(0)
remaining_samples -= new_job.repetitions
running_jobs += 1
scope.spawn(run_job, new_job)

# If no jobs are running, we're in a steady state. Halt.
if not running_jobs:
break

# Start new sampling job.
new_job = queued_jobs.pop(0)
remaining_samples -= new_job.repetitions
pool.include_work(_start_async_job(new_job))

# If no jobs were started or running, we're in a steady state. Halt.
if not pool.num_uncollected:
break

# Forward next job result from pool.
done_job, done_val = await pool.__anext__()
self.on_job_result(done_job, done_val)


def _flatten_jobs(given: Optional[CIRCUIT_SAMPLE_JOB_TREE]) -> List[CircuitSampleJob]:
out: List[CircuitSampleJob] = []
if given is not None:
_flatten_jobs_helper(given, out=out)
return out
# Get result from next completed job and call on_job_result.
job, result = await results.__anext__()
running_jobs -= 1
self.on_job_result(job, result)


def _flatten_jobs_helper(given: CIRCUIT_SAMPLE_JOB_TREE, *, out: List[CircuitSampleJob]) -> None:
if isinstance(given, CircuitSampleJob):
out.append(given)
elif given is not None:
for item in given:
_flatten_jobs_helper(item, out=out)
def _flatten_jobs(tree: Optional[CIRCUIT_SAMPLE_JOB_TREE]) -> Iterator[CircuitSampleJob]:
if isinstance(tree, CircuitSampleJob):
yield tree
elif tree is not None:
for item in tree:
yield from _flatten_jobs(item)
27 changes: 24 additions & 3 deletions cirq-core/cirq/work/collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import duet
import pytest

import cirq
Expand All @@ -36,7 +38,7 @@ def test_circuit_sample_job_repr():
)


@pytest.mark.asyncio
@duet.sync
async def test_async_collect():
received = []

Expand All @@ -49,10 +51,10 @@ def next_job(self):
def on_job_result(self, job, result):
received.append(job.tag)

completion = TestCollector().collect_async(
result = await TestCollector().collect_async(
sampler=cirq.Simulator(), max_total_samples=100, concurrency=5
)
assert await completion is None
assert result is None
assert received == ['test'] * 10


Expand All @@ -72,6 +74,25 @@ def on_job_result(self, job, result):
assert received == ['test'] * 10


def test_failed_job():
class FailingSampler:
async def run_async(self, circuit, repetitions):
await duet.completed_future(None)
raise Exception('job failed!')

class TestCollector(cirq.Collector):
def next_job(self):
q = cirq.LineQubit(0)
circuit = cirq.Circuit(cirq.H(q), cirq.measure(q))
return cirq.CircuitSampleJob(circuit=circuit, repetitions=10, tag='test')

def on_job_result(self, job, result):
pass

with pytest.raises(Exception, match='job failed!'):
TestCollector().collect(sampler=FailingSampler(), max_total_samples=100, concurrency=5)


def test_collect_with_reaction():
events = [0]
sent = 0
Expand Down
14 changes: 7 additions & 7 deletions cirq-core/cirq/work/pauli_sum_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import duet

import cirq


@pytest.mark.asyncio
@duet.sync
async def test_pauli_string_sample_collector():
a, b = cirq.LineQubit.range(2)
p = cirq.PauliSumCollector(
Expand All @@ -28,22 +28,22 @@ async def test_pauli_string_sample_collector():
+ (1 - 0j),
samples_per_term=100,
)
completion = p.collect_async(sampler=cirq.Simulator())
assert await completion is None
result = await p.collect_async(sampler=cirq.Simulator())
assert result is None
energy = p.estimated_energy()
assert isinstance(energy, float) and energy == 12


@pytest.mark.asyncio
@duet.sync
async def test_pauli_string_sample_single():
a, b = cirq.LineQubit.range(2)
p = cirq.PauliSumCollector(
circuit=cirq.Circuit(cirq.H(a), cirq.CNOT(a, b), cirq.X(a), cirq.Z(b)),
observable=cirq.X(a) * cirq.X(b),
samples_per_term=100,
)
completion = p.collect_async(sampler=cirq.Simulator())
assert await completion is None
result = await p.collect_async(sampler=cirq.Simulator())
assert result is None
assert p.estimated_energy() == -1


Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
"""Tests for cirq.Sampler."""
import pytest

import duet
import numpy as np
import pandas as pd
import sympy

import cirq


@pytest.mark.asyncio
@duet.sync
async def test_sampler_async_fail():
class FailingSampler(cirq.Sampler):
def run_sweep(self, program, params, repetitions: int = 1):
Expand Down Expand Up @@ -132,7 +133,7 @@ def test_sampler_sample_inconsistent_keys():
)


@pytest.mark.asyncio
@duet.sync
async def test_sampler_async_not_run_inline():
ran = False

Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/work/work_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import collections
from typing import Optional, Awaitable, Union

from cirq._compat import deprecated_class


@deprecated_class(
deadline='v0.14', fix='Use duet.AsyncCollector instead. See cirq.Collector for an example.'
)
class CompletionOrderedAsyncWorkPool:
"""Ensures given work is executing, and exposes it in completion order."""

Expand Down
Loading

0 comments on commit de31ee9

Please sign in to comment.