Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cirq.toggle_tags helper to apply transformers on specific subsets of operations in a circuit #4973

Merged
merged 3 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
two_qubit_gate_product_tabulation,
TwoQubitGateTabulation,
TwoQubitGateTabulationResult,
toggle_tags,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
unroll_circuit_op_greedy_frontier,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
map_operations_and_unroll,
merge_moments,
merge_operations,
toggle_tags,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
unroll_circuit_op_greedy_frontier,
Expand Down
34 changes: 34 additions & 0 deletions cirq-core/cirq/transformers/align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,40 @@ def test_align_left_no_compile_context():
)


def test_align_left_subset_of_operations():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
tag = "op_to_align"
c_orig = cirq.Circuit(
[
cirq.Moment([cirq.Y(q1)]),
cirq.Moment([cirq.X(q2)]),
cirq.Moment([cirq.X(q1).with_tags(tag)]),
cirq.Moment([cirq.Y(q2)]),
cirq.measure(*[q1, q2], key='a'),
]
)
c_exp = cirq.Circuit(
[
cirq.Moment([cirq.Y(q1)]),
cirq.Moment([cirq.X(q1).with_tags(tag), cirq.X(q2)]),
cirq.Moment(),
cirq.Moment([cirq.Y(q2)]),
cirq.measure(*[q1, q2], key='a'),
]
)
cirq.testing.assert_same_circuits(
cirq.toggle_tags(
cirq.align_left(
cirq.toggle_tags(c_orig, [tag]),
context=cirq.TransformerContext(tags_to_ignore=[tag]),
),
[tag],
),
c_exp,
)


def test_align_right_no_compile_context():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
Expand Down
31 changes: 31 additions & 0 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,34 @@ def unroll_circuit_op_greedy_frontier(
)
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
return _to_target_circuit_type(unrolled_circuit, circuit)


def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False):
"""Toggles tags applied on each operation in the circuit, via `op.tags ^= tags`

For every operations `op` in the input circuit, the tags on `op` are replaced by a symmetric
difference of `op.tags` and `tags` -- this is useful in scenarios where you mark a small subset
of operations with a specific tag and then toggle the set of marked operations s.t. every
marked operation is now unmarked and vice versa.

Often used in transformer workflows to apply a transformer on a small subset of operations.

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
tags: Sequence of tags s.t. `op.tags ^= tags` is done for every operation `op` in circuit.
deep: If true, tags will be recursively toggled for operations in circuits wrapped inside
any circuit operations contained within `circuit`.

Returns:
Copy of transformed input circuit with operation sets marked with `tags` toggled.
"""
tags_to_xor = set(tags)

def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation':
return (
op
if deep and isinstance(op, circuits.CircuitOperation)
else op.untagged.with_tags(*(set(op.tags) ^ tags_to_xor))
)

return map_operations(circuit, map_func, deep=deep)
20 changes: 20 additions & 0 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ def test_map_operations_respects_tags_to_ignore():
)


def test_apply_tag_to_inverted_op_set():
q = cirq.LineQubit.range(2)
op = cirq.CNOT(*q)
tag = "tag_to_flip"
c_orig = cirq.Circuit(op, op.with_tags(tag), cirq.CircuitOperation(cirq.FrozenCircuit(op)))
# Toggle with deep = True.
c_toggled = cirq.Circuit(
op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op.with_tags(tag)))
)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=True), c_toggled)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=True), c_orig)

# Toggle with deep = False
c_toggled = cirq.Circuit(
op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op)).with_tags(tag)
)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=False), c_toggled)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=False), c_orig)


def test_unroll_circuit_op_and_variants():
q = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q[0]), cirq.CNOT(q[0], q[1]), cirq.X(q[0]))
Expand Down