Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make circuit._prev_moment_available a public method. #4980

Merged
merged 8 commits into from
Feb 14, 2022
28 changes: 25 additions & 3 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,7 @@ class Circuit(AbstractCircuit):
AbstractCircuit):

* next_moment_operating_on
* earliest_available_moment
* prev_moment_operating_on
* next_moments_operating_on
* operation_at
Expand Down Expand Up @@ -1950,7 +1951,27 @@ def transform_qubits(
with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)):
return Circuit(op_list, device=self._device if new_device is None else new_device)

def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> Optional[int]:
def earliest_available_moment(
self, op: 'cirq.Operation', *, end_moment_index: Optional[int] = None
) -> int:
"""Finds the index of the earliest (i.e. left most) moment which can accommodate `op`.

Note that, unlike `circuit.prev_moment_operating_on`, this method also takes care of
implicit dependencies between measurements and classically controlled operations (CCO)
that depend on the results of those measurements. Therefore, using this method, a CCO
`op` would not be allowed to move left past a measurement it depends upon.

Args:
op: Operation for which the earliest moment that can accommodate it needs to be found.
end_moment_index: The moment index just after the starting point of the reverse search.
Defaults to the length of the list of moments.

Returns:
Index of the earliest matching moment. Returns `end_moment_index` if no moment on left
is available.
"""
if end_moment_index is None:
end_moment_index = len(self.moments)
last_available = end_moment_index
k = end_moment_index
op_control_keys = protocols.control_keys(op)
Expand All @@ -1968,6 +1989,8 @@ def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) ->
):
return last_available
if self._can_add_op_at(k, op):
# Note: Remove the if condition after `self._device` is gone and move the method to
# `cirq.AbstractDevice`.
last_available = k
return last_available

Expand Down Expand Up @@ -2005,8 +2028,7 @@ def _pick_or_create_inserted_op_moment_index(

if strategy is InsertStrategy.EARLIEST:
if self._can_add_op_at(splitter_index, op):
p = self._prev_moment_available(op, splitter_index)
return p or 0
return self.earliest_available_moment(op, end_moment_index=splitter_index)

return self._pick_or_create_inserted_op_moment_index(
splitter_index, op, InsertStrategy.INLINE
Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,28 @@ def test_prev_moment_operating_on_distance(circuit_cls):
c.prev_moment_operating_on([a], 6, max_distance=-1)


def test_earliest_available_moment():
q = cirq.LineQubit.range(3)
c = cirq.Circuit(
cirq.Moment(cirq.measure(q[0], key="m")),
cirq.Moment(cirq.X(q[1]).with_classical_controls("m")),
)
assert c.earliest_available_moment(cirq.Y(q[0])) == 1
assert c.earliest_available_moment(cirq.Y(q[1])) == 2
assert c.earliest_available_moment(cirq.Y(q[2])) == 0
assert c.earliest_available_moment(cirq.Y(q[2]).with_classical_controls("m")) == 1
assert (
c.earliest_available_moment(cirq.Y(q[2]).with_classical_controls("m"), end_moment_index=1)
== 1
)

# Returns `end_moment_index` by default without verifying if an operation already exists there.
assert (
c.earliest_available_moment(cirq.Y(q[1]).with_classical_controls("m"), end_moment_index=1)
== 1
)


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_operation_at(circuit_cls):
a = cirq.NamedQubit('a')
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/transformers/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']:
batch_removals: List[Tuple[int, 'cirq.Operation']] = []
batch_inserts: List[Tuple[int, 'cirq.Operation']] = []
for op in moment:
prv_idx = stratified_circuit._prev_moment_available(op, curr_idx)
prv_idx = 0 if prv_idx is None else prv_idx
prv_idx = stratified_circuit.earliest_available_moment(op, end_moment_index=curr_idx)
prv_category = prv_idx % num_categories
should_move_to_next_batch = curr_category < prv_category
prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
Expand Down