diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index f1dd45369eb..3415337a49f 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1618,6 +1618,7 @@ class Circuit(AbstractCircuit): AbstractCircuit): * next_moment_operating_on + * earliest_available_moment * prev_moment_operating_on * next_moments_operating_on * operation_at @@ -1950,7 +1951,27 @@ def transform_qubits( with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): return Circuit(op_list, device=self._device if new_device is None else new_device) - def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> Optional[int]: + def earliest_available_moment( + self, op: 'cirq.Operation', *, end_moment_index: Optional[int] = None + ) -> int: + """Finds the index of the earliest (i.e. left most) moment which can accommodate `op`. + + Note that, unlike `circuit.prev_moment_operating_on`, this method also takes care of + implicit dependencies between measurements and classically controlled operations (CCO) + that depend on the results of those measurements. Therefore, using this method, a CCO + `op` would not be allowed to move left past a measurement it depends upon. + + Args: + op: Operation for which the earliest moment that can accommodate it needs to be found. + end_moment_index: The moment index just after the starting point of the reverse search. + Defaults to the length of the list of moments. + + Returns: + Index of the earliest matching moment. Returns `end_moment_index` if no moment on left + is available. + """ + if end_moment_index is None: + end_moment_index = len(self.moments) last_available = end_moment_index k = end_moment_index op_control_keys = protocols.control_keys(op) @@ -1968,6 +1989,8 @@ def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> ): return last_available if self._can_add_op_at(k, op): + # Note: Remove the if condition after `self._device` is gone and move the method to + # `cirq.AbstractDevice`. last_available = k return last_available @@ -2005,8 +2028,7 @@ def _pick_or_create_inserted_op_moment_index( if strategy is InsertStrategy.EARLIEST: if self._can_add_op_at(splitter_index, op): - p = self._prev_moment_available(op, splitter_index) - return p or 0 + return self.earliest_available_moment(op, end_moment_index=splitter_index) return self._pick_or_create_inserted_op_moment_index( splitter_index, op, InsertStrategy.INLINE diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 76e5262dd3a..aedfa273d51 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -1402,6 +1402,28 @@ def test_prev_moment_operating_on_distance(circuit_cls): c.prev_moment_operating_on([a], 6, max_distance=-1) +def test_earliest_available_moment(): + q = cirq.LineQubit.range(3) + c = cirq.Circuit( + cirq.Moment(cirq.measure(q[0], key="m")), + cirq.Moment(cirq.X(q[1]).with_classical_controls("m")), + ) + assert c.earliest_available_moment(cirq.Y(q[0])) == 1 + assert c.earliest_available_moment(cirq.Y(q[1])) == 2 + assert c.earliest_available_moment(cirq.Y(q[2])) == 0 + assert c.earliest_available_moment(cirq.Y(q[2]).with_classical_controls("m")) == 1 + assert ( + c.earliest_available_moment(cirq.Y(q[2]).with_classical_controls("m"), end_moment_index=1) + == 1 + ) + + # Returns `end_moment_index` by default without verifying if an operation already exists there. + assert ( + c.earliest_available_moment(cirq.Y(q[1]).with_classical_controls("m"), end_moment_index=1) + == 1 + ) + + @pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit]) def test_operation_at(circuit_cls): a = cirq.NamedQubit('a') diff --git a/cirq-core/cirq/transformers/stratify.py b/cirq-core/cirq/transformers/stratify.py index ac0866f9a84..dad3e3bdaee 100644 --- a/cirq-core/cirq/transformers/stratify.py +++ b/cirq-core/cirq/transformers/stratify.py @@ -152,8 +152,7 @@ def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']: batch_removals: List[Tuple[int, 'cirq.Operation']] = [] batch_inserts: List[Tuple[int, 'cirq.Operation']] = [] for op in moment: - prv_idx = stratified_circuit._prev_moment_available(op, curr_idx) - prv_idx = 0 if prv_idx is None else prv_idx + prv_idx = stratified_circuit.earliest_available_moment(op, end_moment_index=curr_idx) prv_category = prv_idx % num_categories should_move_to_next_batch = curr_category < prv_category prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch