diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index b6e3856488c..181a990eb67 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -14,21 +14,22 @@ """Protocol and methods for obtaining Kraus representation of quantum channels.""" -from typing import Any, Sequence, Tuple, TypeVar, Union -import warnings - +from typing import Any, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING +from functools import reduce import numpy as np from typing_extensions import Protocol from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) -from cirq.protocols.mixture_protocol import has_mixture - +from cirq.protocols import mixture_protocol, decompose_protocol from cirq.type_workarounds import NotImplementedType +from cirq import qis +from cirq.ops import Moment + +if TYPE_CHECKING: + import cirq + # This is a special indicator value used by the channel method to determine # whether or not the caller provided a 'default' argument. It must be of type @@ -114,57 +115,62 @@ def kraus( where $I$ is the identity matrix. The matrices $A_k$ are sometimes called Kraus or noise operators. + Determines the Kraus representation of `val` by the following strategies: + + 1. Try to use `val._has_kraus_()`. + Case a) Method not present or returns `None`. + Continue to next strategy. + Case b) Returns the Kraus operator. + Method returns the result. + + 2. Try to use `mixture_protocol.mixture()`. + Case a) Method not present or returns `None`. + Continue to next strategy. + Case b) Method returns a valid mixture. + Method converts mixture into kraus and returns. + + 3. Try to use serial concatenation recursively. + Case a) One or more decomposed operators doesn't have Kraus. + `val` does not have a kraus representation. + Case b) All decomposed operators have Kraus representation. + Serially concatenate and return the result. + Args: - val: The value to describe by a channel. + val: The value to describe by Kraus representation. default: Determines the fallback behavior when `val` doesn't have - a channel. If `default` is not set, a TypeError is raised. If - default is set to a value, that value is returned. + a representation. If `default` is not set, a TypeError is raised. + If default is set to a value, that value is returned. Returns: - If `val` has a `_kraus_` method and its result is not NotImplemented, - that result is returned. Otherwise, if `val` has a `_mixture_` method - and its results is not NotImplement a tuple made up of channel - corresponding to that mixture being a probabilistic mixture of unitaries - is returned. Otherwise, if `val` has a `_unitary_` method and - its result is not NotImplemented a tuple made up of that result is - returned. Otherwise, if a default value was specified, the default - value is returned. + The kraus representation of `val`. Raises: - TypeError: `val` doesn't have a _kraus_ or _unitary_ method (or that - method returned NotImplemented) and also no default value was - specified. + TypeError: `val` doesn't have a _kraus_, _unitary_, _mixture_ method + (or that method returned NotImplemented) and also no default value + was specified. """ - channel_getter = getattr(val, '_channel_', None) - if channel_getter is not None: - warnings.warn( - '_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_', - DeprecationWarning, - ) - kraus_getter = getattr(val, '_kraus_', None) - kraus_result = NotImplemented if kraus_getter is None else kraus_getter() - if kraus_result is not NotImplemented: - return tuple(kraus_result) + result = _gettr_helper(val, ['_kraus_']) + if result is not None and result is not NotImplemented: + return result - mixture_getter = getattr(val, '_mixture_', None) - mixture_result = NotImplemented if mixture_getter is None else mixture_getter() - if mixture_result is not NotImplemented and mixture_result is not None: + mixture_result = mixture_protocol.mixture(val, None) + if mixture_result is not None and mixture_result is not NotImplemented: return tuple(np.sqrt(p) * u for p, u in mixture_result) - unitary_getter = getattr(val, '_unitary_', None) - unitary_result = NotImplemented if unitary_getter is None else unitary_getter() - if unitary_result is not NotImplemented and unitary_result is not None: - return (unitary_result,) + decomposed, qubits, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) - channel_result = NotImplemented if channel_getter is None else channel_getter() - if channel_result is not NotImplemented: - return tuple(channel_result) + if decomposed is not None and decomposed != [val] and decomposed != []: + + superoperator_list = [_moment_superoperator(x, qubits, None) for x in decomposed] + if not any([x is None for x in superoperator_list]): + superoperator_result = reduce(lambda x, y: x @ y, superoperator_list) + return tuple(qis.superoperator_to_kraus(superoperator_result)) if default is not RaiseTypeErrorIfNotProvided: return default - if kraus_getter is None and unitary_getter is None and mixture_getter is None: + if _gettr_helper(val, ['_kraus_', '_unitary_', '_mixture_']) is None: raise TypeError( "object of type '{}' has no _kraus_ or _mixture_ or " "_unitary_ method.".format(type(val)) @@ -177,7 +183,41 @@ def kraus( def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: - """Returns whether the value has a Kraus representation. + """Determines whether the value has a Kraus representation. + + Determines whether `val` has a Kraus representation by attempting + the following strategies: + + #TODO + + 1. Try to use `val._has_kraus_()`. + Case a) Method not present or returns `None` or returns `False`. + Continue to next strategy. + Case b) Method returns `True`. + return True. + + 1. Try to use `val._has_channel_()`. + Case a) Method not present or returns `None` or returns `False`. + Continue to next strategy. + Case b) Method returns `True`. + return True. + + 2. Try to use `val._kraus_()`. + Case a) Method not present or returns `NotImplemented`. + Continue to next strategy. + Case b) Method returns a 3D array. + return True. + + 3. Try to use `cirq.has_mixture()`. + Case a) Method not present or returns `None` or returns `False`. + Continue to next strategy. + Case b) Method returns `True`. + return True. + + 4. If decomposition is allowed apply recursion and check. + + If all the above methods fail then it is assumed to have no Kraus + representation. Args: val: The value to check. @@ -190,28 +230,44 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: the result is skipped. Returns: - If `val` has a `_has_kraus_` method and its result is not - NotImplemented, that result is returned. Otherwise, if `val` has a - `_has_mixture_` method and its result is not NotImplemented, that - result is returned. Otherwise if `val` has a `_has_unitary_` method - and its results is not NotImplemented, that result is returned. - Otherwise, if the value has a _kraus_ method return if that - has a non-default value. Returns False if none of these functions - exists. + Whether or not `val` has a Kraus representation. """ - kraus_getter = getattr(val, '_has_kraus_', None) - result = NotImplemented if kraus_getter is None else kraus_getter() - if result is not NotImplemented: + result = _gettr_helper(val, ['_has_kraus_', '_has_channel_']) + if result is not None and result is not NotImplemented: return result - result = has_mixture(val, allow_decompose=False) - if result is not NotImplemented and result: - return result + result = _gettr_helper(val, ['_kraus_']) + if result is not None and result is not NotImplemented: + return True + + if mixture_protocol.has_mixture(val, allow_decompose=False): + return True if allow_decompose: - operations, _, _ = _try_decompose_into_operations_and_qubits(val) - if operations is not None: + operations, _, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) + if operations is not None and operations != [val]: return all(has_kraus(val) for val in operations) - # No has methods, use `_kraus_` or delegates instead. - return kraus(val, None) is not None + return False + + +def _moment_superoperator( + op: Union['cirq.Operation'], qubits: Sequence['cirq.Qid'], default: Any +) -> Union[np.ndarray, TDefault]: + superoperator_result = Moment(op).expand_to(qubits)._superoperator_() + return superoperator_result if superoperator_result is not NotImplemented else default + + +def _gettr_helper(val: Any, gett_str_list: Sequence[str]) -> Any: + notImplementedFlag = False + for gettr_str in gett_str_list: + gettr = getattr(val, gettr_str, None) + if gettr is None: + continue + result = gettr() + if result is NotImplemented: + notImplementedFlag = True + elif result is not None: + return result + + return NotImplemented if notImplementedFlag else None diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 4fc9fc90ea5..f7476637e65 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -137,6 +137,72 @@ def _unitary_(self) -> np.ndarray: assert cirq.has_kraus(ReturnsUnitary()) +def test_serial_concatenation_default(): + q1 = cirq.GridQubit(1, 1) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _kraus_(self): + return NotImplemented + + def _unitary_(self): + return NotImplemented + + def _mixture_(self): + return NotImplemented + + class onlyDecompose: + def _decompose_(self): + return [cirq.Y.on(q1), defaultGate().on(q1)] + + def _unitary_(self): + return NotImplemented + + def _mixture_(self): + return NotImplemented + + with pytest.raises(TypeError, match="_unitary_ method."): + _ = cirq.kraus(onlyDecompose()) + assert cirq.kraus(onlyDecompose(), 1) == 1 + assert not cirq.has_kraus(onlyDecompose()) + + +def test_serial_concatenation_circuit(): + q1 = cirq.GridQubit(1, 1) + q2 = cirq.GridQubit(1, 2) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _kraus_(self): + return cirq.kraus(cirq.X) + + class onlyDecompose: + def _decompose_(self): + circ = cirq.Circuit([cirq.Y.on(q1), defaultGate().on(q2)]) + return cirq.decompose(circ) + + def _unitary_(self): + return NotImplemented + + def _mixture_(self): + return NotImplemented + + g = onlyDecompose() + c = cirq.kraus_to_superoperator((cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),)) + + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, None)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, NotImplemented)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, (1,))), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, LOCAL_DEFAULT)), c) + + assert cirq.has_kraus(g) + + class HasKraus(cirq.SingleQubitGate): def _has_kraus_(self) -> bool: return True diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index f08cffc00d5..25a5200a95a 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -36,6 +36,11 @@ assert_decompose_is_consistent_with_unitary, ) +from cirq.testing.consistent_kraus import ( + assert_kraus_is_consistent_with_unitary, + assert_kraus_is_consistent_with_mixture, +) + from cirq.testing.consistent_pauli_expansion import ( assert_pauli_expansion_is_consistent_with_unitary, ) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py new file mode 100644 index 00000000000..3c596eb4562 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -0,0 +1,72 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np + +from cirq import protocols +from cirq.testing import lin_alg_utils + + +def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: bool = False): + """Uses `cirq.unitary` to check `val.kraus`'s behavior.""" + # pylint: disable=unused-variable + # __tracebackhide__ = True + # pylint: enable=unused-variable + + expected = protocols.unitary(val, None) + if expected is None: + # If there's no unitary, it's vacuously consistent. + return + + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, None) + + # there is unitary and hence must have kraus operator + assert has_krs + assert len(krs) == 1 + actual = krs[0] + + if ignoring_global_phase: + lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) + else: + # coverage: ignore + np.testing.assert_allclose(actual, expected, atol=1e-8) + + +def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: bool = False): + """Uses `cirq.mixture` to check `cirq.kraus`'s behavior.""" + # pylint: disable=unused-variable + # __tracebackhide__ = True + # pylint: enable=unused-variable + + expected = protocols.mixture(val, None) + if expected is None: + # If there's no mixture, it's vacuously consistent. + return + + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = np.array(protocols.kraus_protocol.kraus(val, None)) + + # there is mixture and hence must have kraus operator + assert has_krs + actual = krs + expected = np.array([np.sqrt(p) * u for p, u in expected]) + + if ignoring_global_phase: + lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) + else: + # coverage: ignore + np.testing.assert_allclose(actual, expected, atol=1e-8) diff --git a/cirq-core/cirq/testing/consistent_kraus_test.py b/cirq-core/cirq/testing/consistent_kraus_test.py new file mode 100644 index 00000000000..8a5b9d6bd5e --- /dev/null +++ b/cirq-core/cirq/testing/consistent_kraus_test.py @@ -0,0 +1,71 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import numpy as np + +import cirq + + +class GoodGateKraus(cirq.SingleQubitGate): + def _kraus_(self, default=None): + return (np.array([[0, 1], [1, 0]]),) + + def _unitary_(self): + return np.array([[0, 1], [1, 0]]) + + def _mixture_(self): + return ((1, np.array([[0, 1], [1, 0]])),) + + +class BadGateKraus(cirq.SingleQubitGate): + def _kraus_(self, default=None): + return (np.array([[0, 1], [1, 0]]),) + + def _unitary_(self): + return np.array([[0, 1], [0, 1]]) + + def _mixture_(self): + return ((1, np.array([[0, 1], [0, 1]])),) + + +def test_assert_kraus_is_consistent_with_unitary(): + gate = GoodGateKraus() + cirq.testing.assert_kraus_is_consistent_with_unitary(gate) + + cirq.testing.assert_kraus_is_consistent_with_unitary(GoodGateKraus().on(cirq.NamedQubit("q"))) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_unitary(BadGateKraus()) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_unitary( + BadGateKraus().on(cirq.NamedQubit("q")) + ) + + +def test_assert_kraus_is_consistent_with_mixture(): + gate = GoodGateKraus() + cirq.testing.assert_kraus_is_consistent_with_mixture(gate) + + cirq.testing.assert_kraus_is_consistent_with_mixture(GoodGateKraus().on(cirq.NamedQubit("q"))) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_mixture(BadGateKraus()) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_mixture( + BadGateKraus().on(cirq.NamedQubit("q")) + ) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index 51d44a77416..e6999b18408 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -27,6 +27,10 @@ from cirq.testing.consistent_decomposition import ( assert_decompose_is_consistent_with_unitary, ) +from cirq.testing.consistent_kraus import ( + assert_kraus_is_consistent_with_unitary, + assert_kraus_is_consistent_with_mixture, +) from cirq.testing.consistent_phase_by import ( assert_phase_by_is_consistent_with_unitary, ) @@ -154,6 +158,8 @@ def _assert_meets_standards_helper( assert_qasm_is_consistent_with_unitary(val) assert_has_consistent_trace_distance_bound(val) assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) + assert_kraus_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) + assert_kraus_is_consistent_with_mixture(val, ignoring_global_phase=ignoring_global_phase) assert_phase_by_is_consistent_with_unitary(val) assert_pauli_expansion_is_consistent_with_unitary(val) assert_equivalent_repr(