diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index ec89663d2ae..b6814bdb156 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -394,6 +394,7 @@ two_qubit_gate_product_tabulation, TwoQubitGateTabulation, TwoQubitGateTabulationResult, + toggle_tags, unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index ccf30f68e38..1ee9de1f165 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -76,6 +76,7 @@ map_operations_and_unroll, merge_moments, merge_operations, + toggle_tags, unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, diff --git a/cirq-core/cirq/transformers/align_test.py b/cirq-core/cirq/transformers/align_test.py index 4e8a980a6a2..5203f6c2b30 100644 --- a/cirq-core/cirq/transformers/align_test.py +++ b/cirq-core/cirq/transformers/align_test.py @@ -71,6 +71,40 @@ def test_align_left_no_compile_context(): ) +def test_align_left_subset_of_operations(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + tag = "op_to_align" + c_orig = cirq.Circuit( + [ + cirq.Moment([cirq.Y(q1)]), + cirq.Moment([cirq.X(q2)]), + cirq.Moment([cirq.X(q1).with_tags(tag)]), + cirq.Moment([cirq.Y(q2)]), + cirq.measure(*[q1, q2], key='a'), + ] + ) + c_exp = cirq.Circuit( + [ + cirq.Moment([cirq.Y(q1)]), + cirq.Moment([cirq.X(q1).with_tags(tag), cirq.X(q2)]), + cirq.Moment(), + cirq.Moment([cirq.Y(q2)]), + cirq.measure(*[q1, q2], key='a'), + ] + ) + cirq.testing.assert_same_circuits( + cirq.toggle_tags( + cirq.align_left( + cirq.toggle_tags(c_orig, [tag]), + context=cirq.TransformerContext(tags_to_ignore=[tag]), + ), + [tag], + ), + c_exp, + ) + + def test_align_right_no_compile_context(): q1 = cirq.NamedQubit('q1') q2 = cirq.NamedQubit('q2') diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 153d8180bef..6015072ccf3 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -426,3 +426,34 @@ def unroll_circuit_op_greedy_frontier( ) frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier) return _to_target_circuit_type(unrolled_circuit, circuit) + + +def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False): + """Toggles tags applied on each operation in the circuit, via `op.tags ^= tags` + + For every operations `op` in the input circuit, the tags on `op` are replaced by a symmetric + difference of `op.tags` and `tags` -- this is useful in scenarios where you mark a small subset + of operations with a specific tag and then toggle the set of marked operations s.t. every + marked operation is now unmarked and vice versa. + + Often used in transformer workflows to apply a transformer on a small subset of operations. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + tags: Sequence of tags s.t. `op.tags ^= tags` is done for every operation `op` in circuit. + deep: If true, tags will be recursively toggled for operations in circuits wrapped inside + any circuit operations contained within `circuit`. + + Returns: + Copy of transformed input circuit with operation sets marked with `tags` toggled. + """ + tags_to_xor = set(tags) + + def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation': + return ( + op + if deep and isinstance(op, circuits.CircuitOperation) + else op.untagged.with_tags(*(set(op.tags) ^ tags_to_xor)) + ) + + return map_operations(circuit, map_func, deep=deep) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 6573425a8b1..6e56cc5f734 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -166,6 +166,26 @@ def test_map_operations_respects_tags_to_ignore(): ) +def test_apply_tag_to_inverted_op_set(): + q = cirq.LineQubit.range(2) + op = cirq.CNOT(*q) + tag = "tag_to_flip" + c_orig = cirq.Circuit(op, op.with_tags(tag), cirq.CircuitOperation(cirq.FrozenCircuit(op))) + # Toggle with deep = True. + c_toggled = cirq.Circuit( + op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op.with_tags(tag))) + ) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=True), c_toggled) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=True), c_orig) + + # Toggle with deep = False + c_toggled = cirq.Circuit( + op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op)).with_tags(tag) + ) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=False), c_toggled) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=False), c_orig) + + def test_unroll_circuit_op_and_variants(): q = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(q[0]), cirq.CNOT(q[0], q[1]), cirq.X(q[0]))