From d84e9c777a5f951e4bc6876782c5889bc785b5aa Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 9 Feb 2022 20:35:26 -0800 Subject: [PATCH 1/2] Add helper to apply transformers on specific subsets of operations in a circuit --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/transformers/__init__.py | 1 + cirq-core/cirq/transformers/align_test.py | 34 +++++++++++++++++++ .../transformers/transformer_primitives.py | 15 ++++++++ .../transformer_primitives_test.py | 16 +++++++++ 5 files changed, 67 insertions(+) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 0672592530e..9974b800ef8 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -397,6 +397,7 @@ unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, + xor_ops_with_tags, ) from cirq.qis import ( diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 0fa73cafc9f..c8f0a04195f 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -77,4 +77,5 @@ unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, + xor_ops_with_tags, ) diff --git a/cirq-core/cirq/transformers/align_test.py b/cirq-core/cirq/transformers/align_test.py index 4e8a980a6a2..1ff3d630365 100644 --- a/cirq-core/cirq/transformers/align_test.py +++ b/cirq-core/cirq/transformers/align_test.py @@ -71,6 +71,40 @@ def test_align_left_no_compile_context(): ) +def test_align_left_subset_of_operations(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + tag = "op_to_align" + c_orig = cirq.Circuit( + [ + cirq.Moment([cirq.Y(q1)]), + cirq.Moment([cirq.X(q2)]), + cirq.Moment([cirq.X(q1).with_tags(tag)]), + cirq.Moment([cirq.Y(q2)]), + cirq.measure(*[q1, q2], key='a'), + ] + ) + c_exp = cirq.Circuit( + [ + cirq.Moment([cirq.Y(q1)]), + cirq.Moment([cirq.X(q1).with_tags(tag), cirq.X(q2)]), + cirq.Moment(), + cirq.Moment([cirq.Y(q2)]), + cirq.measure(*[q1, q2], key='a'), + ] + ) + cirq.testing.assert_same_circuits( + cirq.xor_ops_with_tags( + cirq.align_left( + cirq.xor_ops_with_tags(c_orig, [tag]), + context=cirq.TransformerContext(tags_to_ignore=[tag]), + ), + [tag], + ), + c_exp, + ) + + def test_align_right_no_compile_context(): q1 = cirq.NamedQubit('q1') q2 = cirq.NamedQubit('q2') diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 153d8180bef..7eb32da87b6 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -426,3 +426,18 @@ def unroll_circuit_op_greedy_frontier( ) frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier) return _to_target_circuit_type(unrolled_circuit, circuit) + + +def xor_ops_with_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False): + tags_to_xor = set(tags) + + def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation': + op_tags = set(op.tags) + new_tags = (op_tags - tags_to_xor) | (tags_to_xor - op_tags) + return ( + op + if deep and isinstance(op, circuits.CircuitOperation) + else op.untagged.with_tags(*new_tags) + ) + + return map_operations(circuit, map_func, deep=deep) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 6573425a8b1..9c955457b2e 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -166,6 +166,22 @@ def test_map_operations_respects_tags_to_ignore(): ) +def test_apply_tag_to_inverted_op_set(): + q = cirq.LineQubit.range(2) + op = cirq.CNOT(*q) + tag = "tag_to_flip" + c_orig = cirq.Circuit(op, op.with_tags(tag), cirq.CircuitOperation(cirq.FrozenCircuit(op))) + c_flipped_with_deep = cirq.Circuit( + op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op.with_tags(tag))) + ) + c_flipped_without_deep = cirq.Circuit( + op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op)).with_tags(tag) + ) + for c_flip, deep in zip([c_flipped_with_deep, c_flipped_without_deep], [True, False]): + cirq.testing.assert_same_circuits(cirq.xor_ops_with_tags(c_orig, [tag], deep=deep), c_flip) + cirq.testing.assert_same_circuits(cirq.xor_ops_with_tags(c_flip, [tag], deep=deep), c_orig) + + def test_unroll_circuit_op_and_variants(): q = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(q[0]), cirq.CNOT(q[0], q[1]), cirq.X(q[0])) From 64297ae020db4e5e7f5ac34c8d624ac51390b714 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 10 Feb 2022 12:12:26 -0800 Subject: [PATCH 2/2] Rename to toggle_tags and address feedback --- cirq-core/cirq/__init__.py | 2 +- cirq-core/cirq/transformers/__init__.py | 2 +- cirq-core/cirq/transformers/align_test.py | 4 ++-- .../transformers/transformer_primitives.py | 24 +++++++++++++++---- .../transformer_primitives_test.py | 14 +++++++---- 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index a0032029b28..b6814bdb156 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -394,10 +394,10 @@ two_qubit_gate_product_tabulation, TwoQubitGateTabulation, TwoQubitGateTabulationResult, + toggle_tags, unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, - xor_ops_with_tags, ) from cirq.qis import ( diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 99b7643b434..1ee9de1f165 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -76,8 +76,8 @@ map_operations_and_unroll, merge_moments, merge_operations, + toggle_tags, unroll_circuit_op, unroll_circuit_op_greedy_earliest, unroll_circuit_op_greedy_frontier, - xor_ops_with_tags, ) diff --git a/cirq-core/cirq/transformers/align_test.py b/cirq-core/cirq/transformers/align_test.py index 1ff3d630365..5203f6c2b30 100644 --- a/cirq-core/cirq/transformers/align_test.py +++ b/cirq-core/cirq/transformers/align_test.py @@ -94,9 +94,9 @@ def test_align_left_subset_of_operations(): ] ) cirq.testing.assert_same_circuits( - cirq.xor_ops_with_tags( + cirq.toggle_tags( cirq.align_left( - cirq.xor_ops_with_tags(c_orig, [tag]), + cirq.toggle_tags(c_orig, [tag]), context=cirq.TransformerContext(tags_to_ignore=[tag]), ), [tag], diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 7eb32da87b6..6015072ccf3 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -428,16 +428,32 @@ def unroll_circuit_op_greedy_frontier( return _to_target_circuit_type(unrolled_circuit, circuit) -def xor_ops_with_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False): +def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False): + """Toggles tags applied on each operation in the circuit, via `op.tags ^= tags` + + For every operations `op` in the input circuit, the tags on `op` are replaced by a symmetric + difference of `op.tags` and `tags` -- this is useful in scenarios where you mark a small subset + of operations with a specific tag and then toggle the set of marked operations s.t. every + marked operation is now unmarked and vice versa. + + Often used in transformer workflows to apply a transformer on a small subset of operations. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + tags: Sequence of tags s.t. `op.tags ^= tags` is done for every operation `op` in circuit. + deep: If true, tags will be recursively toggled for operations in circuits wrapped inside + any circuit operations contained within `circuit`. + + Returns: + Copy of transformed input circuit with operation sets marked with `tags` toggled. + """ tags_to_xor = set(tags) def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation': - op_tags = set(op.tags) - new_tags = (op_tags - tags_to_xor) | (tags_to_xor - op_tags) return ( op if deep and isinstance(op, circuits.CircuitOperation) - else op.untagged.with_tags(*new_tags) + else op.untagged.with_tags(*(set(op.tags) ^ tags_to_xor)) ) return map_operations(circuit, map_func, deep=deep) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 9c955457b2e..6e56cc5f734 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -171,15 +171,19 @@ def test_apply_tag_to_inverted_op_set(): op = cirq.CNOT(*q) tag = "tag_to_flip" c_orig = cirq.Circuit(op, op.with_tags(tag), cirq.CircuitOperation(cirq.FrozenCircuit(op))) - c_flipped_with_deep = cirq.Circuit( + # Toggle with deep = True. + c_toggled = cirq.Circuit( op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op.with_tags(tag))) ) - c_flipped_without_deep = cirq.Circuit( + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=True), c_toggled) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=True), c_orig) + + # Toggle with deep = False + c_toggled = cirq.Circuit( op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op)).with_tags(tag) ) - for c_flip, deep in zip([c_flipped_with_deep, c_flipped_without_deep], [True, False]): - cirq.testing.assert_same_circuits(cirq.xor_ops_with_tags(c_orig, [tag], deep=deep), c_flip) - cirq.testing.assert_same_circuits(cirq.xor_ops_with_tags(c_flip, [tag], deep=deep), c_orig) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=False), c_toggled) + cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=False), c_orig) def test_unroll_circuit_op_and_variants():