Skip to content

Commit

Permalink
Use duet for async function, in particular for Collector
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo committed Apr 8, 2021
1 parent 259dd91 commit 7c90792
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 458 deletions.
4 changes: 2 additions & 2 deletions cirq/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem._obj):
# coverage: ignore
raise ValueError(
f'{pyfuncitem._obj.__name__} is async but not '
f'decorated with "@pytest.mark.asyncio".'
f'{pyfuncitem._obj.__name__} is a bare async function. '
f'Should be decorated with "@duet.sync".'
)
4 changes: 3 additions & 1 deletion cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import abc
from typing import Generic, Dict, Any
from unittest import mock

import duet
import numpy as np
import pytest

Expand Down Expand Up @@ -353,7 +355,7 @@ def text(self, to_print):
assert p.text_pretty == 'SimulationTrialResult(...)'


@pytest.mark.asyncio
@duet.sync
async def test_async_sample():
m = {'mock': np.array([[0], [1]])}

Expand Down
4 changes: 0 additions & 4 deletions cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

"""Utilities for testing code."""

from cirq.testing.asynchronous import (
asyncio_pending,
)

from cirq.testing.circuit_compare import (
assert_circuits_with_terminal_measurements_are_equivalent,
assert_has_consistent_apply_unitary,
Expand Down
73 changes: 0 additions & 73 deletions cirq/testing/asynchronous.py

This file was deleted.

45 changes: 0 additions & 45 deletions cirq/testing/asynchronous_test.py

This file was deleted.

109 changes: 57 additions & 52 deletions cirq/work/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, List, Optional, TYPE_CHECKING, Union
import abc
import asyncio
from typing import Any, Iterable, Iterator, List, Optional, TYPE_CHECKING, Union
from typing_extensions import Protocol

import duet
import numpy as np

from cirq import circuits, study, value
from cirq.work import work_pool

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -55,7 +55,12 @@ def __repr__(self) -> str:
)


CIRCUIT_SAMPLE_JOB_TREE = Union[CircuitSampleJob, Iterable[Any]]
class CircuitSampleJobTree(Protocol):
def __iter__(self) -> Iterable[Union[CircuitSampleJob, 'CircuitSampleJobTree']]:
pass


CIRCUIT_SAMPLE_JOB_TREE = Union[CircuitSampleJob, CircuitSampleJobTree]


class Collector(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -125,15 +130,12 @@ def collect(
Returns:
The collector's result after all desired samples have been
collected.
See Also:
Python 3 documentation "Coroutines and Tasks"
https://docs.python.org/3/library/asyncio-task.html
"""
return asyncio.get_event_loop().run_until_complete(
self.collect_async(
sampler, concurrency=concurrency, max_total_samples=max_total_samples
)
return duet.run(
self.collect_async,
sampler,
concurrency=concurrency,
max_total_samples=max_total_samples,
)

async def collect_async(
Expand Down Expand Up @@ -163,53 +165,56 @@ async def collect_async(
Returns:
The collector's result after all desired samples have been
collected.
See Also:
Python 3 documentation "Coroutines and Tasks"
https://docs.python.org/3/library/asyncio-task.html
"""
pool = work_pool.CompletionOrderedAsyncWorkPool()
results = duet.AsyncCollector()
job_error = None
running_jobs = 0
queued_jobs: List[CircuitSampleJob] = []
remaining_samples = np.infty if max_total_samples is None else max_total_samples

async def _start_async_job(job):
return job, await sampler.run_async(job.circuit, repetitions=job.repetitions)
async def run_job(job):
nonlocal job_error
try:
result = await sampler.run_async(job.circuit, repetitions=job.repetitions)
except Exception as error:
if not job_error:
results.error(error)
job_error = error
else:
if not job_error:
results.add((job, result))

# Keep dispatching and processing work.
while True:
# Fill up the work pool.
while remaining_samples > 0 and pool.num_uncollected < concurrency:
if not queued_jobs:
queued_jobs.extend(_flatten_jobs(self.next_job()))

# If no jobs were given, stop asking until something completes.
if not queued_jobs:
async with duet.new_scope() as scope:
while True:
# Fill up the work pool.
while remaining_samples > 0 and running_jobs < concurrency:
if not queued_jobs:
queued_jobs.extend(_flatten_jobs(self.next_job()))

# If no jobs were given, stop asking until something completes.
if not queued_jobs:
break

# Start new sampling job.
new_job = queued_jobs.pop(0)
remaining_samples -= new_job.repetitions
running_jobs += 1
scope.spawn(run_job, new_job)

# If no jobs are running, we're in a steady state. Halt.
if not running_jobs:
break

# Start new sampling job.
new_job = queued_jobs.pop(0)
remaining_samples -= new_job.repetitions
pool.include_work(_start_async_job(new_job))

# If no jobs were started or running, we're in a steady state. Halt.
if not pool.num_uncollected:
break

# Forward next job result from pool.
done_job, done_val = await pool.__anext__()
self.on_job_result(done_job, done_val)


def _flatten_jobs(given: Optional[CIRCUIT_SAMPLE_JOB_TREE]) -> List[CircuitSampleJob]:
out: List[CircuitSampleJob] = []
if given is not None:
_flatten_jobs_helper(given, out=out)
return out
# Get result from next completed job and call on_job_result.
done_job, done_val = await results.__anext__()
running_jobs -= 1
self.on_job_result(done_job, done_val)


def _flatten_jobs_helper(given: CIRCUIT_SAMPLE_JOB_TREE, *, out: List[CircuitSampleJob]) -> None:
if isinstance(given, CircuitSampleJob):
out.append(given)
elif given is not None:
for item in given:
_flatten_jobs_helper(item, out=out)
def _flatten_jobs(tree: Optional[CIRCUIT_SAMPLE_JOB_TREE]) -> Iterator[CircuitSampleJob]:
if isinstance(tree, CircuitSampleJob):
yield tree
elif tree is not None:
for item in tree:
yield from _flatten_jobs(item)
28 changes: 25 additions & 3 deletions cirq/work/collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import duet
import pytest

import cirq
Expand All @@ -36,7 +38,7 @@ def test_circuit_sample_job_repr():
)


@pytest.mark.asyncio
@duet.sync
async def test_async_collect():
received = []

Expand All @@ -49,10 +51,10 @@ def next_job(self):
def on_job_result(self, job, result):
received.append(job.tag)

completion = TestCollector().collect_async(
result = await TestCollector().collect_async(
sampler=cirq.Simulator(), max_total_samples=100, concurrency=5
)
assert await completion is None
assert result is None
assert received == ['test'] * 10


Expand All @@ -72,6 +74,26 @@ def on_job_result(self, job, result):
assert received == ['test'] * 10


def test_failed_job():

class FailingSampler:
async def run_async(self, circuit, repetitions):
await duet.completed_future(None)
raise Exception('job failed!')

class TestCollector(cirq.Collector):
def next_job(self):
q = cirq.LineQubit(0)
circuit = cirq.Circuit(cirq.H(q), cirq.measure(q))
return cirq.CircuitSampleJob(circuit=circuit, repetitions=10, tag='test')

def on_job_result(self, job, result):
pass

with pytest.raises(Exception, match='job failed!'):
TestCollector().collect(sampler=FailingSampler(), max_total_samples=100, concurrency=5)


def test_collect_with_reaction():
events = [0]
sent = 0
Expand Down
Loading

0 comments on commit 7c90792

Please sign in to comment.