Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cirq.merge_k_qubit_unitaries transformer to replace cirq.MergeSingleQubitGates optimizer #4986

Merged
merged 11 commits into from
Feb 16, 2022
Merged
3 changes: 2 additions & 1 deletion cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,14 @@
map_moments,
map_operations,
map_operations_and_unroll,
merge_k_qubit_unitaries,
merge_k_qubit_unitaries_to_circuit_op,
merge_moments,
merge_operations,
merge_operations_to_circuit_op,
merge_single_qubit_gates,
merge_single_qubit_gates_to_phased_x_and_z,
merge_single_qubit_gates_to_phxz,
merge_single_qubit_moments_to_phxz,
prepare_two_qubit_state_using_cz,
prepare_two_qubit_state_using_sqrt_iswap,
single_qubit_matrix_to_gates,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/paulistring/convert_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def converted_gate_set(
"""
conv_circuit = circuits.Circuit(circuit)
optimizers.ConvertToCzAndSingleGates().optimize_circuit(conv_circuit)
conv_circuit = transformers.merge_single_qubit_gates(conv_circuit)
conv_circuit = transformers.merge_k_qubit_unitaries(conv_circuit, k=1)
ConvertToPauliStringPhasors(
ignore_failures=True,
keep_clifford=not no_clifford_gates,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_single_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cirq


@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.merge_single_qubit_gates instead.')
@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.merge_k_qubit_unitaries instead.')
class MergeSingleQubitGates(circuits.PointOptimizer):
"""Optimizes runs of adjacent unitary 1-qubit operations."""

Expand Down
19 changes: 11 additions & 8 deletions cirq-core/cirq/optimizers/merge_single_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def assert_optimizes(
before: cirq.Circuit,
expected: cirq.Circuit,
optimizer: Optional[Callable[[cirq.Circuit], None]] = None,
deprecated_msg: str = "Use cirq.merge_k_qubit_unitaries",
):
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated(deprecated_msg, deadline='v1.0'):
if optimizer is None:
optimizer = cirq.MergeSingleQubitGates().optimize_circuit
optimizer(before)
Expand All @@ -39,7 +40,7 @@ def assert_optimizes(


def test_leaves_singleton():
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
m = cirq.MergeSingleQubitGates()
q = cirq.NamedQubit('q')
c = cirq.Circuit([cirq.Moment([cirq.X(q)])])
Expand All @@ -50,15 +51,15 @@ def test_leaves_singleton():


def test_not_both():
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
with pytest.raises(ValueError):
_ = cirq.MergeSingleQubitGates(
synthesizer=lambda *args: None, rewriter=lambda *args: None
)


def test_combines_sequence():
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
m = cirq.MergeSingleQubitGates()
q = cirq.NamedQubit('q')
c = cirq.Circuit(cirq.X(q) ** 0.5, cirq.Z(q) ** 0.5, cirq.X(q) ** -0.5)
Expand Down Expand Up @@ -89,7 +90,7 @@ def test_removes_identity_sequence():


def test_stopped_at_2qubit():
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
m = cirq.MergeSingleQubitGates()
q = cirq.NamedQubit('q')
q2 = cirq.NamedQubit('q2')
Expand All @@ -116,7 +117,7 @@ def test_stopped_at_2qubit():


def test_ignores_2qubit_target():
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
m = cirq.MergeSingleQubitGates()
q = cirq.NamedQubit('q')
q2 = cirq.NamedQubit('q2')
Expand All @@ -140,7 +141,7 @@ class UnsupportedDummy(cirq.SingleQubitGate):
UnsupportedDummy()(q0),
)
c_orig = cirq.Circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
cirq.MergeSingleQubitGates().optimize_circuit(circuit)

assert circuit == c_orig
Expand All @@ -156,7 +157,7 @@ def test_rewrite():
cirq.CZ(q0, q1),
cirq.Y(q1),
)
with cirq.testing.assert_deprecated("Use cirq.merge_single_qubit_gates", deadline='v1.0'):
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
cirq.MergeSingleQubitGates(rewriter=lambda ops: cirq.H(ops[0].qubits[0])).optimize_circuit(
circuit
)
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_merge_single_qubit_gates_into_phased_x_z():
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
),
optimizer=cirq.merge_single_qubit_gates_into_phased_x_z,
deprecated_msg="Use cirq.merge_single_qubit_gates_to_phased_x_and_z",
)


Expand Down Expand Up @@ -217,4 +219,5 @@ def phxz(a, x, z):
phxz(-0.5, 0.5, 0).on(a),
),
optimizer=cirq.merge_single_qubit_gates_into_phxz,
deprecated_msg="Use cirq.merge_single_qubit_gates_to_phxz",
)
4 changes: 3 additions & 1 deletion cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@
dephase_measurements,
)

from cirq.transformers.merge_k_qubit_gates import merge_k_qubit_unitaries

from cirq.transformers.merge_single_qubit_gates import (
merge_single_qubit_gates,
merge_single_qubit_gates_to_phased_x_and_z,
merge_single_qubit_gates_to_phxz,
merge_single_qubit_moments_to_phxz,
)

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements
Expand Down
77 changes: 77 additions & 0 deletions cirq-core/cirq/transformers/merge_k_qubit_gates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer pass to merge connected components of k-qubit unitary operations."""

from typing import cast, Optional, Callable, TYPE_CHECKING

from cirq import ops, protocols, circuits
from cirq.transformers import transformer_api, transformer_primitives

if TYPE_CHECKING:
import cirq


@transformer_api.transformer
def merge_k_qubit_unitaries(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
k: int = 0,
rewriter: Optional[Callable[['cirq.CircuitOperation'], 'cirq.OP_TREE']] = None,
) -> 'cirq.Circuit':
Comment on lines +27 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused now, this seems like a kind of drastic API change. What is k for if it can only be zero or one ? This also errors by default ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k can by any integer > 0 . This will be useful to replace MergeInteractions and MergeInteractionsToSqrtISwap, both of which need to merge connected components of 2q gates. We can directly call this method with k=2 instead of rolling out another implementation.

"""Merges connected components of unitary operations, acting on <= k qubits.

Uses rewriter to convert a connected component of unitary operations acting on <= k-qubits
into a more desirable form. If not specified, connected components are replaced by a single
`cirq.MatrixGate` containing unitary matrix of the merged component.

Args:
circuit: Input circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options for transformers.
k: Connected components of unitary operations acting on <= k qubits are merged.
rewriter: Callable type that takes a `cirq.CircuitOperation`, encapsulating a connected
component of unitary operations acting on <= k qubits, and produces a `cirq.OP_TREE`.
Specifies how to merge the connected component into a more desirable form.

Returns:
Copy of the transformed input circuit.

Raises:
ValueError: If k <= 0
"""
if k <= 0:
raise ValueError(f"k should be greater than or equal to 1. Found {k}.")
merged_circuit_op_tag = "_merged_k_qubit_unitaries_component"

def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
return op
if rewriter:
return rewriter(
cast(circuits.CircuitOperation, op.untagged)
if merged_circuit_op_tag in op.tags
else circuits.CircuitOperation(circuits.FrozenCircuit(op))
)
return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)

circuit = transformer_primitives.merge_k_qubit_unitaries_to_circuit_op(
circuit,
k=k,
tags_to_ignore=context.tags_to_ignore if context else (),
merged_circuit_op_tag=merged_circuit_op_tag,
)
return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
).unfreeze(copy=False)
Loading