From 25aa7d0c95783aafb1b3606ea93b2bcc38ee1dbe Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 11 Feb 2022 00:58:02 +0530 Subject: [PATCH] Rewrite `cirq.stratified_circuit` following new Transformer API and primitives. (#4944) * Rewrite stratified_circuit following new Transformer API and primitives. * Reorder arguments lists * Address Mike's Feedback --- cirq/__init__.py | 2 +- cirq/optimizers/__init__.py | 14 +- cirq/optimizers/stratify.py | 151 -------- cirq/transformers/__init__.py | 2 + cirq/transformers/stratify.py | 190 ++++++++++ .../stratify_test.py | 331 +++++++++++++----- 6 files changed, 445 insertions(+), 245 deletions(-) delete mode 100644 cirq/optimizers/stratify.py create mode 100644 cirq/transformers/stratify.py rename cirq/{optimizers => transformers}/stratify_test.py (53%) diff --git a/cirq/__init__.py b/cirq/__init__.py index 0672592530e..ec89663d2ae 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -349,7 +349,6 @@ MergeInteractions, MergeInteractionsToSqrtIswap, MergeSingleQubitGates, - stratified_circuit, SynchronizeTerminalMeasurements, ) @@ -382,6 +381,7 @@ single_qubit_matrix_to_phased_x_z, single_qubit_matrix_to_phxz, single_qubit_op_to_framed_phase_form, + stratified_circuit, synchronize_terminal_measurements, TRANSFORMER, TransformerContext, diff --git a/cirq/optimizers/__init__.py b/cirq/optimizers/__init__.py index 36faecb6037..7cf975e3e16 100644 --- a/cirq/optimizers/__init__.py +++ b/cirq/optimizers/__init__.py @@ -60,14 +60,12 @@ MergeSingleQubitGates, ) -from cirq.optimizers.stratify import ( - stratified_circuit, -) - from cirq.optimizers.synchronize_terminal_measurements import ( SynchronizeTerminalMeasurements, ) +from cirq.transformers.stratify import stratified_circuit + from cirq.transformers.analytical_decompositions import ( compute_cphase_exponents_for_fsim_decomposition, decompose_cphase_into_two_fsim, @@ -156,3 +154,11 @@ deadline="v0.16", create_attribute=True, ) + +_compat.deprecated_submodule( + new_module_name="cirq.transformers.stratify", + old_parent="cirq.optimizers", + old_child="stratify", + deadline="v0.16", + create_attribute=True, +) diff --git a/cirq/optimizers/stratify.py b/cirq/optimizers/stratify.py deleted file mode 100644 index e09aecafca4..00000000000 --- a/cirq/optimizers/stratify.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2020 The Cirq Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -from typing import TYPE_CHECKING, Type, Callable, Union, Iterable, Set - -from cirq import ops, circuits - -if TYPE_CHECKING: - import cirq - -# A function that decides based on an operation -# whether it belongs to a class or not -Classifier = Callable[['cirq.Operation'], bool] - -# Any of the possible operation categories that we can stratify on. -Category = Union[ - 'cirq.Gate', 'cirq.Operation', Type['cirq.Gate'], Type['cirq.Operation'], Classifier -] - - -def stratified_circuit( - circuit: 'cirq.Circuit', *, categories: Iterable[Category] -) -> 'cirq.Circuit': - """Repacks avoiding simultaneous operations with different classes. - - Sometimes, certain operations should not be done at the same time. For - example, the physical hardware may not be capable of doing certain - operations at the same time. Or it may have worse noise characteristics - when certain operations are done at the same time. In these cases, it - would be good to rearrange the circuit so that these operations always - occur in different moments. - - (As a secondary effect, this may make the circuit easier to read.) - - This methods takes a series of classifiers identifying categories of - operations and then ensures operations from each category only overlap - with operations from the same category. There is no guarantee that the - resulting circuit will be optimally packed under this constraint. - - Args: - circuit: The circuit whose operations should be re-arranged. - categories: A list of classifiers picking out certain operations. - There are several ways to specify a classifier. You can pass - in a gate instance (e.g. `cirq.X`), a gate type (e.g. - `cirq.XPowGate`), an operation instance (e.g. - `cirq.X(cirq.LineQubit(0))`), an operation type (e.g. - `cirq.GlobalPhaseOperation`), or an arbitrary operation - predicate (e.g. `lambda op: len(op.qubits) == 2`). - - Returns: - A copy of the original circuit, but with re-arranged operations. - """ - - # Normalize categories into classifier functions. - classifiers = [_category_to_classifier(category) for category in categories] - # Make the classifiers exhaustive by adding an "everything else" bucket. - and_the_rest = lambda op: all(not classifier(op) for classifier in classifiers) - classifiers_and_the_rest = [*classifiers, and_the_rest] - - # Try the algorithm with each permutation of the classifiers. - classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest)) - reversed_circuit = circuit[::-1] - solutions = [] - for c in classifiers_permutations: - solutions.append(stratify_circuit(list(c), circuit)) - # Do the same thing, except this time in reverse. This helps for some - # circuits because it inserts operations at the end instead of at the - # beginning. - solutions.append(stratify_circuit(list(c), reversed_circuit)[::-1]) - - # Return the shortest circuit. - return min(solutions, key=lambda c: len(c)) - - -def stratify_circuit(classifiers: Iterable[Classifier], circuit: circuits.Circuit): - """Performs the stratification by iterating through the operations in the - circuit and using the given classifiers to align them. - - Args: - classifiers: A list of rules to align the circuit. Must be exhaustive, - i.e. all operations will be caught by one of the processors. - circuit: The circuit to break out into homogeneous moments. Will not be - edited. - - Returns: - The stratified circuit. - """ - solution = circuits.Circuit() - circuit_copy = circuit.copy() - while len(circuit_copy.all_qubits()) > 0: - for classifier in classifiers: - current_moment = circuits.Moment() - blocked_qubits: Set[ops.Qid] = set() - for moment_idx, moment in enumerate(circuit_copy.moments): - for op in moment.operations: - can_insert = classifier(op) - if not can_insert: - blocked_qubits.update(op.qubits) - else: - # Ensure that all the qubits for this operation are - # still available. - if not any(qubit in blocked_qubits for qubit in op.qubits): - # Add the operation to the current moment and - # remove it from the circuit. - current_moment = current_moment.with_operation(op) - blocked_qubits.update(op.qubits) - circuit_copy.batch_remove([(moment_idx, op)]) - - # Short-circuit: If all the qubits are blocked, go on to the - # next moment. - if blocked_qubits.issuperset(circuit_copy.all_qubits()): - break - - if len(current_moment) > 0: - solution.append(current_moment) - return solution - - -# No type for `category` because MyPy does not keep the return type when -# returning a callback. -def _category_to_classifier(category) -> Classifier: - """Normalizes the given category into a classifier function.""" - if isinstance(category, ops.Gate): - return lambda op: op.gate == category - if isinstance(category, ops.Operation): - return lambda op: op == category - elif isinstance(category, type) and issubclass(category, ops.Gate): - return lambda op: isinstance(op.gate, category) - elif isinstance(category, type) and issubclass(category, ops.Operation): - return lambda op: isinstance(op, category) - elif callable(category): - return lambda op: category(op) - else: - raise TypeError( - f'Unrecognized classifier type ' - f'{type(category)} ({category!r}).\n' - f'Expected a cirq.Gate, cirq.Operation, ' - f'Type[cirq.Gate], Type[cirq.Operation], ' - f'or Callable[[cirq.Operation], bool].' - ) diff --git a/cirq/transformers/__init__.py b/cirq/transformers/__init__.py index 0fa73cafc9f..ccf30f68e38 100644 --- a/cirq/transformers/__init__.py +++ b/cirq/transformers/__init__.py @@ -43,6 +43,8 @@ from cirq.transformers.align import align_left, align_right +from cirq.transformers.stratify import stratified_circuit + from cirq.transformers.expand_composite import expand_composite from cirq.transformers.eject_phased_paulis import eject_phased_paulis diff --git a/cirq/transformers/stratify.py b/cirq/transformers/stratify.py new file mode 100644 index 00000000000..ac0866f9a84 --- /dev/null +++ b/cirq/transformers/stratify.py @@ -0,0 +1,190 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer pass to repack circuits avoiding simultaneous operations with different classes.""" + +import itertools +from typing import ( + TYPE_CHECKING, + Type, + Callable, + Optional, + Union, + Iterable, + Sequence, + List, + Tuple, +) + +from cirq import ops, circuits, _import +from cirq.transformers import transformer_api, transformer_primitives + +drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers') + +if TYPE_CHECKING: + import cirq + +# A function that decides based on an operation +# whether it belongs to a class or not +Classifier = Callable[['cirq.Operation'], bool] + +# Any of the possible operation categories that we can stratify on. +Category = Union[ + 'cirq.Gate', 'cirq.Operation', Type['cirq.Gate'], Type['cirq.Operation'], Classifier +] + + +@transformer_api.transformer +def stratified_circuit( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + categories: Iterable[Category] = (), +) -> 'cirq.Circuit': + """Repacks avoiding simultaneous operations with different classes. + + This transforms the given circuit to ensure that no operations of different categories are + found in the same moment. Makes no optimality guarantees. + Tagged Operations marked with any of `context.tags_to_ignore` will be treated as a separate + category will be left in their original moments without stratification. + + Args: + circuit: The circuit whose operations should be re-arranged. Will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + categories: A list of classifiers picking out certain operations. There are several ways + to specify a classifier. You can pass in a gate instance (e.g. `cirq.X`), + a gate type (e.g. `cirq.XPowGate`), an operation instance + (e.g. `cirq.X(cirq.LineQubit(0))`), an operation type (e.g.`cirq.CircuitOperation`), + or an arbitrary operation predicate (e.g. `lambda op: len(op.qubits) == 2`). + + Returns: + A copy of the original circuit, but with re-arranged operations. + """ + + # Normalize categories into classifier functions. + classifiers = [_category_to_classifier(category) for category in categories] + # Make the classifiers exhaustive by adding an "everything else" bucket. + and_the_rest = lambda op: all(not classifier(op) for classifier in classifiers) + classifiers_and_the_rest = [*classifiers, and_the_rest] + + # Try the algorithm with each permutation of the classifiers. + classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest)) + reversed_circuit = circuit[::-1] + solutions = [] + for c in classifiers_permutations: + solutions.append( + _stratify_circuit( + circuit, + classifiers=list(c), + context=context or transformer_api.TransformerContext(), + ) + ) + # Do the same thing, except this time in reverse. This helps for some + # circuits because it inserts operations at the end instead of at the + # beginning. + solutions.append( + _stratify_circuit( + reversed_circuit, + classifiers=list(c), + context=context or transformer_api.TransformerContext(), + )[::-1] + ) + + # Return the shortest circuit. + return min(solutions, key=lambda c: len(c)) + + +def _stratify_circuit( + circuit: circuits.AbstractCircuit, + *, + context: 'cirq.TransformerContext', + classifiers: Sequence[Classifier], +) -> 'cirq.Circuit': + """Performs the stratification by iterating through the operations in the + circuit and using the given classifiers to align them. + + Tagged Operations marked with any of `context.tags_to_ignore` are treated as separate + categories and left in their original moments without stratification. + + Args: + circuit: The circuit to break out into homogeneous moments. Will not be edited. + context: `cirq.TransformerContext` storing common configurable options for transformers. + classifiers: A list of rules to align the circuit. Must be exhaustive, i.e. all operations + will be caught by one of the processors. + + Returns: + The stratified circuit. + """ + num_categories = len(classifiers) + 1 + + def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']: + stratified_ops: List[List['cirq.Operation']] = [[] for _ in range(num_categories)] + for op in m: + if set(op.tags) & set(context.tags_to_ignore): + stratified_ops[0].append(op) + continue + for i, classifier in enumerate(classifiers): + if classifier(op): + stratified_ops[i + 1].append(op) + break + return [circuits.Moment(op_list) for op_list in stratified_ops] + + stratified_circuit = transformer_primitives.map_moments(circuit, map_func).unfreeze(copy=False) + assert len(stratified_circuit) == len(circuit) * num_categories + + # Try to move operations to the left to reduce circuit depth, preserving stratification. + for curr_idx, moment in enumerate(stratified_circuit): + curr_category = curr_idx % num_categories + if curr_category == 0: + # Moment containing tagged operations to be ignored. + continue + batch_removals: List[Tuple[int, 'cirq.Operation']] = [] + batch_inserts: List[Tuple[int, 'cirq.Operation']] = [] + for op in moment: + prv_idx = stratified_circuit._prev_moment_available(op, curr_idx) + prv_idx = 0 if prv_idx is None else prv_idx + prv_category = prv_idx % num_categories + should_move_to_next_batch = curr_category < prv_category + prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch + assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories + if prv_idx < curr_idx: + batch_inserts.append((prv_idx, op)) + batch_removals.append((curr_idx, op)) + stratified_circuit.batch_remove(batch_removals) + stratified_circuit.batch_insert_into(batch_inserts) + return drop_empty_moments.drop_empty_moments(stratified_circuit) + + +# No type for `category` because mypy does not keep the return type when +# returning a callback. +def _category_to_classifier(category) -> Classifier: + """Normalizes the given category into a classifier function.""" + if isinstance(category, ops.Gate): + return lambda op: op.gate == category + if isinstance(category, ops.Operation): + return lambda op: op == category + elif isinstance(category, type) and issubclass(category, ops.Gate): + return lambda op: isinstance(op.gate, category) + elif isinstance(category, type) and issubclass(category, ops.Operation): + return lambda op: isinstance(op, category) + elif callable(category): + return lambda op: category(op) + else: + raise TypeError( + f'Unrecognized classifier type ' + f'{type(category)} ({category!r}).\n' + f'Expected a cirq.Gate, cirq.Operation, ' + f'Type[cirq.Gate], Type[cirq.Operation], ' + f'or Callable[[cirq.Operation], bool].' + ) diff --git a/cirq/optimizers/stratify_test.py b/cirq/transformers/stratify_test.py similarity index 53% rename from cirq/optimizers/stratify_test.py rename to cirq/transformers/stratify_test.py index 6cc0726f2e0..fe727a8e277 100644 --- a/cirq/optimizers/stratify_test.py +++ b/cirq/transformers/stratify_test.py @@ -16,6 +16,11 @@ import cirq +def test_deprecated_submodule(): + with cirq.testing.assert_deprecated("Use cirq.transformers.stratify instead", deadline="v0.16"): + _ = cirq.optimizers.stratify.stratified_circuit + + def test_stratified_circuit_classifier_types(): a, b, c, d = cirq.LineQubit.range(4) @@ -36,18 +41,21 @@ def test_stratified_circuit_classifier_types(): cirq.X, ], ) - assert gate_result == cirq.Circuit( - cirq.Moment( - [ - cirq.X(a), - cirq.X(d), - ] - ), - cirq.Moment( - [ - cirq.Y(b), - cirq.X(c) ** 0.5, - ] + cirq.testing.assert_same_circuits( + gate_result, + cirq.Circuit( + cirq.Moment( + [ + cirq.X(a), + cirq.X(d), + ] + ), + cirq.Moment( + [ + cirq.Y(b), + cirq.X(c) ** 0.5, + ] + ), ), ) @@ -57,18 +65,21 @@ def test_stratified_circuit_classifier_types(): cirq.XPowGate, ], ) - assert gate_type_result == cirq.Circuit( - cirq.Moment( - [ - cirq.X(a), - cirq.X(c) ** 0.5, - cirq.X(d), - ] - ), - cirq.Moment( - [ - cirq.Y(b), - ] + cirq.testing.assert_same_circuits( + gate_type_result, + cirq.Circuit( + cirq.Moment( + [ + cirq.X(a), + cirq.X(c) ** 0.5, + cirq.X(d), + ] + ), + cirq.Moment( + [ + cirq.Y(b), + ] + ), ), ) @@ -78,18 +89,21 @@ def test_stratified_circuit_classifier_types(): cirq.X(a), ], ) - assert operation_result == cirq.Circuit( - cirq.Moment( - [ - cirq.X(a), - ] - ), - cirq.Moment( - [ - cirq.Y(b), - cirq.X(c) ** 0.5, - cirq.X(d), - ] + cirq.testing.assert_same_circuits( + operation_result, + cirq.Circuit( + cirq.Moment( + [ + cirq.X(a), + ] + ), + cirq.Moment( + [ + cirq.Y(b), + cirq.X(c) ** 0.5, + cirq.X(d), + ] + ), ), ) @@ -99,15 +113,18 @@ def test_stratified_circuit_classifier_types(): cirq.GateOperation, ], ) - assert operation_type_result == cirq.Circuit( - cirq.Moment( - [ - cirq.X(a), - cirq.Y(b), - cirq.X(c) ** 0.5, - cirq.X(d), - ] - ) + cirq.testing.assert_same_circuits( + operation_type_result, + cirq.Circuit( + cirq.Moment( + [ + cirq.X(a), + cirq.Y(b), + cirq.X(c) ** 0.5, + cirq.X(d), + ] + ) + ), ) predicate_result = cirq.stratified_circuit( @@ -116,18 +133,21 @@ def test_stratified_circuit_classifier_types(): lambda op: op.qubits == (b,), ], ) - assert predicate_result == cirq.Circuit( - cirq.Moment( - [ - cirq.Y(b), - ] - ), - cirq.Moment( - [ - cirq.X(a), - cirq.X(d), - cirq.X(c) ** 0.5, - ] + cirq.testing.assert_same_circuits( + predicate_result, + cirq.Circuit( + cirq.Moment( + [ + cirq.Y(b), + ] + ), + cirq.Moment( + [ + cirq.X(a), + cirq.X(d), + cirq.X(c) ** 0.5, + ] + ), ), ) @@ -171,34 +191,37 @@ def test_overlapping_categories(): ], ) - assert result == cirq.Circuit( - cirq.Moment( - [ - cirq.Y(b), - cirq.Z(c), - ] - ), - cirq.Moment( - [ - cirq.X(a), - ] - ), - cirq.Moment( - [ - cirq.CNOT(a, b), - cirq.CNOT(c, d), - ] - ), - cirq.Moment( - [ - cirq.Y(b), - cirq.Z(c), - ] - ), - cirq.Moment( - [ - cirq.X(a), - ] + cirq.testing.assert_same_circuits( + result, + cirq.Circuit( + cirq.Moment( + [ + cirq.Y(b), + cirq.Z(c), + ] + ), + cirq.Moment( + [ + cirq.X(a), + ] + ), + cirq.Moment( + [ + cirq.CNOT(a, b), + cirq.CNOT(c, d), + ] + ), + cirq.Moment( + [ + cirq.Y(b), + cirq.Z(c), + ] + ), + cirq.Moment( + [ + cirq.X(a), + ] + ), ), ) @@ -228,7 +251,9 @@ def test_greedy_merging(): cirq.Moment([cirq.X(q1), cirq.X(q3)]), cirq.Moment([cirq.SWAP(q1, q2), cirq.SWAP(q3, q4)]), ) - assert cirq.stratified_circuit(input_circuit, categories=[cirq.X]) == expected + cirq.testing.assert_same_circuits( + cirq.stratified_circuit(input_circuit, categories=[cirq.X]), expected + ) def test_greedy_merging_reverse(): @@ -245,7 +270,9 @@ def test_greedy_merging_reverse(): cirq.Moment([cirq.X(q1), cirq.X(q4)]), cirq.Moment([cirq.SWAP(q3, q4)]), ) - assert cirq.stratified_circuit(input_circuit, categories=[cirq.X]) == expected + cirq.testing.assert_same_circuits( + cirq.stratified_circuit(input_circuit, categories=[cirq.X]), expected + ) def test_complex_circuit(): @@ -263,7 +290,131 @@ def test_complex_circuit(): cirq.Moment([cirq.X(q1), cirq.X(q4)]), cirq.Moment([cirq.ISWAP(q1, q2)]), ) - assert cirq.stratified_circuit(input_circuit, categories=[cirq.X, cirq.Z]) == expected + cirq.testing.assert_same_circuits( + cirq.stratified_circuit(input_circuit, categories=[cirq.X, cirq.Z]), expected + ) + + +def test_no_categories_earliest_insert(): + q1, q2, q3, q4, q5 = cirq.LineQubit.range(5) + input_circuit = cirq.Circuit( + cirq.Moment([cirq.ISWAP(q2, q3)]), + cirq.Moment([cirq.X(q1), cirq.ISWAP(q4, q5)]), + cirq.Moment([cirq.ISWAP(q1, q2), cirq.X(q4)]), + ) + cirq.testing.assert_same_circuits( + cirq.Circuit(input_circuit.all_operations()), cirq.stratified_circuit(input_circuit) + ) + + +def test_stratify_respects_no_compile_operations(): + q1, q2, q3, q4, q5 = cirq.LineQubit.range(5) + input_circuit = cirq.Circuit( + cirq.Moment( + [ + cirq.X(q1).with_tags("nocompile"), + cirq.ISWAP(q2, q3).with_tags("nocompile"), + cirq.Z(q5), + ] + ), + cirq.Moment([cirq.X(q1), cirq.ISWAP(q4, q5)]), + cirq.Moment([cirq.ISWAP(q1, q2), cirq.X(q4)]), + ) + expected = cirq.Circuit( + [ + cirq.Moment( + cirq.TaggedOperation(cirq.X(cirq.LineQubit(0)), 'nocompile'), + cirq.TaggedOperation(cirq.ISWAP(cirq.LineQubit(1), cirq.LineQubit(2)), 'nocompile'), + ), + cirq.Moment( + cirq.X(cirq.LineQubit(0)), + ), + cirq.Moment( + cirq.Z(cirq.LineQubit(4)), + ), + cirq.Moment( + cirq.ISWAP(cirq.LineQubit(3), cirq.LineQubit(4)), + cirq.ISWAP(cirq.LineQubit(0), cirq.LineQubit(1)), + ), + cirq.Moment( + cirq.X(cirq.LineQubit(3)), + ), + ] + ) + cirq.testing.assert_has_diagram( + input_circuit, + ''' +0: ───X['nocompile']───────X───────iSwap─── + │ +1: ───iSwap['nocompile']───────────iSwap─── + │ +2: ───iSwap──────────────────────────────── + +3: ────────────────────────iSwap───X─────── + │ +4: ───Z────────────────────iSwap─────────── +''', + ) + cirq.testing.assert_has_diagram( + expected, + ''' +0: ───X['nocompile']───────X───────iSwap─────── + │ +1: ───iSwap['nocompile']───────────iSwap─────── + │ +2: ───iSwap──────────────────────────────────── + +3: ────────────────────────────────iSwap───X─── + │ +4: ────────────────────────────Z───iSwap─────── +''', + ) + cirq.testing.assert_same_circuits( + cirq.stratified_circuit( + input_circuit, + categories=[cirq.X, cirq.Z], + context=cirq.TransformerContext(tags_to_ignore=("nocompile",)), + ), + expected, + ) + + +def test_does_not_move_ccos_behind_measurement(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.measure(q[0], key='m'), + cirq.X(q[1]).with_classical_controls('m'), + cirq.Moment(cirq.X.on_each(q[1], q[2])), + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───M─────────── + ║ +1: ───╫───X───X─── + ║ ║ +2: ───╫───╫───X─── + ║ ║ +m: ═══@═══^═══════ +''', + ) + c_out = cirq.stratified_circuit( + c_orig, categories=[cirq.GateOperation, cirq.ClassicallyControlledOperation] + ) + cirq.testing.assert_has_diagram( + c_out, + ''' + ┌──┐ +0: ────M───────────── + ║ +1: ────╫─────X───X─── + ║ ║ +2: ────╫X────╫─────── + ║ ║ +m: ════@═════^═══════ + └──┘ +''', + ) def test_heterogeneous_circuit(): @@ -293,7 +444,9 @@ def test_heterogeneous_circuit(): ), ) - assert cirq.stratified_circuit(input_circuit, categories=[cirq.X, cirq.Z]) == expected + cirq.testing.assert_same_circuits( + cirq.stratified_circuit(input_circuit, categories=[cirq.X, cirq.Z]), expected + ) def test_surface_code_cycle_stratifies_without_growing():