From 1a74e94fffc255e9f3789d02a25686196983910a Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 13:30:20 +0530 Subject: [PATCH 01/33] Added serial concatanation and wrote a test for the same --- cirq-core/cirq/protocols/kraus_protocol.py | 24 +++++++++++++++---- .../cirq/protocols/kraus_protocol_test.py | 14 +++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 8975d41e96c..c6ad1ab93b5 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -25,11 +25,11 @@ deprecated_class, ) from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose from cirq.protocols.mixture_protocol import has_mixture +from cirq.ops import Gate +from cirq.devices import LineQid from cirq.type_workarounds import NotImplementedType @@ -164,7 +164,7 @@ def kraus( return tuple(kraus_result) mixture_getter = getattr(val, '_mixture_', None) - mixture_result = NotImplemented if mixture_getter is None else mixture_getter() + mixture_result = NotImplemented if mpixture_getter is None else mixture_getter() if mixture_result is not NotImplemented and mixture_result is not None: return tuple(np.sqrt(p) * u for p, u in mixture_result) @@ -177,6 +177,22 @@ def kraus( if channel_result is not NotImplemented: return tuple(channel_result) + if isinstance(val, Gate): + operation = val.on(*LineQid.for_gate(val)) + else: + operation = val + + kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) + + if all([x != None for x in kraus_list]): + kraus_result = kraus_list[0] + + for i in range(1, len(kraus_list)): + kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] + + if len(kraus_result) != 0: + return tuple(kraus_result) + if default is not RaiseTypeErrorIfNotProvided: return default diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index fa6e74bf0ee..c5303c57ee0 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -157,6 +157,20 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) +def test_serial_concatation(): + g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) + + c = (cirq.unitary(g),) + + np.allclose(cirq.kraus(g), c) + np.allclose(cirq.kraus(g, None), c) + np.allclose(cirq.kraus(g, NotImplemented), c) + np.allclose(cirq.kraus(g, (1,)), c) + np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + + assert cirq.has_kraus(g) + + def test_channel_fallback_to_unitary(): u = np.array([[1, 0], [1, 0]]) From 94fcdfa6e4ec5f65574bb6c95d1b850a98864318 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 13:30:20 +0530 Subject: [PATCH 02/33] Added serial concatanation and wrote a test for the same --- cirq-core/cirq/protocols/kraus_protocol.py | 24 +++++++++++++++---- .../cirq/protocols/kraus_protocol_test.py | 14 +++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 8975d41e96c..5f9048b7536 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -25,11 +25,11 @@ deprecated_class, ) from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose from cirq.protocols.mixture_protocol import has_mixture +from cirq.ops import Gate +from cirq.devices import LineQid from cirq.type_workarounds import NotImplementedType @@ -164,7 +164,7 @@ def kraus( return tuple(kraus_result) mixture_getter = getattr(val, '_mixture_', None) - mixture_result = NotImplemented if mixture_getter is None else mixture_getter() + mixture_result = NotImplemented if mpixture_getter is None else mixture_getter() if mixture_result is not NotImplemented and mixture_result is not None: return tuple(np.sqrt(p) * u for p, u in mixture_result) @@ -177,6 +177,22 @@ def kraus( if channel_result is not NotImplemented: return tuple(channel_result) + if isinstance(val, Gate): + operation = val.on(*LineQid.for_gate(val)) + else: + operation = val + + kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) + + if all([x != default for x in kraus_list]): + kraus_result = kraus_list[0] + + for i in range(1, len(kraus_list)): + kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] + + if len(kraus_result) != 0: + return tuple(kraus_result) + if default is not RaiseTypeErrorIfNotProvided: return default diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index fa6e74bf0ee..c5303c57ee0 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -157,6 +157,20 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) +def test_serial_concatation(): + g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) + + c = (cirq.unitary(g),) + + np.allclose(cirq.kraus(g), c) + np.allclose(cirq.kraus(g, None), c) + np.allclose(cirq.kraus(g, NotImplemented), c) + np.allclose(cirq.kraus(g, (1,)), c) + np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + + assert cirq.has_kraus(g) + + def test_channel_fallback_to_unitary(): u = np.array([[1, 0], [1, 0]]) From d67785ee9fbe025cb61598b55da931750b38cc3b Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 13:44:54 +0530 Subject: [PATCH 03/33] Fixed small error --- cirq-core/cirq/protocols/kraus_protocol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 268f6722f4a..5f9048b7536 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -184,11 +184,7 @@ def kraus( kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) -<<<<<<< HEAD if all([x != default for x in kraus_list]): -======= - if all([x != None for x in kraus_list]): ->>>>>>> 1a74e94fffc255e9f3789d02a25686196983910a kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): From 3ea770755380265770430651061aaeb9680d1efe Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 13:46:18 +0530 Subject: [PATCH 04/33] Fixed small error --- cirq-core/cirq/protocols/kraus_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 5f9048b7536..0476dc153dd 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -164,7 +164,7 @@ def kraus( return tuple(kraus_result) mixture_getter = getattr(val, '_mixture_', None) - mixture_result = NotImplemented if mpixture_getter is None else mixture_getter() + mixture_result = NotImplemented if mixture_getter is None else mixture_getter() if mixture_result is not NotImplemented and mixture_result is not None: return tuple(np.sqrt(p) * u for p, u in mixture_result) From 3d1c9f2a09e70dd621ba9fa08e79ebd88892d9b6 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 13:30:20 +0530 Subject: [PATCH 05/33] Added serial concatanation and wrote a test for the same Added serial concatanation and wrote a test for the same Fixed small error Fixed small error --- cirq-core/cirq/protocols/kraus_protocol.py | 28 +++++++++++++++++-- .../cirq/protocols/kraus_protocol_test.py | 14 ++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 8975d41e96c..457062fcdbb 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -25,11 +25,11 @@ deprecated_class, ) from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose from cirq.protocols.mixture_protocol import has_mixture +from cirq.ops import Gate +from cirq.devices import LineQid from cirq.type_workarounds import NotImplementedType @@ -177,6 +177,28 @@ def kraus( if channel_result is not NotImplemented: return tuple(channel_result) + if isinstance(val, Gate): + operation = val.on(*LineQid.for_gate(val)) + else: + operation = val + + kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) + + if all([x != default for x in kraus_list]): +<<<<<<< HEAD +======= + if all([x != None for x in kraus_list]): +>>>>>>> 1a74e94f... Added serial concatanation and wrote a test for the same +======= +>>>>>>> d67785ee... Fixed small error + kraus_result = kraus_list[0] + + for i in range(1, len(kraus_list)): + kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] + + if len(kraus_result) != 0: + return tuple(kraus_result) + if default is not RaiseTypeErrorIfNotProvided: return default diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index fa6e74bf0ee..c5303c57ee0 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -157,6 +157,20 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) +def test_serial_concatation(): + g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) + + c = (cirq.unitary(g),) + + np.allclose(cirq.kraus(g), c) + np.allclose(cirq.kraus(g, None), c) + np.allclose(cirq.kraus(g, NotImplemented), c) + np.allclose(cirq.kraus(g, (1,)), c) + np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + + assert cirq.has_kraus(g) + + def test_channel_fallback_to_unitary(): u = np.array([[1, 0], [1, 0]]) From 3e5c399e8d2f822d31712e8b2b22759b88043fc6 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 14:23:55 +0530 Subject: [PATCH 06/33] Changed cirq.kraus to kraus --- cirq-core/cirq/protocols/kraus_protocol.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 457062fcdbb..f22a7cd2275 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -182,15 +182,9 @@ def kraus( else: operation = val - kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) + kraus_list = list(map(lambda x: kraus(x, default), decompose(operation))) if all([x != default for x in kraus_list]): -<<<<<<< HEAD -======= - if all([x != None for x in kraus_list]): ->>>>>>> 1a74e94f... Added serial concatanation and wrote a test for the same -======= ->>>>>>> d67785ee... Fixed small error kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): From 2635d67d23b0b4fd2a2e9cb5ecfba337bdeb1bf5 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 5 Sep 2021 14:28:43 +0530 Subject: [PATCH 07/33] merge error fixed --- cirq-core/cirq/protocols/kraus_protocol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index b6c51252882..f22a7cd2275 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -182,11 +182,7 @@ def kraus( else: operation = val -<<<<<<< HEAD kraus_list = list(map(lambda x: kraus(x, default), decompose(operation))) -======= - kraus_list = list(map(lambda x: cirq.kraus(x, default), decompose(operation))) ->>>>>>> 798615cacc34004a9b53b12ff680622a9a1d8587 if all([x != default for x in kraus_list]): kraus_result = kraus_list[0] From 4e68cb6f3e5ad3375fcd2c722d14eef763325ce8 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Thu, 9 Sep 2021 21:57:34 +0530 Subject: [PATCH 08/33] Fixed deep comparision issue --- cirq-core/cirq/protocols/kraus_protocol.py | 28 +++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index f22a7cd2275..03c038ca52e 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -182,15 +182,27 @@ def kraus( else: operation = val - kraus_list = list(map(lambda x: kraus(x, default), decompose(operation))) + decomposed = decompose(operation) + if decomposed != [operation]: + kraus_list = list(map(lambda x: kraus(x, default), decomposed)) + + def checkEquality(x, y): + if type(x) != type(y): + return False + if type(x) not in [list, tuple, np.ndarray]: + return x == y + if type(x) == np.ndarray: + return x.shape == y.shape and np.all(x == y) + if len(x) != len(y): + return False + return all([checkEquality(a, b) for a, b in zip(x, y)]) + + if not any([checkEquality(x, default) for x in kraus_list]): + kraus_result = kraus_list[0] + + for i in range(1, len(kraus_list)): + kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] - if all([x != default for x in kraus_list]): - kraus_result = kraus_list[0] - - for i in range(1, len(kraus_list)): - kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] - - if len(kraus_result) != 0: return tuple(kraus_result) if default is not RaiseTypeErrorIfNotProvided: From e7b3bdb0a5cd8e4fe8f2d01c77715b0cee82cf4f Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 14 Sep 2021 14:01:24 +0530 Subject: [PATCH 09/33] Fixed matrix multiplication --- cirq-core/cirq/protocols/kraus_protocol.py | 65 +++++++++++++++---- .../cirq/protocols/kraus_protocol_test.py | 2 +- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 03c038ca52e..0e9739f2bf9 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -177,31 +177,70 @@ def kraus( if channel_result is not NotImplemented: return tuple(channel_result) + def checkEquality(x, y): + if type(x) != type(y): + return False + if type(x) not in [list, tuple, np.ndarray]: + return x == y + if type(x) == np.ndarray: + return x.shape == y.shape and np.all(x == y) + if len(x) != len(y): + return False + return all([checkEquality(a, b) for a, b in zip(x, y)]) + + def kraus_tensor(op, qubits, default): + kraus_list = kraus(op, default) + if checkEquality(kraus_list, default): + return default + + val = None + op_q = op.qubits + found = False + for i in range(len(qubits)): + if qubits[i] in op_q: + if not found: + found = True + if val is None: + val = kraus_list + else: + val = tuple([np.kron(x, y) for x in val for y in kraus_list]) + + elif val is None: + val = (np.identity(2),) + else: + val = tuple([np.kron(x, np.identity(2)) for x in val]) + + return val + + qubits = None if isinstance(val, Gate): operation = val.on(*LineQid.for_gate(val)) + qubits = LineQid.for_gate(val) else: operation = val decomposed = decompose(operation) + if decomposed != [operation]: - kraus_list = list(map(lambda x: kraus(x, default), decomposed)) - - def checkEquality(x, y): - if type(x) != type(y): - return False - if type(x) not in [list, tuple, np.ndarray]: - return x == y - if type(x) == np.ndarray: - return x.shape == y.shape and np.all(x == y) - if len(x) != len(y): - return False - return all([checkEquality(a, b) for a, b in zip(x, y)]) + if qubits is None: + qubits = [] + for x in decomposed: + qubits += list(x.qubits) + qubits = sorted(list(set(qubits))) + + kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) if not any([checkEquality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): - kraus_result = [op_1 * op_2 for op_1 in kraus_result for op_2 in kraus_list[i]] + try: + kraus_result = [ + op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i] + ] + except: + print("---->", operation) + exit() return tuple(kraus_result) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index c5303c57ee0..42b091a71a5 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -157,7 +157,7 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) -def test_serial_concatation(): +def test_serial_concatenation(): g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) c = (cirq.unitary(g),) From e91faacb009150062f602488205acd77b1b5152c Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 14 Sep 2021 15:02:16 +0530 Subject: [PATCH 10/33] Fixed typecasting issue --- cirq-core/cirq/protocols/kraus_protocol.py | 26 +++++++++------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 0e9739f2bf9..40c2c9db278 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -14,7 +14,7 @@ """Protocol and methods for quantum channels.""" -from typing import Any, Sequence, Tuple, TypeVar, Union +from typing import Any, Sequence, Tuple, TypeVar, Union, List import warnings import numpy as np @@ -29,6 +29,7 @@ from cirq.protocols.mixture_protocol import has_mixture from cirq.ops import Gate +from cirq.ops.raw_types import Qid from cirq.devices import LineQid from cirq.type_workarounds import NotImplementedType @@ -212,35 +213,28 @@ def kraus_tensor(op, qubits, default): return val - qubits = None if isinstance(val, Gate): operation = val.on(*LineQid.for_gate(val)) - qubits = LineQid.for_gate(val) else: operation = val decomposed = decompose(operation) if decomposed != [operation]: - if qubits is None: - qubits = [] - for x in decomposed: - qubits += list(x.qubits) - qubits = sorted(list(set(qubits))) + qubits: List[Qid] = [] + for x in decomposed: + qubits.extend(x.qubits) + qubits = sorted(list(set(qubits))) kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) - if not any([checkEquality(x, default) for x in kraus_list]): + if not any([checkEquality(x, default) for x in kraus_list]) or len(kraus_list) == 0: kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): - try: - kraus_result = [ - op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i] - ] - except: - print("---->", operation) - exit() + kraus_result = [ + op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i] + ] return tuple(kraus_result) From 78e983a096b3535c7186a84db7fd60558bec1941 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 14 Sep 2021 15:22:37 +0530 Subject: [PATCH 11/33] List index error fixed --- cirq-core/cirq/protocols/kraus_protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 40c2c9db278..b07115b6e18 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -28,9 +28,8 @@ from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose from cirq.protocols.mixture_protocol import has_mixture -from cirq.ops import Gate -from cirq.ops.raw_types import Qid -from cirq.devices import LineQid +from cirq.ops.raw_types import Qid, Gate +from cirq.devices.line_qubit import LineQid from cirq.type_workarounds import NotImplementedType @@ -228,7 +227,8 @@ def kraus_tensor(op, qubits, default): qubits = sorted(list(set(qubits))) kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) - if not any([checkEquality(x, default) for x in kraus_list]) or len(kraus_list) == 0: + if len(kraus_list) != 0 and not any([checkEquality(x, default) for x in kraus_list]): + kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): From 682251cde91e2ed97d3da858a1acebf722b22f97 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 14 Sep 2021 15:23:23 +0530 Subject: [PATCH 12/33] Formatting error --- cirq-core/cirq/protocols/kraus_protocol.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index b07115b6e18..e776e83a23a 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -232,9 +232,7 @@ def kraus_tensor(op, qubits, default): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): - kraus_result = [ - op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i] - ] + kraus_result = [op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i]] return tuple(kraus_result) From b5a761058b575d2a676367a22d14ac4afd8e2451 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Wed, 15 Sep 2021 14:36:54 +0530 Subject: [PATCH 13/33] Added kraus consistency tests --- cirq-core/cirq/testing/__init__.py | 4 ++ cirq-core/cirq/testing/consistent_kraus.py | 50 ++++++++++++++++++ .../cirq/testing/consistent_kraus_test.py | 52 +++++++++++++++++++ .../cirq/testing/consistent_protocols.py | 4 ++ 4 files changed, 110 insertions(+) create mode 100644 cirq-core/cirq/testing/consistent_kraus.py create mode 100644 cirq-core/cirq/testing/consistent_kraus_test.py diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 387249245de..9c7969faecd 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -40,6 +40,10 @@ assert_decompose_is_consistent_with_unitary, ) +from cirq.testing.consistent_kraus import ( + assert_kraus_is_consistent_with_unitary, +) + from cirq.testing.consistent_pauli_expansion import ( assert_pauli_expansion_is_consistent_with_unitary, ) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py new file mode 100644 index 00000000000..56f07d40ac1 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -0,0 +1,50 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np + +from cirq import devices, protocols, ops, circuits +from cirq.testing import lin_alg_utils + + +def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: bool = False): + """Uses `val._unitary_` to check `val._phase_by_`'s behavior.""" + # pylint: disable=unused-variable + __tracebackhide__ = True + # pylint: enable=unused-variable + + expected = (protocols.unitary(val, None),) + if expected is None: + # If there's no unitary, it's vacuously consistent. + return + if isinstance(val, ops.Operation): + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, default=None) + else: + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, default=None) + + if not has_krs: + # If there's no kraus, it's vacuously consistent. + return + + actual = krs + + if ignoring_global_phase: + lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) + else: + # coverage: ignore + np.testing.assert_allclose(actual, expected, atol=1e-8) diff --git a/cirq-core/cirq/testing/consistent_kraus_test.py b/cirq-core/cirq/testing/consistent_kraus_test.py new file mode 100644 index 00000000000..53dbba42af3 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_kraus_test.py @@ -0,0 +1,52 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import numpy as np + +import cirq + + +class GoodGateKraus(cirq.SingleQubitGate): + def _kraus_(self, default=None): + return (np.array([[0, 1], [1, 0]]),) + + def _unitary_(self): + return np.array([[0, 1], [1, 0]]) + + +class BadGateKraus(cirq.SingleQubitGate): + def _kraus_(self, default=None): + return (np.array([[0, 1], [1, 0]]),) + + def _unitary_(self): + return np.array([[0, 1], [0, 1]]) + + +def test_assert_kraus_is_consistent_with_unitary(): + gate = GoodGateKraus() + cirq.testing.assert_kraus_is_consistent_with_unitary(gate) + + cirq.testing.assert_kraus_is_consistent_with_unitary( + GoodGateKraus().on(cirq.NamedQubit('q')) + ) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_unitary(BadGateKraus()) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_unitary( + BadGateKraus().on(cirq.NamedQubit('q')) + ) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index cfa208c5067..fb916d0be7e 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -27,6 +27,9 @@ from cirq.testing.consistent_decomposition import ( assert_decompose_is_consistent_with_unitary, ) +from cirq.testing.consistent_kraus import ( + assert_kraus_is_consistent_with_unitary, +) from cirq.testing.consistent_phase_by import ( assert_phase_by_is_consistent_with_unitary, ) @@ -153,6 +156,7 @@ def _assert_meets_standards_helper( assert_qasm_is_consistent_with_unitary(val) assert_has_consistent_trace_distance_bound(val) assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) + assert_kraus_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) assert_phase_by_is_consistent_with_unitary(val) assert_pauli_expansion_is_consistent_with_unitary(val) assert_equivalent_repr( From f723659f021164680d66d14a0d6f55cba9d56078 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Thu, 16 Sep 2021 15:27:09 +0530 Subject: [PATCH 14/33] Changes in kraus consistency tests --- cirq-core/cirq/protocols/kraus_protocol.py | 1 + cirq-core/cirq/testing/consistent_kraus.py | 14 +++++++------- cirq-core/cirq/testing/consistent_kraus_test.py | 4 +--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index e776e83a23a..5c3ab108e20 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -177,6 +177,7 @@ def kraus( if channel_result is not NotImplemented: return tuple(channel_result) + # serial concatenation def checkEquality(x, y): if type(x) != type(y): return False diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index 56f07d40ac1..f32e93c2693 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -16,20 +16,22 @@ import numpy as np -from cirq import devices, protocols, ops, circuits +from cirq import protocols, ops from cirq.testing import lin_alg_utils def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: bool = False): """Uses `val._unitary_` to check `val._phase_by_`'s behavior.""" # pylint: disable=unused-variable - __tracebackhide__ = True + # __tracebackhide__ = True # pylint: enable=unused-variable - expected = (protocols.unitary(val, None),) + expected = protocols.unitary(val, None) if expected is None: # If there's no unitary, it's vacuously consistent. return + expected = (expected,) + if isinstance(val, ops.Operation): has_krs = protocols.kraus_protocol.has_kraus(val) krs = protocols.kraus_protocol.kraus(val, default=None) @@ -37,10 +39,8 @@ def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: boo has_krs = protocols.kraus_protocol.has_kraus(val) krs = protocols.kraus_protocol.kraus(val, default=None) - if not has_krs: - # If there's no kraus, it's vacuously consistent. - return - + # there is unitary and hence must have kraus operator + assert has_krs actual = krs if ignoring_global_phase: diff --git a/cirq-core/cirq/testing/consistent_kraus_test.py b/cirq-core/cirq/testing/consistent_kraus_test.py index 53dbba42af3..89095528a73 100644 --- a/cirq-core/cirq/testing/consistent_kraus_test.py +++ b/cirq-core/cirq/testing/consistent_kraus_test.py @@ -39,9 +39,7 @@ def test_assert_kraus_is_consistent_with_unitary(): gate = GoodGateKraus() cirq.testing.assert_kraus_is_consistent_with_unitary(gate) - cirq.testing.assert_kraus_is_consistent_with_unitary( - GoodGateKraus().on(cirq.NamedQubit('q')) - ) + cirq.testing.assert_kraus_is_consistent_with_unitary(GoodGateKraus().on(cirq.NamedQubit('q'))) with pytest.raises(AssertionError): cirq.testing.assert_kraus_is_consistent_with_unitary(BadGateKraus()) From a0128d9d0b5c6e6a3bae0acaddef1a02a1ae4e0c Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Wed, 22 Sep 2021 01:08:19 +0530 Subject: [PATCH 15/33] Changed kraus decomposition and using `cirq.unitary` and `cirq.mixture` Also added consistency checks for kraus with unitary and mixture Need to add serial concatenation to mixture --- cirq-core/cirq/protocols/kraus_protocol.py | 134 +++++++++++++----- .../cirq/protocols/kraus_protocol_test.py | 12 ++ cirq-core/cirq/testing/consistent_kraus.py | 39 ++++- 3 files changed, 142 insertions(+), 43 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 5c3ab108e20..ccfe17308bb 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -26,7 +26,8 @@ ) from cirq._doc import doc_private from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose -from cirq.protocols.mixture_protocol import has_mixture +from cirq.protocols.mixture_protocol import mixture +from cirq.protocols.unitary_protocol import unitary from cirq.ops.raw_types import Qid, Gate from cirq.devices.line_qubit import LineQid @@ -158,25 +159,22 @@ def kraus( DeprecationWarning, ) - kraus_getter = getattr(val, '_kraus_', None) - kraus_result = NotImplemented if kraus_getter is None else kraus_getter() - if kraus_result is not NotImplemented: - return tuple(kraus_result) + channel_result = NotImplemented if channel_getter is None else channel_getter() + if channel_result is not NotImplemented: + return tuple(channel_result) - mixture_getter = getattr(val, '_mixture_', None) - mixture_result = NotImplemented if mixture_getter is None else mixture_getter() - if mixture_result is not NotImplemented and mixture_result is not None: + _, kraus_result = _strat_kraus_from_kraus(val) + if kraus_result is not None and kraus_result is not NotImplemented: + return kraus_result + + mixture_result = mixture(val, None) + if mixture_result is not None and mixture_result is not NotImplemented: return tuple(np.sqrt(p) * u for p, u in mixture_result) - unitary_getter = getattr(val, '_unitary_', None) - unitary_result = NotImplemented if unitary_getter is None else unitary_getter() - if unitary_result is not NotImplemented and unitary_result is not None: + unitary_result = unitary(val, None) + if unitary_result is not None and mixture_result is not NotImplemented: return (unitary_result,) - channel_result = NotImplemented if channel_getter is None else channel_getter() - if channel_result is not NotImplemented: - return tuple(channel_result) - # serial concatenation def checkEquality(x, y): if type(x) != type(y): @@ -227,11 +225,10 @@ def kraus_tensor(op, qubits, default): qubits = sorted(list(set(qubits))) kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) + assert len(decomposed) != 0 - if len(kraus_list) != 0 and not any([checkEquality(x, default) for x in kraus_list]): - + if not any([checkEquality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] - for i in range(1, len(kraus_list)): kraus_result = [op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i]] @@ -240,7 +237,10 @@ def kraus_tensor(op, qubits, default): if default is not RaiseTypeErrorIfNotProvided: return default - if kraus_getter is None and unitary_getter is None and mixture_getter is None: + if not any( + getattr(val, instance, None) is not None + for instance in ['_kraus_', '_unitary_', '_mixture_'] + ): raise TypeError( "object of type '{}' has no _kraus_ or _mixture_ or " "_unitary_ method.".format(type(val)) @@ -258,7 +258,39 @@ def has_channel(val: Any, *, allow_decompose: bool = True) -> bool: def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: - """Returns whether the value has a Kraus representation. + """Determines whether the value has a Kraus representation. + + Determines whether `val` has a Kraus representation by attempting + the following strategies: + + 1. Try to use `val.has_channel_()`. + Case a) Method not present or returns `None`. + Continue to next strategy. + Case b) Method returns `True`. + Kraus. + + 2. Try to use `val._kraus_()`. + Case a) Method not present or returns `NotImplemented`. + Continue to next strategy. + Case b) Method returns a 3D array. + Kraus. + + 3. Try to use `cirq.mixture()`. + Case a) Method not present or returns `NotImplemented`. + Continue to next strategy. + Case b) Method returns a 3D array. + Kraus. + + 4. Try to use `cirq.unitary()`. + Case a) Method not present or returns `NotImplemented`. + No Kraus. + Case b) Method returns a 3D array. + Kraus. + + 5. If decomposition is allowed apply recursion and check. + + If all the above methods fail then it is assumed to have no Kraus + representation. Args: val: The value to check. @@ -271,14 +303,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: the result is skipped. Returns: - If `val` has a `_has_kraus_` method and its result is not - NotImplemented, that result is returned. Otherwise, if `val` has a - `_has_mixture_` method and its result is not NotImplemented, that - result is returned. Otherwise if `val` has a `_has_unitary_` method - and its results is not NotImplemented, that result is returned. - Otherwise, if the value has a _kraus_ method return if that - has a non-default value. Returns False if none of these functions - exists. + Whether or not `val` has a Kraus representation. """ channel_getter = getattr(val, '_has_channel_', None) if channel_getter is not None: @@ -287,23 +312,56 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: DeprecationWarning, ) - kraus_getter = getattr(val, '_has_kraus_', None) - result = NotImplemented if kraus_getter is None else kraus_getter() + result = NotImplemented if channel_getter is None else channel_getter() if result is not NotImplemented: return result - result = has_mixture(val, allow_decompose=False) - if result is not NotImplemented and result: - return result + for instance in ['_has_kraus_', '_has_unitary_', '_has_mixture_']: + getter = getattr(val, instance, None) + result = NotImplemented if getter is None else getter() + if result is not NotImplemented: + return result - result = NotImplemented if channel_getter is None else channel_getter() - if result is not NotImplemented: - return result + strats = [_strat_kraus_from_kraus, _strat_kraus_from_mixture, _strat_kraus_from_unitary] + + if any(strat(val)[1] is not None and strat(val)[1] is not NotImplemented for strat in strats): + return True if allow_decompose: operations, _, _ = _try_decompose_into_operations_and_qubits(val) if operations is not None: return all(has_kraus(val) for val in operations) - # No has methods, use `_kraus_` or delegates instead. - return kraus(val, None) is not None + return False + + +def _strat_kraus_from_kraus(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: + """Attempts to compute the value's kraus via its _kraus_ method.""" + kraus_getter = getattr(val, '_kraus_', None) + kraus_result = NotImplemented if kraus_getter is None else kraus_getter() + if kraus_result is not NotImplemented: + return kraus_getter, tuple(kraus_result) + + return kraus_getter, kraus_result + + +def _strat_kraus_from_mixture(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: + """Attempts to compute the value's kraus via its _mixture_ method.""" + + mixture_getter = getattr(val, '_mixture_', None) + mixture_result = mixture(val, None) + if mixture_result is not NotImplemented and mixture_result is not None: + return mixture_getter, tuple(np.sqrt(p) * u for p, u in mixture_result) + + return mixture_getter, mixture_result + + +def _strat_kraus_from_unitary(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: + """Attempts to compute the value's kraus via its _unitary_ method.""" + + unitary_getter = getattr(val, '_unitary_', None) + unitary_result = unitary(val, None) + if unitary_result is not NotImplemented and unitary_result is not None: + return unitary_getter, (unitary_result,) + + return unitary_getter, unitary_result diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 42b091a71a5..d091fa06a12 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -171,6 +171,18 @@ def test_serial_concatenation(): assert cirq.has_kraus(g) +def test_empty_decompose(): + g = cirq.PauliString({}) ** 2 + + c = (cirq.unitary(g),) + + np.allclose(cirq.kraus(g), c) + np.allclose(cirq.kraus(g, None), c) + np.allclose(cirq.kraus(g, NotImplemented), c) + np.allclose(cirq.kraus(g, (1,)), c) + np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + + def test_channel_fallback_to_unitary(): u = np.array([[1, 0], [1, 0]]) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index f32e93c2693..3fda3b975f4 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -21,7 +21,7 @@ def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: bool = False): - """Uses `val._unitary_` to check `val._phase_by_`'s behavior.""" + """Uses `cirq.unitary` to check `val.kraus`'s behavior.""" # pylint: disable=unused-variable # __tracebackhide__ = True # pylint: enable=unused-variable @@ -30,21 +30,50 @@ def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: boo if expected is None: # If there's no unitary, it's vacuously consistent. return - expected = (expected,) if isinstance(val, ops.Operation): has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, default=None) + krs = protocols.kraus_protocol.kraus(val, None) else: has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, default=None) + krs = protocols.kraus_protocol.kraus(val, None) # there is unitary and hence must have kraus operator assert has_krs - actual = krs + actual = krs[0] if ignoring_global_phase: lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) else: # coverage: ignore np.testing.assert_allclose(actual, expected, atol=1e-8) + + +def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: bool = False): + """Uses `cirq.mixture` to check `cirq.kraus`'s behavior.""" + # pylint: disable=unused-variable + # __tracebackhide__ = True + # pylint: enable=unused-variable + + expected = protocols.mixture(val, None) + if expected is None: + # If there's no mixture, it's vacuously consistent. + return + + if isinstance(val, ops.Operation): + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, None) + else: + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, None) + + # there is mixture and hence must have kraus operator + assert has_krs + expected = np.array([np.sqrt(p) * x for p, x in expected]) + + if ignoring_global_phase: + lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) + else: + # coverage: ignore + np.testing.assert_allclose(actual, expected, atol=1e-8) + From c71954f984d5d011658eab6bbd8ab67b23084c89 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Wed, 22 Sep 2021 03:16:13 +0530 Subject: [PATCH 16/33] Fixed `has_kraus` --- cirq-core/cirq/protocols/kraus_protocol.py | 8 ++++++-- cirq-core/cirq/protocols/kraus_protocol_test.py | 6 ++++-- cirq-core/cirq/testing/consistent_kraus.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index ccfe17308bb..b3f7dfbc641 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -312,15 +312,19 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: DeprecationWarning, ) + results = [] result = NotImplemented if channel_getter is None else channel_getter() if result is not NotImplemented: - return result + results.append(result) for instance in ['_has_kraus_', '_has_unitary_', '_has_mixture_']: getter = getattr(val, instance, None) result = NotImplemented if getter is None else getter() if result is not NotImplemented: - return result + results.append(result) + + if any(results): + return True strats = [_strat_kraus_from_kraus, _strat_kraus_from_mixture, _strat_kraus_from_unitary] diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index d091fa06a12..a7570b56ce8 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -159,8 +159,8 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: def test_serial_concatenation(): g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) - c = (cirq.unitary(g),) + assert cirq.has_kraus(g) np.allclose(cirq.kraus(g), c) np.allclose(cirq.kraus(g, None), c) @@ -173,9 +173,10 @@ def test_serial_concatenation(): def test_empty_decompose(): g = cirq.PauliString({}) ** 2 - c = (cirq.unitary(g),) + assert cirq.has_kraus(g) + np.allclose(cirq.kraus(g), c) np.allclose(cirq.kraus(g, None), c) np.allclose(cirq.kraus(g, NotImplemented), c) @@ -190,6 +191,7 @@ class ReturnsUnitary: def _unitary_(self) -> np.ndarray: return u + assert cirq.has_kraus(ReturnsUnitary()) np.testing.assert_equal(cirq.kraus(ReturnsUnitary()), (u,)) np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), None), (u,)) np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), NotImplemented), (u,)) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index 3fda3b975f4..ed9a3c39dd5 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -69,6 +69,7 @@ def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: boo # there is mixture and hence must have kraus operator assert has_krs + actual = krs expected = np.array([np.sqrt(p) * x for p, x in expected]) if ignoring_global_phase: @@ -76,4 +77,3 @@ def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: boo else: # coverage: ignore np.testing.assert_allclose(actual, expected, atol=1e-8) - From ee57307b4335bda82958fc5eab2fcbb35baaeb29 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Thu, 23 Sep 2021 06:15:27 +0530 Subject: [PATCH 17/33] Serial concatenation in Mixture pending. Kraus needs additional tests --- cirq-core/cirq/protocols/kraus_protocol.py | 76 +++++++------------ .../cirq/protocols/kraus_protocol_test.py | 24 ++++-- cirq-core/cirq/protocols/mixture_protocol.py | 10 +-- 3 files changed, 51 insertions(+), 59 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index b3f7dfbc641..fc70cc38d3c 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -26,8 +26,9 @@ ) from cirq._doc import doc_private from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose -from cirq.protocols.mixture_protocol import mixture +from cirq.protocols.mixture_protocol import mixture, has_mixture from cirq.protocols.unitary_protocol import unitary +from cirq.protocols.has_unitary_protocol import has_unitary from cirq.ops.raw_types import Qid, Gate from cirq.devices.line_qubit import LineQid @@ -172,24 +173,12 @@ def kraus( return tuple(np.sqrt(p) * u for p, u in mixture_result) unitary_result = unitary(val, None) - if unitary_result is not None and mixture_result is not NotImplemented: + if unitary_result is not None and unitary_result is not NotImplemented: return (unitary_result,) - # serial concatenation - def checkEquality(x, y): - if type(x) != type(y): - return False - if type(x) not in [list, tuple, np.ndarray]: - return x == y - if type(x) == np.ndarray: - return x.shape == y.shape and np.all(x == y) - if len(x) != len(y): - return False - return all([checkEquality(a, b) for a, b in zip(x, y)]) - def kraus_tensor(op, qubits, default): kraus_list = kraus(op, default) - if checkEquality(kraus_list, default): + if _check_equality(kraus_list, default): return default val = None @@ -227,7 +216,7 @@ def kraus_tensor(op, qubits, default): kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) assert len(decomposed) != 0 - if not any([checkEquality(x, default) for x in kraus_list]): + if not any([_check_equality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): kraus_result = [op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i]] @@ -312,22 +301,21 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: DeprecationWarning, ) - results = [] result = NotImplemented if channel_getter is None else channel_getter() - if result is not NotImplemented: - results.append(result) + if result is not NotImplemented and result: + return True - for instance in ['_has_kraus_', '_has_unitary_', '_has_mixture_']: - getter = getattr(val, instance, None) - result = NotImplemented if getter is None else getter() - if result is not NotImplemented: - results.append(result) + for instance in [has_unitary, has_mixture]: + result = instance(val) + if result is not NotImplemented and result: + return True - if any(results): + getter = getattr(val, '_has_kraus_', None) + result = NotImplemented if getter is None else getter() + if result is not NotImplemented and result: return True - strats = [_strat_kraus_from_kraus, _strat_kraus_from_mixture, _strat_kraus_from_unitary] - + strats = [_strat_kraus_from_kraus] if any(strat(val)[1] is not None and strat(val)[1] is not NotImplemented for strat in strats): return True @@ -339,6 +327,18 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: return False +def _check_equality(x, y): + if type(x) != type(y): + return False + if type(x) not in [list, tuple, np.ndarray]: + return x == y + if type(x) == np.ndarray: + return x.shape == y.shape and np.all(x == y) + if len(x) != len(y): + return False + return all([_check_equality(a, b) for a, b in zip(x, y)]) + + def _strat_kraus_from_kraus(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: """Attempts to compute the value's kraus via its _kraus_ method.""" kraus_getter = getattr(val, '_kraus_', None) @@ -347,25 +347,3 @@ def _strat_kraus_from_kraus(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault] return kraus_getter, tuple(kraus_result) return kraus_getter, kraus_result - - -def _strat_kraus_from_mixture(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: - """Attempts to compute the value's kraus via its _mixture_ method.""" - - mixture_getter = getattr(val, '_mixture_', None) - mixture_result = mixture(val, None) - if mixture_result is not NotImplemented and mixture_result is not None: - return mixture_getter, tuple(np.sqrt(p) * u for p, u in mixture_result) - - return mixture_getter, mixture_result - - -def _strat_kraus_from_unitary(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: - """Attempts to compute the value's kraus via its _unitary_ method.""" - - unitary_getter = getattr(val, '_unitary_', None) - unitary_result = unitary(val, None) - if unitary_result is not NotImplemented and unitary_result is not None: - return unitary_getter, (unitary_result,) - - return unitary_getter, unitary_result diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index a7570b56ce8..91ec02618b2 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -157,10 +157,23 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) -def test_serial_concatenation(): - g = cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=-0.5, z_exponent=1) - c = (cirq.unitary(g),) - assert cirq.has_kraus(g) +def test_serial_concatenation_circuit(): + q1 = cirq.GridQubit(1, 1) + q2 = cirq.GridQubit(1, 2) + + class onlyDecompose: + def _decompose_(self): + circ = cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]) + return cirq.decompose(circ) + + def _unitary_(self): + return None + + def _mixture_(self): + return None + + g = onlyDecompose() + c = (cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),) np.allclose(cirq.kraus(g), c) np.allclose(cirq.kraus(g, None), c) @@ -233,4 +246,5 @@ def test_has_kraus(cls): def test_has_channel_when_decomposed(decomposed_cls): op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) assert cirq.has_kraus(op) - assert not cirq.has_kraus(op, allow_decompose=False) + if not cirq.has_unitary(op) and not cirq.has_mixture(op): + assert not cirq.has_kraus(op, allow_decompose=False) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 05adeff5796..2ea4fffa31d 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -20,6 +20,7 @@ from cirq._doc import doc_private from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.protocols.has_unitary_protocol import has_unitary +from cirq.protocols.unitary_protocol import unitary from cirq.type_workarounds import NotImplementedType # This is a special indicator value used by the inverse method to determine @@ -89,15 +90,14 @@ def mixture( if result is not NotImplemented: return result - unitary_getter = getattr(val, '_unitary_', None) - result = NotImplemented if unitary_getter is None else unitary_getter() - if result is not NotImplemented: - return ((1.0, result),) + unitary_result = unitary(val, None) + if unitary_result is not None and unitary_result is not NotImplemented: + return ((1.0, unitary_result),) if default is not RaiseTypeErrorIfNotProvided: return default - if mixture_getter is None and unitary_getter is None: + if not any(getattr(val, instance, None) is not None for instance in ['_unitary_', '_mixture_']): raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.") raise TypeError( From 0ce91eb93c1047bcfe907d853304185300b43611 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Fri, 24 Sep 2021 21:29:08 +0530 Subject: [PATCH 18/33] Added tests for serial concatenation. Mixture serial concatenation pending --- cirq-core/cirq/protocols/kraus_protocol.py | 42 ++++------- .../cirq/protocols/kraus_protocol_test.py | 74 +++++++++++++------ cirq-core/cirq/testing/__init__.py | 1 + cirq-core/cirq/testing/consistent_kraus.py | 18 ++--- .../cirq/testing/consistent_kraus_test.py | 25 ++++++- .../cirq/testing/consistent_protocols.py | 12 +-- 6 files changed, 105 insertions(+), 67 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index fc70cc38d3c..02f08520457 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -30,8 +30,7 @@ from cirq.protocols.unitary_protocol import unitary from cirq.protocols.has_unitary_protocol import has_unitary -from cirq.ops.raw_types import Qid, Gate -from cirq.devices.line_qubit import LineQid +from cirq.ops.raw_types import Qid from cirq.type_workarounds import NotImplementedType @@ -44,10 +43,10 @@ RaiseTypeErrorIfNotProvided = (np.array([]),) -TDefault = TypeVar('TDefault') +TDefault = TypeVar("TDefault") -@deprecated_class(deadline='v0.13', fix='use cirq.SupportsKraus instead') +@deprecated_class(deadline="v0.13", fix="use cirq.SupportsKraus instead") class SupportsChannel(Protocol): pass @@ -107,7 +106,7 @@ def _has_kraus_(self) -> bool: """ -@deprecated(deadline='v0.13', fix='use cirq.kraus instead') +@deprecated(deadline="v0.13", fix="use cirq.kraus instead") def channel( val: Any, default: Any = RaiseTypeErrorIfNotProvided ) -> Union[Tuple[np.ndarray, ...], TDefault]: @@ -153,10 +152,10 @@ def kraus( method returned NotImplemented) and also no default value was specified. """ - channel_getter = getattr(val, '_channel_', None) + channel_getter = getattr(val, "_channel_", None) if channel_getter is not None: warnings.warn( - '_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_', + "_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_", DeprecationWarning, ) @@ -200,22 +199,15 @@ def kraus_tensor(op, qubits, default): return val - if isinstance(val, Gate): - operation = val.on(*LineQid.for_gate(val)) - else: - operation = val + decomposed = decompose(val) - decomposed = decompose(operation) - - if decomposed != [operation]: + if decomposed != [val]: qubits: List[Qid] = [] for x in decomposed: qubits.extend(x.qubits) qubits = sorted(list(set(qubits))) kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) - assert len(decomposed) != 0 - if not any([_check_equality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): @@ -228,7 +220,7 @@ def kraus_tensor(op, qubits, default): if not any( getattr(val, instance, None) is not None - for instance in ['_kraus_', '_unitary_', '_mixture_'] + for instance in ["_kraus_", "_unitary_", "_mixture_"] ): raise TypeError( "object of type '{}' has no _kraus_ or _mixture_ or " @@ -241,7 +233,7 @@ def kraus_tensor(op, qubits, default): ) -@deprecated(deadline='v0.13', fix='use cirq.has_kraus instead') +@deprecated(deadline="v0.13", fix="use cirq.has_kraus instead") def has_channel(val: Any, *, allow_decompose: bool = True) -> bool: return has_kraus(val, allow_decompose=allow_decompose) @@ -294,10 +286,10 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Returns: Whether or not `val` has a Kraus representation. """ - channel_getter = getattr(val, '_has_channel_', None) + channel_getter = getattr(val, "_has_channel_", None) if channel_getter is not None: warnings.warn( - '_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_', + "_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_", DeprecationWarning, ) @@ -310,7 +302,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: if result is not NotImplemented and result: return True - getter = getattr(val, '_has_kraus_', None) + getter = getattr(val, "_has_kraus_", None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result: return True @@ -334,14 +326,12 @@ def _check_equality(x, y): return x == y if type(x) == np.ndarray: return x.shape == y.shape and np.all(x == y) - if len(x) != len(y): - return False - return all([_check_equality(a, b) for a, b in zip(x, y)]) + return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) -def _strat_kraus_from_kraus(val: Any) -> Union[Tuple[np.ndarray, ...], TDefault]: +def _strat_kraus_from_kraus(val: Any): """Attempts to compute the value's kraus via its _kraus_ method.""" - kraus_getter = getattr(val, '_kraus_', None) + kraus_getter = getattr(val, "_kraus_", None) kraus_result = NotImplemented if kraus_getter is None else kraus_getter() if kraus_result is not NotImplemented: return kraus_getter, tuple(kraus_result) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 91ec02618b2..a75dbb8873b 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -29,7 +29,7 @@ def test_channel_no_methods(): class NoMethod: pass - with pytest.raises(TypeError, match='no _kraus_ or _mixture_ or _unitary_ method'): + with pytest.raises(TypeError, match="no _kraus_ or _mixture_ or _unitary_ method"): _ = cirq.kraus(NoMethod()) assert cirq.kraus(NoMethod(), None) is None @@ -41,7 +41,7 @@ class NoMethod: def assert_not_implemented(val): - with pytest.raises(TypeError, match='returned NotImplemented'): + with pytest.raises(TypeError, match="returned NotImplemented"): _ = cirq.kraus(val) assert cirq.kraus(val, None) is None @@ -53,7 +53,7 @@ def assert_not_implemented(val): def test_supports_channel_class_is_deprecated(): - with cirq.testing.assert_deprecated(deadline='v0.13'): + with cirq.testing.assert_deprecated(deadline="v0.13"): class SomeChannel(cirq.SupportsChannel): pass @@ -62,12 +62,12 @@ class SomeChannel(cirq.SupportsChannel): def test_channel_protocol_is_deprecated(): - with cirq.testing.assert_deprecated(deadline='v0.13'): + with cirq.testing.assert_deprecated(deadline="v0.13"): assert np.allclose(cirq.channel(cirq.X), cirq.kraus(cirq.X)) def test_has_channel_protocol_is_deprecated(): - with cirq.testing.assert_deprecated(deadline='v0.13'): + with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.has_channel(cirq.depolarize(0.1)) == cirq.has_kraus(cirq.depolarize(0.1)) @@ -88,9 +88,9 @@ def _channel_(self): return (np.eye(2),) val = UsesDeprecatedChannelMethod() - with pytest.warns(DeprecationWarning, match='_has_kraus_'): + with pytest.warns(DeprecationWarning, match="_has_kraus_"): assert cirq.has_kraus(val) - with pytest.warns(DeprecationWarning, match='_kraus_'): + with pytest.warns(DeprecationWarning, match="_kraus_"): ks = cirq.kraus(val) assert len(ks) == 1 assert np.all(ks[0] == np.eye(2)) @@ -109,7 +109,7 @@ class ReturnsNotImplemented: def _unitary_(self): return NotImplemented - with pytest.raises(TypeError, match='returned NotImplemented'): + with pytest.raises(TypeError, match="returned NotImplemented"): _ = cirq.kraus(ReturnsNotImplemented()) assert cirq.kraus(ReturnsNotImplemented(), None) is None assert cirq.kraus(ReturnsNotImplemented(), NotImplemented) is NotImplemented @@ -157,6 +157,38 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) +def test_serial_concatenation_default(): + q1 = cirq.GridQubit(1, 1) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _kraus_(self): + return NotImplemented + + def _unitary_(self): + return None + + def _mixture_(self): + return None + + class onlyDecompose: + def _decompose_(self): + return [cirq.Y.on(q1), defaultGate().on(q1)] + + def _unitary_(self): + return None + + def _mixture_(self): + return None + + with pytest.raises(TypeError, match="returned NotImplemented"): + _ = cirq.kraus(onlyDecompose()) + assert cirq.kraus(onlyDecompose(), 0) == 0 + assert not cirq.has_kraus(onlyDecompose()) + + def test_serial_concatenation_circuit(): q1 = cirq.GridQubit(1, 1) q2 = cirq.GridQubit(1, 2) @@ -175,11 +207,11 @@ def _mixture_(self): g = onlyDecompose() c = (cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),) - np.allclose(cirq.kraus(g), c) - np.allclose(cirq.kraus(g, None), c) - np.allclose(cirq.kraus(g, NotImplemented), c) - np.allclose(cirq.kraus(g, (1,)), c) - np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + np.testing.assert_equal(cirq.kraus(g), c) + np.testing.assert_equal(cirq.kraus(g, None), c) + np.testing.assert_equal(cirq.kraus(g, NotImplemented), c) + np.testing.assert_equal(cirq.kraus(g, (1,)), c) + np.testing.assert_equal(cirq.kraus(g, LOCAL_DEFAULT), c) assert cirq.has_kraus(g) @@ -190,11 +222,11 @@ def test_empty_decompose(): assert cirq.has_kraus(g) - np.allclose(cirq.kraus(g), c) - np.allclose(cirq.kraus(g, None), c) - np.allclose(cirq.kraus(g, NotImplemented), c) - np.allclose(cirq.kraus(g, (1,)), c) - np.allclose(cirq.kraus(g, LOCAL_DEFAULT), c) + np.testing.assert_equal(cirq.kraus(g), c) + np.testing.assert_equal(cirq.kraus(g, None), c) + np.testing.assert_equal(cirq.kraus(g, NotImplemented), c) + np.testing.assert_equal(cirq.kraus(g, (1,)), c) + np.testing.assert_equal(cirq.kraus(g, LOCAL_DEFAULT), c) def test_channel_fallback_to_unitary(): @@ -237,14 +269,14 @@ def _decompose_(self, qubits): return [self.decomposed_cls().on(q) for q in qubits] -@pytest.mark.parametrize('cls', [HasKraus, HasMixture, HasUnitary]) +@pytest.mark.parametrize("cls", [HasKraus, HasMixture, HasUnitary]) def test_has_kraus(cls): assert cirq.has_kraus(cls()) -@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture, HasUnitary]) +@pytest.mark.parametrize("decomposed_cls", [HasKraus, HasMixture, HasUnitary]) def test_has_channel_when_decomposed(decomposed_cls): - op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) + op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit("test")) assert cirq.has_kraus(op) if not cirq.has_unitary(op) and not cirq.has_mixture(op): assert not cirq.has_kraus(op, allow_decompose=False) diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 9c7969faecd..f4068efccdb 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -42,6 +42,7 @@ from cirq.testing.consistent_kraus import ( assert_kraus_is_consistent_with_unitary, + assert_kraus_is_consistent_with_mixture, ) from cirq.testing.consistent_pauli_expansion import ( diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index ed9a3c39dd5..84f36ab966b 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -16,7 +16,7 @@ import numpy as np -from cirq import protocols, ops +from cirq import protocols from cirq.testing import lin_alg_utils @@ -31,12 +31,8 @@ def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: boo # If there's no unitary, it's vacuously consistent. return - if isinstance(val, ops.Operation): - has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, None) - else: - has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, None) + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, None) # there is unitary and hence must have kraus operator assert has_krs @@ -60,12 +56,8 @@ def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: boo # If there's no mixture, it's vacuously consistent. return - if isinstance(val, ops.Operation): - has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, None) - else: - has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, None) + has_krs = protocols.kraus_protocol.has_kraus(val) + krs = protocols.kraus_protocol.kraus(val, None) # there is mixture and hence must have kraus operator assert has_krs diff --git a/cirq-core/cirq/testing/consistent_kraus_test.py b/cirq-core/cirq/testing/consistent_kraus_test.py index 89095528a73..8a5b9d6bd5e 100644 --- a/cirq-core/cirq/testing/consistent_kraus_test.py +++ b/cirq-core/cirq/testing/consistent_kraus_test.py @@ -26,6 +26,9 @@ def _kraus_(self, default=None): def _unitary_(self): return np.array([[0, 1], [1, 0]]) + def _mixture_(self): + return ((1, np.array([[0, 1], [1, 0]])),) + class BadGateKraus(cirq.SingleQubitGate): def _kraus_(self, default=None): @@ -34,17 +37,35 @@ def _kraus_(self, default=None): def _unitary_(self): return np.array([[0, 1], [0, 1]]) + def _mixture_(self): + return ((1, np.array([[0, 1], [0, 1]])),) + def test_assert_kraus_is_consistent_with_unitary(): gate = GoodGateKraus() cirq.testing.assert_kraus_is_consistent_with_unitary(gate) - cirq.testing.assert_kraus_is_consistent_with_unitary(GoodGateKraus().on(cirq.NamedQubit('q'))) + cirq.testing.assert_kraus_is_consistent_with_unitary(GoodGateKraus().on(cirq.NamedQubit("q"))) with pytest.raises(AssertionError): cirq.testing.assert_kraus_is_consistent_with_unitary(BadGateKraus()) with pytest.raises(AssertionError): cirq.testing.assert_kraus_is_consistent_with_unitary( - BadGateKraus().on(cirq.NamedQubit('q')) + BadGateKraus().on(cirq.NamedQubit("q")) + ) + + +def test_assert_kraus_is_consistent_with_mixture(): + gate = GoodGateKraus() + cirq.testing.assert_kraus_is_consistent_with_mixture(gate) + + cirq.testing.assert_kraus_is_consistent_with_mixture(GoodGateKraus().on(cirq.NamedQubit("q"))) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_mixture(BadGateKraus()) + + with pytest.raises(AssertionError): + cirq.testing.assert_kraus_is_consistent_with_mixture( + BadGateKraus().on(cirq.NamedQubit("q")) ) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index fb916d0be7e..cb1353d2846 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -29,6 +29,7 @@ ) from cirq.testing.consistent_kraus import ( assert_kraus_is_consistent_with_unitary, + assert_kraus_is_consistent_with_mixture, ) from cirq.testing.consistent_phase_by import ( assert_phase_by_is_consistent_with_unitary, @@ -51,10 +52,10 @@ def assert_implements_consistent_protocols( val: Any, *, - exponents: Sequence[Any] = (0, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol('s')), + exponents: Sequence[Any] = (0, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol("s")), qubit_count: Optional[int] = None, ignoring_global_phase: bool = False, - setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', + setup_code: str = "import cirq\nimport numpy as np\nimport sympy", global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, ) -> None: @@ -85,11 +86,11 @@ def assert_implements_consistent_protocols( def assert_eigengate_implements_consistent_protocols( eigen_gate_type: Type[ops.EigenGate], *, - exponents: Sequence[value.TParamVal] = (0, 0.5, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol('s')), + exponents: Sequence[value.TParamVal] = (0, 0.5, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol("s")), global_shifts: Sequence[float] = (0, -0.5, 0.1), qubit_count: Optional[int] = None, ignoring_global_phase: bool = False, - setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', + setup_code: str = "import cirq\nimport numpy as np\nimport sympy", global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, ) -> None: @@ -157,6 +158,7 @@ def _assert_meets_standards_helper( assert_has_consistent_trace_distance_bound(val) assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) assert_kraus_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) + assert_kraus_is_consistent_with_mixture(val, ignoring_global_phase=ignoring_global_phase) assert_phase_by_is_consistent_with_unitary(val) assert_pauli_expansion_is_consistent_with_unitary(val) assert_equivalent_repr( @@ -170,7 +172,7 @@ def assert_commutes_magic_method_consistent_with_unitaries( *vals: Sequence[Any], atol: Union[int, float] = 1e-8 ) -> None: if any(isinstance(val, ops.Operation) for val in vals): - raise TypeError('`_commutes_` need not be consistent with unitaries for `Operation`.') + raise TypeError("`_commutes_` need not be consistent with unitaries for `Operation`.") unitaries = [protocols.unitary(val, None) for val in vals] pairs = itertools.permutations(zip(vals, unitaries), 2) for (left_val, left_unitary), (right_val, right_unitary) in pairs: From 694a6dc4a6ab449197a2b94c9e72357f32b0ac28 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Fri, 24 Sep 2021 23:01:53 +0530 Subject: [PATCH 19/33] Fixed consistency issue in mixture --- cirq-core/cirq/testing/consistent_kraus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index 84f36ab966b..225c7209763 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -57,12 +57,12 @@ def assert_kraus_is_consistent_with_mixture(val: Any, ignoring_global_phase: boo return has_krs = protocols.kraus_protocol.has_kraus(val) - krs = protocols.kraus_protocol.kraus(val, None) + krs = np.array(protocols.kraus_protocol.kraus(val, None)) # there is mixture and hence must have kraus operator assert has_krs actual = krs - expected = np.array([np.sqrt(p) * x for p, x in expected]) + expected = np.array([np.sqrt(p) * u for p, u in expected]) if ignoring_global_phase: lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8) From 9c1110240db1cd6d8120bdf75031b310f711ab85 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sat, 25 Sep 2021 02:29:26 +0530 Subject: [PATCH 20/33] Added serial concatenation to mixture Need to fix tests and check --- cirq-core/cirq/protocols/kraus_protocol.py | 51 ++++---- cirq-core/cirq/protocols/mixture_protocol.py | 114 +++++++++++++++--- .../cirq/protocols/mixture_protocol_test.py | 87 ++++++++++--- 3 files changed, 196 insertions(+), 56 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 02f08520457..4149cf769c8 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -175,30 +175,6 @@ def kraus( if unitary_result is not None and unitary_result is not NotImplemented: return (unitary_result,) - def kraus_tensor(op, qubits, default): - kraus_list = kraus(op, default) - if _check_equality(kraus_list, default): - return default - - val = None - op_q = op.qubits - found = False - for i in range(len(qubits)): - if qubits[i] in op_q: - if not found: - found = True - if val is None: - val = kraus_list - else: - val = tuple([np.kron(x, y) for x in val for y in kraus_list]) - - elif val is None: - val = (np.identity(2),) - else: - val = tuple([np.kron(x, np.identity(2)) for x in val]) - - return val - decomposed = decompose(val) if decomposed != [val]: @@ -207,7 +183,7 @@ def kraus_tensor(op, qubits, default): qubits.extend(x.qubits) qubits = sorted(list(set(qubits))) - kraus_list = list(map(lambda x: kraus_tensor(x, qubits, default), decomposed)) + kraus_list = list(map(lambda x: _kraus_tensor(x, qubits, default), decomposed)) if not any([_check_equality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): @@ -329,6 +305,31 @@ def _check_equality(x, y): return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) +def _kraus_tensor(op, qubits, default): + kraus_list = kraus(op, default) + if _check_equality(kraus_list, default): + return default + + val = None + op_q = op.qubits + found = False + for i in range(len(qubits)): + if qubits[i] in op_q: + if not found: + found = True + if val is None: + val = kraus_list + else: + val = tuple([np.kron(x, y) for x in val for y in kraus_list]) + + elif val is None: + val = (np.identity(2),) + else: + val = tuple([np.kron(x, np.identity(2)) for x in val]) + + return val + + def _strat_kraus_from_kraus(val: Any): """Attempts to compute the value's kraus via its _kraus_ method.""" kraus_getter = getattr(val, "_kraus_", None) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 2ea4fffa31d..e5f4c720407 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """Protocol for objects that are mixtures (probabilistic combinations).""" -from typing import Any, Sequence, Tuple, Union +from typing import Any, Sequence, Tuple, Union, List import numpy as np from typing_extensions import Protocol from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose from cirq.protocols.has_unitary_protocol import has_unitary from cirq.protocols.unitary_protocol import unitary from cirq.type_workarounds import NotImplementedType +from cirq.devices.line_qubit import LineQid + +from cirq.ops.raw_types import Qid, Gate # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. @@ -85,19 +88,43 @@ def mixture( with that probability in the mixture. The probabilities will sum to 1.0. """ - mixture_getter = getattr(val, '_mixture_', None) - result = NotImplemented if mixture_getter is None else mixture_getter() - if result is not NotImplemented: - return result + mixture_result = _strat_mixture_from_mixture(val) + if mixture_result is not None and mixture_result is not NotImplemented: + return mixture_result unitary_result = unitary(val, None) if unitary_result is not None and unitary_result is not NotImplemented: return ((1.0, unitary_result),) + if isinstance(val, Gate): + val = val.on(*LineQid.for_gate(val)) + else: + val = val + # serial concatenation + decomposed = decompose(val) + + if decomposed != [val]: + qubits: List[Qid] = [] + for x in decomposed: + qubits.extend(x.qubits) + + qubits = sorted(list(set(qubits))) + mixture_list = list(map(lambda x: _mixture_tensor(x, qubits, default), decomposed)) + if not any([_check_equality(x, default) for x in mixture_list]): + mixture_result = mixture_list[0] + for i in range(1, len(mixture_list)): + mixture_result = [ + _product_mixture_pair(op_1, op_2) + for op_1 in mixture_result + for op_2 in mixture_list[i] + ] + + return tuple(mixture_result) + if default is not RaiseTypeErrorIfNotProvided: return default - if not any(getattr(val, instance, None) is not None for instance in ['_unitary_', '_mixture_']): + if not any(getattr(val, instance, None) is not None for instance in ["_unitary_", "_mixture_"]): raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.") raise TypeError( @@ -126,7 +153,7 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: has a `_mixture_` method return True if that has a non-default value. Returns False if neither function exists. """ - mixture_getter = getattr(val, '_has_mixture_', None) + mixture_getter = getattr(val, "_has_mixture_", None) result = NotImplemented if mixture_getter is None else mixture_getter() if result is not NotImplemented: return result @@ -134,30 +161,87 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: if has_unitary(val, allow_decompose=False): return True + strats = [_strat_mixture_from_mixture] + if any(strat(val) is not None and strat(val) is not NotImplemented for strat in strats): + return True + if allow_decompose: operations, _, _ = _try_decompose_into_operations_and_qubits(val) if operations is not None: return all(has_mixture(val) for val in operations) - # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead. - return mixture(val, None) is not None + return False def validate_mixture(supports_mixture: SupportsMixture): """Validates that the mixture's tuple are valid probabilities.""" mixture_tuple = mixture(supports_mixture, None) if mixture_tuple is None: - raise TypeError(f'{supports_mixture}_mixture did not have a _mixture_ method') + raise TypeError(f"{supports_mixture}_mixture did not have a _mixture_ method") def validate_probability(p, p_str): if p < 0: - raise ValueError(f'{p_str} was less than 0.') + raise ValueError(f"{p_str} was less than 0.") elif p > 1: - raise ValueError(f'{p_str} was greater than 1.') + raise ValueError(f"{p_str} was greater than 1.") total = 0.0 for p, val in mixture_tuple: - validate_probability(p, '{}\'s probability'.format(str(val))) + validate_probability(p, "{}'s probability".format(str(val))) total += p if not np.isclose(total, 1.0): - raise ValueError('Sum of probabilities of a mixture was not 1.0') + raise ValueError("Sum of probabilities of a mixture was not 1.0") + + +def _strat_mixture_from_mixture(val: Any): + """Attempts to compute the value's mixture via its _mixture_ method.""" + mixture_getter = getattr(val, "_mixture_", None) + result = NotImplemented if mixture_getter is None else mixture_getter() + return result + + +def _check_equality(x, y): + if type(x) != type(y): + return False + if type(x) not in [list, tuple, np.ndarray]: + return x == y + if type(x) == np.ndarray: + return x.shape == y.shape and np.all(x == y) + return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) + + +def _tensor_mixture_pair(x, y): + p_new = x[0] * y[0] + mat_new = np.kron(x[1], y[1]) + return (p_new, mat_new) + + +def _product_mixture_pair(x, y): + p_new = x[0] * y[0] + mat_new = y[1].dot(x[1]) + return (p_new, mat_new) + + +def _mixture_tensor(op, qubits, default): + mixture_list = mixture(op, default) + if _check_equality(mixture_list, default): + return default + + val = None + op_q = op.qubits + found = False + for i in range(len(qubits)): + if qubits[i] in op_q: + if not found: + found = True + if val is None: + val = mixture_list + else: + val = tuple([_tensor_mixture_pair(x, y) for x in val for y in mixture_list]) + + elif val is None: + val = ((1, np.identity(2)),) + else: + val = tuple([_tensor_mixture_pair(x, (1, np.identity(2))) for x in val]) + + return val diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index cadbe0806c6..e0a66f2acd5 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -33,7 +33,7 @@ def _has_mixture_(self): class ReturnsValidTuple(cirq.SupportsMixture): def _mixture_(self): - return ((0.4, 'a'), (0.6, 'b')) + return ((0.4, "a"), (0.6, "b")) def _has_mixture_(self): return True @@ -41,22 +41,22 @@ def _has_mixture_(self): class ReturnsNonnormalizedTuple: def _mixture_(self): - return ((0.4, 'a'), (0.4, 'b')) + return ((0.4, "a"), (0.4, "b")) class ReturnsNegativeProbability: def _mixture_(self): - return ((0.4, 'a'), (-0.4, 'b')) + return ((0.4, "a"), (-0.4, "b")) class ReturnsGreaterThanUnityProbability: def _mixture_(self): - return ((1.2, 'a'), (0.4, 'b')) + return ((1.2, "a"), (0.4, "b")) class ReturnsMixtureButNoHasMixture: def _mixture_(self): - return ((0.4, 'a'), (0.6, 'b')) + return ((0.4, "a"), (0.6, "b")) class ReturnsUnitary: @@ -76,10 +76,10 @@ def _has_unitary_(self): @pytest.mark.parametrize( - 'val,mixture', + "val,mixture", ( - (ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))), - (ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))), + (ReturnsValidTuple(), ((0.4, "a"), (0.6, "b"))), + (ReturnsNonnormalizedTuple(), ((0.4, "a"), (0.4, "b"))), (ReturnsUnitary(), ((1.0, np.ones((2, 2))),)), ), ) @@ -89,20 +89,20 @@ def test_objects_with_mixture(val, mixture): np.testing.assert_almost_equal(keys, expected_keys) np.testing.assert_equal(values, expected_values) - keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b')))) + keys, values = zip(*cirq.mixture(val, ((0.3, "a"), (0.7, "b")))) np.testing.assert_almost_equal(keys, expected_keys) np.testing.assert_equal(values, expected_values) @pytest.mark.parametrize( - 'val', (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) + "val", (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) ) def test_objects_with_no_mixture(val): with pytest.raises(TypeError, match="mixture"): _ = cirq.mixture(val) assert cirq.mixture(val, None) is None assert cirq.mixture(val, NotImplemented) is NotImplemented - default = ((0.4, 'a'), (0.6, 'b')) + default = ((0.4, "a"), (0.6, "b")) assert cirq.mixture(val, default) == default @@ -118,12 +118,67 @@ def test_valid_mixture(): cirq.validate_mixture(ReturnsValidTuple()) +def test_serial_concatenation_default(): + q1 = cirq.GridQubit(1, 1) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + + class onlyDecompose: + def _decompose_(self): + return [cirq.Y.on(q1), defaultGate().on(q1)] + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + + with pytest.raises(TypeError, match="returned NotImplemented"): + _ = cirq.mixture(onlyDecompose()) + assert cirq.mixture(onlyDecompose(), 0) == 0 + assert not cirq.has_mixture(onlyDecompose()) + + +def test_serial_concatenation_circuit(): + q1 = cirq.GridQubit(1, 1) + q2 = cirq.GridQubit(1, 2) + + class onlyDecompose: + def _decompose_(self): + circ = cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]) + return cirq.decompose(circ) + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + + g = onlyDecompose() + c = ((1, cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]))),) + + np.testing.assert_equal(cirq.mixture(g), c) + np.testing.assert_equal(cirq.mixture(g, None), c) + np.testing.assert_equal(cirq.mixture(g, NotImplemented), c) + np.testing.assert_equal(cirq.mixture(g, (1,)), c) + + assert cirq.has_mixture(g) + + @pytest.mark.parametrize( - 'val,message', + "val,message", ( - (ReturnsNonnormalizedTuple(), '1.0'), - (ReturnsNegativeProbability(), 'less than 0'), - (ReturnsGreaterThanUnityProbability(), 'greater than 1'), + (ReturnsNonnormalizedTuple(), "1.0"), + (ReturnsNegativeProbability(), "less than 0"), + (ReturnsGreaterThanUnityProbability(), "greater than 1"), ), ) def test_invalid_mixture(val, message): @@ -132,5 +187,5 @@ def test_invalid_mixture(val, message): def test_missing_mixture(): - with pytest.raises(TypeError, match='_mixture_'): + with pytest.raises(TypeError, match="_mixture_"): cirq.validate_mixture(NoMethod) From ca14ba3a1e0f265b54c5b39e19cffadfcf601735 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 28 Sep 2021 03:18:11 +0530 Subject: [PATCH 21/33] Completed Mixture serial concatenation adn added tests for the same. Also Added check for combinatorial explosion as requested. --- cirq-core/cirq/protocols/kraus_protocol.py | 12 +++-- .../cirq/protocols/kraus_protocol_test.py | 46 +++++++++++++++++-- cirq-core/cirq/protocols/mixture_protocol.py | 32 ++++++++----- .../cirq/protocols/mixture_protocol_test.py | 43 ++++++++++++++--- 4 files changed, 109 insertions(+), 24 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 4149cf769c8..d40ad51ea4c 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -25,7 +25,10 @@ deprecated_class, ) from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose +from cirq.protocols.decompose_protocol import ( + _try_decompose_into_operations_and_qubits, + decompose, +) from cirq.protocols.mixture_protocol import mixture, has_mixture from cirq.protocols.unitary_protocol import unitary from cirq.protocols.has_unitary_protocol import has_unitary @@ -181,14 +184,17 @@ def kraus( qubits: List[Qid] = [] for x in decomposed: qubits.extend(x.qubits) - qubits = sorted(list(set(qubits))) + limit = (4 ** np.prod(len(qubits))) ** 2 + kraus_list = list(map(lambda x: _kraus_tensor(x, qubits, default), decomposed)) if not any([_check_equality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): kraus_result = [op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i]] - + assert ( + len(kraus_result) < limit + ), f"{val} kraus decomposition had combinatorial explosion." return tuple(kraus_result) if default is not RaiseTypeErrorIfNotProvided: diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index a75dbb8873b..7d1a9e6d631 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -171,7 +171,7 @@ def _unitary_(self): return None def _mixture_(self): - return None + return NotImplemented class onlyDecompose: def _decompose_(self): @@ -193,9 +193,22 @@ def test_serial_concatenation_circuit(): q1 = cirq.GridQubit(1, 1) q2 = cirq.GridQubit(1, 2) + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _kraus_(self): + return cirq.kraus(cirq.X) + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + class onlyDecompose: def _decompose_(self): - circ = cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]) + circ = cirq.Circuit([cirq.Y.on(q1), defaultGate().on(q2)]) return cirq.decompose(circ) def _unitary_(self): @@ -216,7 +229,34 @@ def _mixture_(self): assert cirq.has_kraus(g) -def test_empty_decompose(): +def test_kraus_combinatorial_explosion(): + q1 = cirq.GridQubit(1, 1) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _kraus_(self): + # for one qubit the upper limit is 16 elements + ls = [np.array([[1, 0], [0, 1]])] * 16 + ls.append(np.array([[1, 0], [0, 1]])) + return tuple(ls) + + class onlyDecompose: + def _decompose_(self): + return [cirq.Y.on(q1), defaultGate().on(q1)] + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + + with pytest.raises(AssertionError, match="combinatorial explosion."): + _ = cirq.kraus(onlyDecompose()) + + +def test_kraus_empty_decompose(): g = cirq.PauliString({}) ** 2 c = (cirq.unitary(g),) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index e5f4c720407..fce55cb4dda 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -18,9 +18,14 @@ from typing_extensions import Protocol from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits, decompose +from cirq.protocols.decompose_protocol import ( + _try_decompose_into_operations_and_qubits, + decompose, + decompose_once_with_qubits, +) from cirq.protocols.has_unitary_protocol import has_unitary from cirq.protocols.unitary_protocol import unitary +from cirq.protocols.qid_shape_protocol import qid_shape from cirq.type_workarounds import NotImplementedType from cirq.devices.line_qubit import LineQid @@ -96,17 +101,19 @@ def mixture( if unitary_result is not None and unitary_result is not NotImplemented: return ((1.0, unitary_result),) - if isinstance(val, Gate): - val = val.on(*LineQid.for_gate(val)) - else: - val = val - # serial concatenation - decomposed = decompose(val) + decomposed = ( + decompose_once_with_qubits(val, LineQid.for_qid_shape(qid_shape(val)), []) + if isinstance(val, Gate) + else decompose(val) + ) - if decomposed != [val]: + # serial concatenation + if decomposed != [] and decomposed != [val]: qubits: List[Qid] = [] for x in decomposed: qubits.extend(x.qubits) + qubits = list(set(qubits)) + limit = (4 ** np.prod(len(qubits))) ** 2 qubits = sorted(list(set(qubits))) mixture_list = list(map(lambda x: _mixture_tensor(x, qubits, default), decomposed)) @@ -118,8 +125,11 @@ def mixture( for op_1 in mixture_result for op_2 in mixture_list[i] ] - - return tuple(mixture_result) + assert ( + len(mixture_result) < limit + ), f"{val} mixture decomposition had combinatorial explosion." + else: + return tuple(mixture_result) if default is not RaiseTypeErrorIfNotProvided: return default @@ -187,7 +197,7 @@ def validate_probability(p, p_str): total = 0.0 for p, val in mixture_tuple: - validate_probability(p, "{}'s probability".format(str(val))) + validate_probability(p, f"{str(val)}'s probability") total += p if not np.isclose(total, 1.0): raise ValueError("Sum of probabilities of a mixture was not 1.0") diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index e0a66f2acd5..2f68141f018 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -118,6 +118,33 @@ def test_valid_mixture(): cirq.validate_mixture(ReturnsValidTuple()) +def test_combinatorial_explosion(): + q1 = cirq.GridQubit(1, 1) + + class defaultGate(cirq.Gate): + def num_qubits(self): + return 1 + + def _mixture_(self): + # for one qubit the upper limit is 16 elements + ls = [(0.1, np.array([[1, 0], [0, 1]]))] * 16 + ls.append((0.84, np.array([[1, 0], [0, 1]]))) + return tuple(ls) + + class onlyDecompose: + def _decompose_(self): + return [cirq.Y.on(q1), defaultGate().on(q1)] + + def _unitary_(self): + return None + + def _mixture_(self): + return NotImplemented + + with pytest.raises(AssertionError, match="combinatorial explosion."): + _ = cirq.mixture(onlyDecompose()) + + def test_serial_concatenation_default(): q1 = cirq.GridQubit(1, 1) @@ -141,9 +168,12 @@ def _unitary_(self): def _mixture_(self): return NotImplemented + default = (1.0, np.array([[1, 0], [0, 1]])) + with pytest.raises(TypeError, match="returned NotImplemented"): _ = cirq.mixture(onlyDecompose()) assert cirq.mixture(onlyDecompose(), 0) == 0 + np.testing.assert_equal(cirq.mixture(onlyDecompose(), default), default) assert not cirq.has_mixture(onlyDecompose()) @@ -162,15 +192,14 @@ def _unitary_(self): def _mixture_(self): return NotImplemented - g = onlyDecompose() - c = ((1, cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]))),) + c = ((1.0, np.array(cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])))),) - np.testing.assert_equal(cirq.mixture(g), c) - np.testing.assert_equal(cirq.mixture(g, None), c) - np.testing.assert_equal(cirq.mixture(g, NotImplemented), c) - np.testing.assert_equal(cirq.mixture(g, (1,)), c) + np.testing.assert_equal(cirq.mixture(onlyDecompose()), c) + np.testing.assert_equal(cirq.mixture(onlyDecompose(), None), c) + np.testing.assert_equal(cirq.mixture(onlyDecompose(), NotImplemented), c) + np.testing.assert_equal(cirq.mixture(onlyDecompose(), (1,)), c) - assert cirq.has_mixture(g) + assert cirq.has_mixture(onlyDecompose()) @pytest.mark.parametrize( From f2a703def72795dd731472a324bc1e7bcafe1dca Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sat, 2 Oct 2021 22:02:56 +0530 Subject: [PATCH 22/33] Fixed header files and reverted the irrelavant quote changes. --- cirq-core/cirq/protocols/kraus_protocol.py | 25 ++++++--------- .../cirq/protocols/kraus_protocol_test.py | 20 ++++++------ cirq-core/cirq/protocols/mixture_protocol.py | 15 ++++----- .../cirq/protocols/mixture_protocol_test.py | 32 +++++++++---------- .../cirq/testing/consistent_protocols.py | 10 +++--- 5 files changed, 47 insertions(+), 55 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index d40ad51ea4c..a149ee88842 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -29,10 +29,7 @@ _try_decompose_into_operations_and_qubits, decompose, ) -from cirq.protocols.mixture_protocol import mixture, has_mixture -from cirq.protocols.unitary_protocol import unitary -from cirq.protocols.has_unitary_protocol import has_unitary - +from cirq.protocols import mixture_protocol, has_unitary_protocol from cirq.ops.raw_types import Qid from cirq.type_workarounds import NotImplementedType @@ -46,10 +43,10 @@ RaiseTypeErrorIfNotProvided = (np.array([]),) -TDefault = TypeVar("TDefault") +TDefault = TypeVar('TDefault') -@deprecated_class(deadline="v0.13", fix="use cirq.SupportsKraus instead") +@deprecated_class(deadline='v0.13', fix='use cirq.SupportsKraus instead') class SupportsChannel(Protocol): pass @@ -155,10 +152,10 @@ def kraus( method returned NotImplemented) and also no default value was specified. """ - channel_getter = getattr(val, "_channel_", None) + channel_getter = getattr(val, '_channel_', None) if channel_getter is not None: warnings.warn( - "_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_", + '_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_', DeprecationWarning, ) @@ -170,14 +167,10 @@ def kraus( if kraus_result is not None and kraus_result is not NotImplemented: return kraus_result - mixture_result = mixture(val, None) + mixture_result = mixture_protocol.mixture(val, None) if mixture_result is not None and mixture_result is not NotImplemented: return tuple(np.sqrt(p) * u for p, u in mixture_result) - unitary_result = unitary(val, None) - if unitary_result is not None and unitary_result is not NotImplemented: - return (unitary_result,) - decomposed = decompose(val) if decomposed != [val]: @@ -215,7 +208,7 @@ def kraus( ) -@deprecated(deadline="v0.13", fix="use cirq.has_kraus instead") +@deprecated(deadline='v0.13', fix='use cirq.has_kraus instead') def has_channel(val: Any, *, allow_decompose: bool = True) -> bool: return has_kraus(val, allow_decompose=allow_decompose) @@ -271,7 +264,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: channel_getter = getattr(val, "_has_channel_", None) if channel_getter is not None: warnings.warn( - "_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_", + '_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_', DeprecationWarning, ) @@ -279,7 +272,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: if result is not NotImplemented and result: return True - for instance in [has_unitary, has_mixture]: + for instance in [has_unitary_protocol.has_unitary, mixture_protocol.has_mixture]: result = instance(val) if result is not NotImplemented and result: return True diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 7d1a9e6d631..6c6457c8faa 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -29,7 +29,7 @@ def test_channel_no_methods(): class NoMethod: pass - with pytest.raises(TypeError, match="no _kraus_ or _mixture_ or _unitary_ method"): + with pytest.raises(TypeError, match='no _kraus_ or _mixture_ or _unitary_ method'): _ = cirq.kraus(NoMethod()) assert cirq.kraus(NoMethod(), None) is None @@ -41,7 +41,7 @@ class NoMethod: def assert_not_implemented(val): - with pytest.raises(TypeError, match="returned NotImplemented"): + with pytest.raises(TypeError, match='returned NotImplemented'): _ = cirq.kraus(val) assert cirq.kraus(val, None) is None @@ -53,7 +53,7 @@ def assert_not_implemented(val): def test_supports_channel_class_is_deprecated(): - with cirq.testing.assert_deprecated(deadline="v0.13"): + with cirq.testing.assert_deprecated(deadline='v0.13'): class SomeChannel(cirq.SupportsChannel): pass @@ -62,7 +62,7 @@ class SomeChannel(cirq.SupportsChannel): def test_channel_protocol_is_deprecated(): - with cirq.testing.assert_deprecated(deadline="v0.13"): + with cirq.testing.assert_deprecated(deadline='v0.13'): assert np.allclose(cirq.channel(cirq.X), cirq.kraus(cirq.X)) @@ -88,9 +88,9 @@ def _channel_(self): return (np.eye(2),) val = UsesDeprecatedChannelMethod() - with pytest.warns(DeprecationWarning, match="_has_kraus_"): + with pytest.warns(DeprecationWarning, match='_has_kraus_'): assert cirq.has_kraus(val) - with pytest.warns(DeprecationWarning, match="_kraus_"): + with pytest.warns(DeprecationWarning, match='_kraus_'): ks = cirq.kraus(val) assert len(ks) == 1 assert np.all(ks[0] == np.eye(2)) @@ -109,7 +109,7 @@ class ReturnsNotImplemented: def _unitary_(self): return NotImplemented - with pytest.raises(TypeError, match="returned NotImplemented"): + with pytest.raises(TypeError, match='returned NotImplemented'): _ = cirq.kraus(ReturnsNotImplemented()) assert cirq.kraus(ReturnsNotImplemented(), None) is None assert cirq.kraus(ReturnsNotImplemented(), NotImplemented) is NotImplemented @@ -309,14 +309,14 @@ def _decompose_(self, qubits): return [self.decomposed_cls().on(q) for q in qubits] -@pytest.mark.parametrize("cls", [HasKraus, HasMixture, HasUnitary]) +@pytest.mark.parametrize('cls', [HasKraus, HasMixture, HasUnitary]) def test_has_kraus(cls): assert cirq.has_kraus(cls()) -@pytest.mark.parametrize("decomposed_cls", [HasKraus, HasMixture, HasUnitary]) +@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture, HasUnitary]) def test_has_channel_when_decomposed(decomposed_cls): - op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit("test")) + op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) assert cirq.has_kraus(op) if not cirq.has_unitary(op) and not cirq.has_mixture(op): assert not cirq.has_kraus(op, allow_decompose=False) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index fce55cb4dda..994c4db5964 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -23,8 +23,7 @@ decompose, decompose_once_with_qubits, ) -from cirq.protocols.has_unitary_protocol import has_unitary -from cirq.protocols.unitary_protocol import unitary +from cirq.protocols import has_unitary_protocol, unitary_protocol from cirq.protocols.qid_shape_protocol import qid_shape from cirq.type_workarounds import NotImplementedType from cirq.devices.line_qubit import LineQid @@ -97,7 +96,7 @@ def mixture( if mixture_result is not None and mixture_result is not NotImplemented: return mixture_result - unitary_result = unitary(val, None) + unitary_result = unitary_protocol.unitary(val, None) if unitary_result is not None and unitary_result is not NotImplemented: return ((1.0, unitary_result),) @@ -163,12 +162,12 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: has a `_mixture_` method return True if that has a non-default value. Returns False if neither function exists. """ - mixture_getter = getattr(val, "_has_mixture_", None) + mixture_getter = getattr(val, '_has_mixture_', None) result = NotImplemented if mixture_getter is None else mixture_getter() if result is not NotImplemented: return result - if has_unitary(val, allow_decompose=False): + if has_unitary_protocol.has_unitary(val, allow_decompose=False): return True strats = [_strat_mixture_from_mixture] @@ -187,13 +186,13 @@ def validate_mixture(supports_mixture: SupportsMixture): """Validates that the mixture's tuple are valid probabilities.""" mixture_tuple = mixture(supports_mixture, None) if mixture_tuple is None: - raise TypeError(f"{supports_mixture}_mixture did not have a _mixture_ method") + raise TypeError(f'{supports_mixture}_mixture did not have a _mixture_ method') def validate_probability(p, p_str): if p < 0: - raise ValueError(f"{p_str} was less than 0.") + raise ValueError(f'{p_str} was less than 0.') elif p > 1: - raise ValueError(f"{p_str} was greater than 1.") + raise ValueError(f'{p_str} was greater than 1.') total = 0.0 for p, val in mixture_tuple: diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index 2f68141f018..e5d8026b116 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -33,7 +33,7 @@ def _has_mixture_(self): class ReturnsValidTuple(cirq.SupportsMixture): def _mixture_(self): - return ((0.4, "a"), (0.6, "b")) + return ((0.4, 'a'), (0.6, 'b')) def _has_mixture_(self): return True @@ -41,22 +41,22 @@ def _has_mixture_(self): class ReturnsNonnormalizedTuple: def _mixture_(self): - return ((0.4, "a"), (0.4, "b")) + return ((0.4, 'a'), (0.4, 'b')) class ReturnsNegativeProbability: def _mixture_(self): - return ((0.4, "a"), (-0.4, "b")) + return ((0.4, 'a'), (-0.4, 'b')) class ReturnsGreaterThanUnityProbability: def _mixture_(self): - return ((1.2, "a"), (0.4, "b")) + return ((1.2, 'a'), (0.4, 'b')) class ReturnsMixtureButNoHasMixture: def _mixture_(self): - return ((0.4, "a"), (0.6, "b")) + return ((0.4, 'a'), (0.6, 'b')) class ReturnsUnitary: @@ -76,10 +76,10 @@ def _has_unitary_(self): @pytest.mark.parametrize( - "val,mixture", + 'val,mixture', ( - (ReturnsValidTuple(), ((0.4, "a"), (0.6, "b"))), - (ReturnsNonnormalizedTuple(), ((0.4, "a"), (0.4, "b"))), + (ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))), + (ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))), (ReturnsUnitary(), ((1.0, np.ones((2, 2))),)), ), ) @@ -89,20 +89,20 @@ def test_objects_with_mixture(val, mixture): np.testing.assert_almost_equal(keys, expected_keys) np.testing.assert_equal(values, expected_values) - keys, values = zip(*cirq.mixture(val, ((0.3, "a"), (0.7, "b")))) + keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b')))) np.testing.assert_almost_equal(keys, expected_keys) np.testing.assert_equal(values, expected_values) @pytest.mark.parametrize( - "val", (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) + 'val', (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) ) def test_objects_with_no_mixture(val): with pytest.raises(TypeError, match="mixture"): _ = cirq.mixture(val) assert cirq.mixture(val, None) is None assert cirq.mixture(val, NotImplemented) is NotImplemented - default = ((0.4, "a"), (0.6, "b")) + default = ((0.4, 'a'), (0.6, 'b')) assert cirq.mixture(val, default) == default @@ -203,11 +203,11 @@ def _mixture_(self): @pytest.mark.parametrize( - "val,message", + 'val,message', ( - (ReturnsNonnormalizedTuple(), "1.0"), - (ReturnsNegativeProbability(), "less than 0"), - (ReturnsGreaterThanUnityProbability(), "greater than 1"), + (ReturnsNonnormalizedTuple(), '1.0'), + (ReturnsNegativeProbability(), 'less than 0'), + (ReturnsGreaterThanUnityProbability(), 'greater than 1'), ), ) def test_invalid_mixture(val, message): @@ -216,5 +216,5 @@ def test_invalid_mixture(val, message): def test_missing_mixture(): - with pytest.raises(TypeError, match="_mixture_"): + with pytest.raises(TypeError, match='_mixture_'): cirq.validate_mixture(NoMethod) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index cb1353d2846..a7ba0dd6385 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -52,10 +52,10 @@ def assert_implements_consistent_protocols( val: Any, *, - exponents: Sequence[Any] = (0, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol("s")), + exponents: Sequence[Any] = (0, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol('s')), qubit_count: Optional[int] = None, ignoring_global_phase: bool = False, - setup_code: str = "import cirq\nimport numpy as np\nimport sympy", + setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, ) -> None: @@ -86,11 +86,11 @@ def assert_implements_consistent_protocols( def assert_eigengate_implements_consistent_protocols( eigen_gate_type: Type[ops.EigenGate], *, - exponents: Sequence[value.TParamVal] = (0, 0.5, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol("s")), + exponents: Sequence[value.TParamVal] = (0, 0.5, 1, -1, 0.25, -0.5, 0.1, sympy.Symbol('s')), global_shifts: Sequence[float] = (0, -0.5, 0.1), qubit_count: Optional[int] = None, ignoring_global_phase: bool = False, - setup_code: str = "import cirq\nimport numpy as np\nimport sympy", + setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, ) -> None: @@ -172,7 +172,7 @@ def assert_commutes_magic_method_consistent_with_unitaries( *vals: Sequence[Any], atol: Union[int, float] = 1e-8 ) -> None: if any(isinstance(val, ops.Operation) for val in vals): - raise TypeError("`_commutes_` need not be consistent with unitaries for `Operation`.") + raise TypeError('`_commutes_` need not be consistent with unitaries for `Operation`.') unitaries = [protocols.unitary(val, None) for val in vals] pairs = itertools.permutations(zip(vals, unitaries), 2) for (left_val, left_unitary), (right_val, right_unitary) in pairs: From bdad3ea74ca47ff3639c512a0db99eb4bc2f07dc Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sat, 2 Oct 2021 22:07:23 +0530 Subject: [PATCH 23/33] Missed on reversion. --- cirq-core/cirq/protocols/kraus_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index a149ee88842..256f02d91a8 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -106,7 +106,7 @@ def _has_kraus_(self) -> bool: """ -@deprecated(deadline="v0.13", fix="use cirq.kraus instead") +@deprecated(deadline='v0.13', fix='use cirq.kraus instead') def channel( val: Any, default: Any = RaiseTypeErrorIfNotProvided ) -> Union[Tuple[np.ndarray, ...], TDefault]: From 742ee483936618bce07f0191676320e4d5a67093 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sat, 2 Oct 2021 22:57:52 +0530 Subject: [PATCH 24/33] Using `_try_decompose_into_operations_and_qubits` instead of `decompose`. --- cirq-core/cirq/protocols/kraus_protocol.py | 13 +++---------- cirq-core/cirq/protocols/mixture_protocol.py | 18 ++---------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 256f02d91a8..bf3c083e0b2 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -27,11 +27,8 @@ from cirq._doc import doc_private from cirq.protocols.decompose_protocol import ( _try_decompose_into_operations_and_qubits, - decompose, ) from cirq.protocols import mixture_protocol, has_unitary_protocol -from cirq.ops.raw_types import Qid - from cirq.type_workarounds import NotImplementedType @@ -171,13 +168,9 @@ def kraus( if mixture_result is not None and mixture_result is not NotImplemented: return tuple(np.sqrt(p) * u for p, u in mixture_result) - decomposed = decompose(val) + decomposed, qubits, _ = _try_decompose_into_operations_and_qubits(val) - if decomposed != [val]: - qubits: List[Qid] = [] - for x in decomposed: - qubits.extend(x.qubits) - qubits = sorted(list(set(qubits))) + if decomposed is not None and decomposed != [val]: limit = (4 ** np.prod(len(qubits))) ** 2 kraus_list = list(map(lambda x: _kraus_tensor(x, qubits, default), decomposed)) @@ -261,7 +254,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Returns: Whether or not `val` has a Kraus representation. """ - channel_getter = getattr(val, "_has_channel_", None) + channel_getter = getattr(val, '_has_channel_', None) if channel_getter is not None: warnings.warn( '_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_', diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 994c4db5964..b732372c8c1 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -20,15 +20,9 @@ from cirq._doc import doc_private from cirq.protocols.decompose_protocol import ( _try_decompose_into_operations_and_qubits, - decompose, - decompose_once_with_qubits, ) from cirq.protocols import has_unitary_protocol, unitary_protocol -from cirq.protocols.qid_shape_protocol import qid_shape from cirq.type_workarounds import NotImplementedType -from cirq.devices.line_qubit import LineQid - -from cirq.ops.raw_types import Qid, Gate # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. @@ -100,18 +94,10 @@ def mixture( if unitary_result is not None and unitary_result is not NotImplemented: return ((1.0, unitary_result),) - decomposed = ( - decompose_once_with_qubits(val, LineQid.for_qid_shape(qid_shape(val)), []) - if isinstance(val, Gate) - else decompose(val) - ) + decomposed, qubits, _ = _try_decompose_into_operations_and_qubits(val) # serial concatenation - if decomposed != [] and decomposed != [val]: - qubits: List[Qid] = [] - for x in decomposed: - qubits.extend(x.qubits) - qubits = list(set(qubits)) + if decomposed is not None and decomposed != [val]: limit = (4 ** np.prod(len(qubits))) ** 2 qubits = sorted(list(set(qubits))) From 14c4ada83e21af66ba8b56cc233179e7c7fefec1 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 3 Oct 2021 01:35:59 +0530 Subject: [PATCH 25/33] Restructured the protocols. --- cirq-core/cirq/protocols/kraus_protocol.py | 88 +++++++------------- cirq-core/cirq/protocols/mixture_protocol.py | 42 +++++----- 2 files changed, 55 insertions(+), 75 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index bf3c083e0b2..69d0ac63962 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -14,7 +14,7 @@ """Protocol and methods for quantum channels.""" -from typing import Any, Sequence, Tuple, TypeVar, Union, List +from typing import Any, Sequence, Tuple, TypeVar, Union import warnings import numpy as np @@ -28,7 +28,7 @@ from cirq.protocols.decompose_protocol import ( _try_decompose_into_operations_and_qubits, ) -from cirq.protocols import mixture_protocol, has_unitary_protocol +from cirq.protocols import mixture_protocol from cirq.type_workarounds import NotImplementedType @@ -149,20 +149,9 @@ def kraus( method returned NotImplemented) and also no default value was specified. """ - channel_getter = getattr(val, '_channel_', None) - if channel_getter is not None: - warnings.warn( - '_channel_ is deprecated and will be removed in cirq 0.13, rename to _kraus_', - DeprecationWarning, - ) - - channel_result = NotImplemented if channel_getter is None else channel_getter() - if channel_result is not NotImplemented: - return tuple(channel_result) - - _, kraus_result = _strat_kraus_from_kraus(val) - if kraus_result is not None and kraus_result is not NotImplemented: - return kraus_result + result = _gettr_helper(val, ['_kraus_', '_channel_']) + if result is not None and result is not NotImplemented: + return result mixture_result = mixture_protocol.mixture(val, None) if mixture_result is not None and mixture_result is not NotImplemented: @@ -186,10 +175,7 @@ def kraus( if default is not RaiseTypeErrorIfNotProvided: return default - if not any( - getattr(val, instance, None) is not None - for instance in ["_kraus_", "_unitary_", "_mixture_"] - ): + if _gettr_helper(val, ['_kraus_', '_unitary_', '_mixture_']) is None: raise TypeError( "object of type '{}' has no _kraus_ or _mixture_ or " "_unitary_ method.".format(type(val)) @@ -230,13 +216,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Case b) Method returns a 3D array. Kraus. - 4. Try to use `cirq.unitary()`. - Case a) Method not present or returns `NotImplemented`. - No Kraus. - Case b) Method returns a 3D array. - Kraus. - - 5. If decomposition is allowed apply recursion and check. + 4. If decomposition is allowed apply recursion and check. If all the above methods fail then it is assumed to have no Kraus representation. @@ -254,29 +234,11 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Returns: Whether or not `val` has a Kraus representation. """ - channel_getter = getattr(val, '_has_channel_', None) - if channel_getter is not None: - warnings.warn( - '_has_channel_ is deprecated and will be removed in cirq 0.13, rename to _has_kraus_', - DeprecationWarning, - ) - - result = NotImplemented if channel_getter is None else channel_getter() - if result is not NotImplemented and result: - return True - - for instance in [has_unitary_protocol.has_unitary, mixture_protocol.has_mixture]: - result = instance(val) - if result is not NotImplemented and result: - return True - - getter = getattr(val, "_has_kraus_", None) - result = NotImplemented if getter is None else getter() - if result is not NotImplemented and result: + result = _gettr_helper(val, ['_has_kraus_', '_has_channel_', '_kraus_', '_channel_']) + if result is not None and result is not NotImplemented and result: return True - strats = [_strat_kraus_from_kraus] - if any(strat(val)[1] is not None and strat(val)[1] is not NotImplemented for strat in strats): + if mixture_protocol.has_mixture(val, allow_decompose=False): return True if allow_decompose: @@ -284,6 +246,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: if operations is not None: return all(has_kraus(val) for val in operations) + # No has methods, use `_kraus_` or delegates instead. return False @@ -322,11 +285,24 @@ def _kraus_tensor(op, qubits, default): return val -def _strat_kraus_from_kraus(val: Any): - """Attempts to compute the value's kraus via its _kraus_ method.""" - kraus_getter = getattr(val, "_kraus_", None) - kraus_result = NotImplemented if kraus_getter is None else kraus_getter() - if kraus_result is not NotImplemented: - return kraus_getter, tuple(kraus_result) - - return kraus_getter, kraus_result +def _gettr_helper(val: Any, gett_str_list: Sequence[str]): + notImplementedFlag = False + for gettr_str in gett_str_list: + gettr = getattr(val, gettr_str, None) + if gettr is None: + continue + if 'channel' in gettr_str: + warnings.warn( + f'{gettr_str} is deprecated and will be removed in cirq 0.13, rename to ' + f'{gettr_str.replace("channel", "kraus")}', + DeprecationWarning, + ) + result = gettr() + if result is NotImplemented: + notImplementedFlag = True + elif result is not None: + return result + + if notImplementedFlag: + return NotImplemented + return None diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index b732372c8c1..cdfb7925596 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Protocol for objects that are mixtures (probabilistic combinations).""" -from typing import Any, Sequence, Tuple, Union, List +from typing import Any, Sequence, Tuple, Union import numpy as np from typing_extensions import Protocol @@ -86,7 +86,7 @@ def mixture( with that probability in the mixture. The probabilities will sum to 1.0. """ - mixture_result = _strat_mixture_from_mixture(val) + mixture_result = _gettr_helper(val, ['_mixture_']) if mixture_result is not None and mixture_result is not NotImplemented: return mixture_result @@ -100,7 +100,6 @@ def mixture( if decomposed is not None and decomposed != [val]: limit = (4 ** np.prod(len(qubits))) ** 2 - qubits = sorted(list(set(qubits))) mixture_list = list(map(lambda x: _mixture_tensor(x, qubits, default), decomposed)) if not any([_check_equality(x, default) for x in mixture_list]): mixture_result = mixture_list[0] @@ -119,7 +118,7 @@ def mixture( if default is not RaiseTypeErrorIfNotProvided: return default - if not any(getattr(val, instance, None) is not None for instance in ["_unitary_", "_mixture_"]): + if _gettr_helper(val, ['_unitary_', '_mixture_']) is None: raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.") raise TypeError( @@ -148,16 +147,11 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: has a `_mixture_` method return True if that has a non-default value. Returns False if neither function exists. """ - mixture_getter = getattr(val, '_has_mixture_', None) - result = NotImplemented if mixture_getter is None else mixture_getter() - if result is not NotImplemented: - return result - - if has_unitary_protocol.has_unitary(val, allow_decompose=False): + result = _gettr_helper(val, ['_has_mixture_', '_mixture_']) + if result is not None and result is not NotImplemented and result: return True - strats = [_strat_mixture_from_mixture] - if any(strat(val) is not None and strat(val) is not NotImplemented for strat in strats): + if has_unitary_protocol.has_unitary(val, allow_decompose=False): return True if allow_decompose: @@ -188,13 +182,6 @@ def validate_probability(p, p_str): raise ValueError("Sum of probabilities of a mixture was not 1.0") -def _strat_mixture_from_mixture(val: Any): - """Attempts to compute the value's mixture via its _mixture_ method.""" - mixture_getter = getattr(val, "_mixture_", None) - result = NotImplemented if mixture_getter is None else mixture_getter() - return result - - def _check_equality(x, y): if type(x) != type(y): return False @@ -240,3 +227,20 @@ def _mixture_tensor(op, qubits, default): val = tuple([_tensor_mixture_pair(x, (1, np.identity(2))) for x in val]) return val + + +def _gettr_helper(val: Any, gett_str_list: Sequence[str]): + notImplementedFlag = False + for gettr_str in gett_str_list: + gettr = getattr(val, gettr_str, None) + if gettr is None: + continue + result = gettr() + if result is NotImplemented: + notImplementedFlag = True + elif result is not None: + return result + + if notImplementedFlag: + return NotImplemented + return None From bea0d440d8c83cdc7caf97f2fcc03ec83433673b Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Mon, 25 Oct 2021 19:40:16 +0530 Subject: [PATCH 26/33] Comparision is superoperator based. --- cirq-core/cirq/protocols/kraus_protocol.py | 16 ++- .../cirq/protocols/kraus_protocol_test.py | 107 ++++-------------- 2 files changed, 30 insertions(+), 93 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 8f8791a8b00..36e52e6a829 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -27,6 +27,8 @@ from cirq.protocols import mixture_protocol from cirq.type_workarounds import NotImplementedType +from cirq.qis.channels import kraus_to_superoperator, superoperator_to_kraus + # This is a special indicator value used by the channel method to determine # whether or not the caller provided a 'default' argument. It must be of type @@ -144,17 +146,13 @@ def kraus( decomposed, qubits, _ = _try_decompose_into_operations_and_qubits(val) if decomposed is not None and decomposed != [val]: - limit = (4 ** np.prod(len(qubits))) ** 2 - kraus_list = list(map(lambda x: _kraus_tensor(x, qubits, default), decomposed)) + kraus_list = list(map(lambda x: _kraus_to_superoperator(x, qubits, default), decomposed)) if not any([_check_equality(x, default) for x in kraus_list]): kraus_result = kraus_list[0] for i in range(1, len(kraus_list)): - kraus_result = [op_2.dot(op_1) for op_1 in kraus_result for op_2 in kraus_list[i]] - assert ( - len(kraus_result) < limit - ), f"{val} kraus decomposition had combinatorial explosion." - return tuple(kraus_result) + kraus_result = kraus_result @ kraus_list[i] + return tuple(superoperator_to_kraus(kraus_result)) if default is not RaiseTypeErrorIfNotProvided: return default @@ -239,7 +237,7 @@ def _check_equality(x, y): return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) -def _kraus_tensor(op, qubits, default): +def _kraus_to_superoperator(op, qubits, default): kraus_list = kraus(op, default) if _check_equality(kraus_list, default): return default @@ -261,7 +259,7 @@ def _kraus_tensor(op, qubits, default): else: val = tuple([np.kron(x, np.identity(2)) for x in val]) - return val + return kraus_to_superoperator(val) def _gettr_helper(val: Any, gett_str_list: Sequence[str]): diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 3eabff75ee4..0d2344ea8fc 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -52,25 +52,6 @@ def assert_not_implemented(val): assert not cirq.has_kraus(val) -def test_supports_channel_class_is_deprecated(): - with cirq.testing.assert_deprecated(deadline='v0.13'): - - class SomeChannel(cirq.SupportsChannel): - pass - - _ = SomeChannel() - - -def test_channel_protocol_is_deprecated(): - with cirq.testing.assert_deprecated(deadline='v0.13'): - assert np.allclose(cirq.channel(cirq.X), cirq.kraus(cirq.X)) - - -def test_has_channel_protocol_is_deprecated(): - with cirq.testing.assert_deprecated(deadline="v0.13"): - assert cirq.has_channel(cirq.depolarize(0.1)) == cirq.has_kraus(cirq.depolarize(0.1)) - - def test_kraus_returns_not_implemented(): class ReturnsNotImplemented: def _kraus_(self): @@ -140,6 +121,22 @@ def _mixture_(self) -> Iterable[Tuple[float, np.ndarray]]: assert cirq.has_kraus(ReturnsMixture()) +def test_kraus_fallback_to_unitary(): + u = np.array([[1, 0], [1, 0]]) + + class ReturnsUnitary: + def _unitary_(self) -> np.ndarray: + return u + + np.testing.assert_equal(cirq.kraus(ReturnsUnitary()), (u,)) + np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), None), (u,)) + np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), NotImplemented), (u,)) + np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), (1,)), (u,)) + np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), LOCAL_DEFAULT), (u,)) + + assert cirq.has_kraus(ReturnsUnitary()) + + def test_serial_concatenation_default(): q1 = cirq.GridQubit(1, 1) @@ -201,74 +198,17 @@ def _mixture_(self): return None g = onlyDecompose() - c = (cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),) + c = cirq.kraus_to_superoperator((cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),)) - np.testing.assert_equal(cirq.kraus(g), c) - np.testing.assert_equal(cirq.kraus(g, None), c) - np.testing.assert_equal(cirq.kraus(g, NotImplemented), c) - np.testing.assert_equal(cirq.kraus(g, (1,)), c) - np.testing.assert_equal(cirq.kraus(g, LOCAL_DEFAULT), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, None)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, NotImplemented)), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, (1,))), c) + np.testing.assert_almost_equal(cirq.kraus_to_superoperator(cirq.kraus(g, LOCAL_DEFAULT)), c) assert cirq.has_kraus(g) -def test_kraus_combinatorial_explosion(): - q1 = cirq.GridQubit(1, 1) - - class defaultGate(cirq.Gate): - def num_qubits(self): - return 1 - - def _kraus_(self): - # for one qubit the upper limit is 16 elements - ls = [np.array([[1, 0], [0, 1]])] * 16 - ls.append(np.array([[1, 0], [0, 1]])) - return tuple(ls) - - class onlyDecompose: - def _decompose_(self): - return [cirq.Y.on(q1), defaultGate().on(q1)] - - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - - with pytest.raises(AssertionError, match="combinatorial explosion."): - _ = cirq.kraus(onlyDecompose()) - - -def test_kraus_empty_decompose(): - g = cirq.PauliString({}) ** 2 - c = (cirq.unitary(g),) - - assert cirq.has_kraus(g) - - np.testing.assert_equal(cirq.kraus(g), c) - np.testing.assert_equal(cirq.kraus(g, None), c) - np.testing.assert_equal(cirq.kraus(g, NotImplemented), c) - np.testing.assert_equal(cirq.kraus(g, (1,)), c) - np.testing.assert_equal(cirq.kraus(g, LOCAL_DEFAULT), c) - - -def test_kraus_fallback_to_unitary(): - u = np.array([[1, 0], [1, 0]]) - - class ReturnsUnitary: - def _unitary_(self) -> np.ndarray: - return u - - assert cirq.has_kraus(ReturnsUnitary()) - np.testing.assert_equal(cirq.kraus(ReturnsUnitary()), (u,)) - np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), None), (u,)) - np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), NotImplemented), (u,)) - np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), (1,)), (u,)) - np.testing.assert_equal(cirq.kraus(ReturnsUnitary(), LOCAL_DEFAULT), (u,)) - - assert cirq.has_kraus(ReturnsUnitary()) - - class HasKraus(cirq.SingleQubitGate): def _has_kraus_(self) -> bool: return True @@ -301,5 +241,4 @@ def test_has_kraus(cls): def test_has_kraus_when_decomposed(decomposed_cls): op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) assert cirq.has_kraus(op) - if not cirq.has_unitary(op) and not cirq.has_mixture(op): - assert not cirq.has_kraus(op, allow_decompose=False) + assert not cirq.has_kraus(op, allow_decompose=False) From 01d2a187694b5c4ffda066993bdb4fe8994f5cb3 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Mon, 25 Oct 2021 20:09:20 +0530 Subject: [PATCH 27/33] Removed `_channel_` uses. --- cirq-core/cirq/protocols/kraus_protocol.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 36e52e6a829..b40d29b98bd 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -135,7 +135,7 @@ def kraus( method returned NotImplemented) and also no default value was specified. """ - result = _gettr_helper(val, ['_kraus_', '_channel_']) + result = _gettr_helper(val, ['_kraus_']) if result is not None and result is not NotImplemented: return result @@ -211,7 +211,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Returns: Whether or not `val` has a Kraus representation. """ - result = _gettr_helper(val, ['_has_kraus_', '_has_channel_', '_kraus_', '_channel_']) + result = _gettr_helper(val, ['_has_kraus_', '_has_channel_', '_kraus_']) if result is not None and result is not NotImplemented and result: return True @@ -268,12 +268,6 @@ def _gettr_helper(val: Any, gett_str_list: Sequence[str]): gettr = getattr(val, gettr_str, None) if gettr is None: continue - if 'channel' in gettr_str: - warnings.warn( - f'{gettr_str} is deprecated and will be removed in cirq 0.13, rename to ' - f'{gettr_str.replace("channel", "kraus")}', - DeprecationWarning, - ) result = gettr() if result is NotImplemented: notImplementedFlag = True From e17b162396e58aacf1f932d9f889bbca41444ea1 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Mon, 25 Oct 2021 20:28:25 +0530 Subject: [PATCH 28/33] Removed unused imports. --- cirq-core/cirq/protocols/kraus_protocol.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index b40d29b98bd..8fafc550b8a 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -15,8 +15,6 @@ """Protocol and methods for obtaining Kraus representation of quantum channels.""" from typing import Any, Sequence, Tuple, TypeVar, Union -import warnings - import numpy as np from typing_extensions import Protocol From c855084be1511f04833170902cf5d98c7ce86cb6 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Wed, 5 Jan 2022 13:54:57 +0530 Subject: [PATCH 29/33] Using superoperator instead of kraus during serial concatenation. --- cirq-core/cirq/protocols/kraus_protocol.py | 64 ++++++------------- .../cirq/protocols/kraus_protocol_test.py | 10 +-- cirq-core/cirq/protocols/mixture_protocol.py | 9 +-- 3 files changed, 27 insertions(+), 56 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 8fafc550b8a..153c3e2903d 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -17,15 +17,15 @@ from typing import Any, Sequence, Tuple, TypeVar, Union import numpy as np from typing_extensions import Protocol +from functools import reduce from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) -from cirq.protocols import mixture_protocol + +from cirq.protocols import mixture_protocol, decompose_protocol from cirq.type_workarounds import NotImplementedType -from cirq.qis.channels import kraus_to_superoperator, superoperator_to_kraus +from cirq import qis +from cirq.ops import Moment # This is a special indicator value used by the channel method to determine @@ -133,6 +133,7 @@ def kraus( method returned NotImplemented) and also no default value was specified. """ + result = _gettr_helper(val, ['_kraus_']) if result is not None and result is not NotImplemented: return result @@ -141,16 +142,17 @@ def kraus( if mixture_result is not None and mixture_result is not NotImplemented: return tuple(np.sqrt(p) * u for p, u in mixture_result) - decomposed, qubits, _ = _try_decompose_into_operations_and_qubits(val) + decomposed, qubits, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) if decomposed is not None and decomposed != [val]: - kraus_list = list(map(lambda x: _kraus_to_superoperator(x, qubits, default), decomposed)) - if not any([_check_equality(x, default) for x in kraus_list]): - kraus_result = kraus_list[0] - for i in range(1, len(kraus_list)): - kraus_result = kraus_result @ kraus_list[i] - return tuple(superoperator_to_kraus(kraus_result)) + superoperator_list = list(map(lambda x: _moment_superoperator(x, qubits, None), decomposed)) + if not any([x is None for x in superoperator_list]): + superoperator_result = reduce(lambda x, y: x @ y, superoperator_list) + # superoperator_result = superoperator_list[0] + # for i in range(1, len(superoperator_list)): + # superoperator_result = superoperator_result @ superoperator_list[i] + return tuple(qis.superoperator_to_kraus(superoperator_result)) if default is not RaiseTypeErrorIfNotProvided: return default @@ -217,7 +219,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: return True if allow_decompose: - operations, _, _ = _try_decompose_into_operations_and_qubits(val) + operations, _, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) if operations is not None: return all(has_kraus(val) for val in operations) @@ -225,39 +227,9 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: return False -def _check_equality(x, y): - if type(x) != type(y): - return False - if type(x) not in [list, tuple, np.ndarray]: - return x == y - if type(x) == np.ndarray: - return x.shape == y.shape and np.all(x == y) - return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) - - -def _kraus_to_superoperator(op, qubits, default): - kraus_list = kraus(op, default) - if _check_equality(kraus_list, default): - return default - - val = None - op_q = op.qubits - found = False - for i in range(len(qubits)): - if qubits[i] in op_q: - if not found: - found = True - if val is None: - val = kraus_list - else: - val = tuple([np.kron(x, y) for x in val for y in kraus_list]) - - elif val is None: - val = (np.identity(2),) - else: - val = tuple([np.kron(x, np.identity(2)) for x in val]) - - return kraus_to_superoperator(val) +def _moment_superoperator(op, qubits, default): + superoperator_result = Moment(op).expand_to(qubits)._superoperator_() + return superoperator_result if superoperator_result is not NotImplemented else default def _gettr_helper(val: Any, gett_str_list: Sequence[str]): diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 0d2344ea8fc..2b82a1d3370 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -163,10 +163,12 @@ def _unitary_(self): def _mixture_(self): return None - with pytest.raises(TypeError, match="returned NotImplemented"): - _ = cirq.kraus(onlyDecompose()) - assert cirq.kraus(onlyDecompose(), 0) == 0 - assert not cirq.has_kraus(onlyDecompose()) + # with pytest.raises(TypeError, match="returned NotImplemented."): + # _ = cirq.kraus(onlyDecompose()) + # x = cirq.kraus(onlyDecompose(), None) + # print(x) + assert not cirq.kraus(onlyDecompose(), 1) == 1 + # assert not cirq.has_kraus(onlyDecompose()) def test_serial_concatenation_circuit(): diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 0496daef1e5..1771b3816de 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -18,10 +18,7 @@ from typing_extensions import Protocol from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import ( - _try_decompose_into_operations_and_qubits, -) -from cirq.protocols import has_unitary_protocol, unitary_protocol +from cirq.protocols import has_unitary_protocol, unitary_protocol, decompose_protocol from cirq.type_workarounds import NotImplementedType # This is a special indicator value used by the inverse method to determine @@ -96,7 +93,7 @@ def mixture( if unitary_result is not None and unitary_result is not NotImplemented: return ((1.0, unitary_result),) - decomposed, qubits, _ = _try_decompose_into_operations_and_qubits(val) + decomposed, qubits, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) # serial concatenation if decomposed is not None and decomposed != [val]: @@ -156,7 +153,7 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: return True if allow_decompose: - operations, _, _ = _try_decompose_into_operations_and_qubits(val) + operations, _, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) if operations is not None: return all(has_mixture(val) for val in operations) From 321b72ac6e0cf83d00f5163671d9b229854df3e8 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 11 Jan 2022 18:25:24 +0530 Subject: [PATCH 30/33] Minor linting changes. --- cirq-core/cirq/protocols/kraus_protocol.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index f5577b22273..d442fa937fe 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -14,7 +14,7 @@ """Protocol and methods for obtaining Kraus representation of quantum channels.""" -from typing import Any, Sequence, Tuple, TypeVar, Union, Optional, List +from typing import Any, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING import numpy as np from typing_extensions import Protocol from functools import reduce @@ -27,6 +27,9 @@ from cirq import qis from cirq.ops import Moment +if TYPE_CHECKING: + import cirq + # This is a special indicator value used by the channel method to determine # whether or not the caller provided a 'default' argument. It must be of type @@ -224,7 +227,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: def _moment_superoperator( - op: Tuple[Optional[List['cirq.Operation']]], qubits: Sequence['cirq.Qid'], default: Any + op: Union['cirq.Operation'], qubits: Sequence['cirq.Qid'], default: Any ) -> Union[np.ndarray, TDefault]: superoperator_result = Moment(op).expand_to(qubits)._superoperator_() return superoperator_result if superoperator_result is not NotImplemented else default From 3359a962601f2997e81934bc89e2b9b461d91d30 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Tue, 11 Jan 2022 22:18:33 +0530 Subject: [PATCH 31/33] Restoring `mixture` changes. --- cirq-core/cirq/protocols/kraus_protocol.py | 2 +- .../cirq/protocols/kraus_protocol_test.py | 16 +-- cirq-core/cirq/protocols/mixture_protocol.py | 122 +++--------------- .../cirq/protocols/mixture_protocol_test.py | 84 ------------ 4 files changed, 27 insertions(+), 197 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index d442fa937fe..7e3814368bb 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -15,9 +15,9 @@ """Protocol and methods for obtaining Kraus representation of quantum channels.""" from typing import Any, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING +from functools import reduce import numpy as np from typing_extensions import Protocol -from functools import reduce from cirq._doc import doc_private diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 2145ceb550c..f7476637e65 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -148,7 +148,7 @@ def _kraus_(self): return NotImplemented def _unitary_(self): - return None + return NotImplemented def _mixture_(self): return NotImplemented @@ -158,10 +158,10 @@ def _decompose_(self): return [cirq.Y.on(q1), defaultGate().on(q1)] def _unitary_(self): - return None + return NotImplemented def _mixture_(self): - return None + return NotImplemented with pytest.raises(TypeError, match="_unitary_ method."): _ = cirq.kraus(onlyDecompose()) @@ -180,22 +180,16 @@ def num_qubits(self): def _kraus_(self): return cirq.kraus(cirq.X) - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - class onlyDecompose: def _decompose_(self): circ = cirq.Circuit([cirq.Y.on(q1), defaultGate().on(q2)]) return cirq.decompose(circ) def _unitary_(self): - return None + return NotImplemented def _mixture_(self): - return None + return NotImplemented g = onlyDecompose() c = cirq.kraus_to_superoperator((cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])),)) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index b29a592927c..646c0b0cd74 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -18,7 +18,8 @@ from typing_extensions import Protocol from cirq._doc import doc_private -from cirq.protocols import has_unitary_protocol, unitary_protocol, decompose_protocol +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits +from cirq.protocols.has_unitary_protocol import has_unitary from cirq.type_workarounds import NotImplementedType # This is a special indicator value used by the inverse method to determine @@ -85,39 +86,20 @@ def mixture( does and this method returned `NotImplemented`. """ - mixture_result = _gettr_helper(val, ['_mixture_']) - if mixture_result is not None and mixture_result is not NotImplemented: - return mixture_result - - unitary_result = unitary_protocol.unitary(val, None) - if unitary_result is not None and unitary_result is not NotImplemented: - return ((1.0, unitary_result),) - - decomposed, qubits, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) - - # serial concatenation - if decomposed is not None and decomposed != [val]: - limit = (4 ** np.prod(len(qubits))) ** 2 - - mixture_list = list(map(lambda x: _mixture_tensor(x, qubits, default), decomposed)) - if not any([_check_equality(x, default) for x in mixture_list]): - mixture_result = mixture_list[0] - for i in range(1, len(mixture_list)): - mixture_result = [ - _product_mixture_pair(op_1, op_2) - for op_1 in mixture_result - for op_2 in mixture_list[i] - ] - assert ( - len(mixture_result) < limit - ), f"{val} mixture decomposition had combinatorial explosion." - else: - return tuple(mixture_result) + mixture_getter = getattr(val, '_mixture_', None) + result = NotImplemented if mixture_getter is None else mixture_getter() + if result is not NotImplemented: + return result + + unitary_getter = getattr(val, '_unitary_', None) + result = NotImplemented if unitary_getter is None else unitary_getter() + if result is not NotImplemented: + return ((1.0, result),) if default is not RaiseTypeErrorIfNotProvided: return default - if _gettr_helper(val, ['_unitary_', '_mixture_']) is None: + if mixture_getter is None and unitary_getter is None: raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.") raise TypeError( @@ -145,19 +127,21 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: has a `_mixture_` method return True if that has a non-default value. Returns False if neither function exists. """ - result = _gettr_helper(val, ['_has_mixture_', '_mixture_']) - if result is not None and result is not NotImplemented and result: - return True + mixture_getter = getattr(val, '_has_mixture_', None) + result = NotImplemented if mixture_getter is None else mixture_getter() + if result is not NotImplemented: + return result - if has_unitary_protocol.has_unitary(val, allow_decompose=False): + if has_unitary(val, allow_decompose=False): return True if allow_decompose: - operations, _, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) + operations, _, _ = _try_decompose_into_operations_and_qubits(val) if operations is not None: return all(has_mixture(val) for val in operations) - return False + # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead. + return mixture(val, None) is not None def validate_mixture(supports_mixture: SupportsMixture): @@ -177,68 +161,4 @@ def validate_probability(p, p_str): validate_probability(p, f"{val}'s probability") total += p if not np.isclose(total, 1.0): - raise ValueError("Sum of probabilities of a mixture was not 1.0") - - -def _check_equality(x, y): - if type(x) != type(y): - return False - if type(x) not in [list, tuple, np.ndarray]: - return x == y - if type(x) == np.ndarray: - return x.shape == y.shape and np.all(x == y) - return False if len(x) != len(y) else all([_check_equality(a, b) for a, b in zip(x, y)]) - - -def _tensor_mixture_pair(x, y): - p_new = x[0] * y[0] - mat_new = np.kron(x[1], y[1]) - return (p_new, mat_new) - - -def _product_mixture_pair(x, y): - p_new = x[0] * y[0] - mat_new = y[1].dot(x[1]) - return (p_new, mat_new) - - -def _mixture_tensor(op, qubits, default): - mixture_list = mixture(op, default) - if _check_equality(mixture_list, default): - return default - - val = None - op_q = op.qubits - found = False - for i in range(len(qubits)): - if qubits[i] in op_q: - if not found: - found = True - if val is None: - val = mixture_list - else: - val = tuple([_tensor_mixture_pair(x, y) for x in val for y in mixture_list]) - - elif val is None: - val = ((1, np.identity(2)),) - else: - val = tuple([_tensor_mixture_pair(x, (1, np.identity(2))) for x in val]) - - return val - - -def _gettr_helper(val: Any, gett_str_list: Sequence[str]): - notImplementedFlag = False - for gettr_str in gett_str_list: - gettr = getattr(val, gettr_str, None) - if gettr is None: - continue - result = gettr() - if result is NotImplemented: - notImplementedFlag = True - elif result is not None: - return result - - if notImplementedFlag: - return NotImplemented - return None + raise ValueError('Sum of probabilities of a mixture was not 1.0') diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index e5d8026b116..cadbe0806c6 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -118,90 +118,6 @@ def test_valid_mixture(): cirq.validate_mixture(ReturnsValidTuple()) -def test_combinatorial_explosion(): - q1 = cirq.GridQubit(1, 1) - - class defaultGate(cirq.Gate): - def num_qubits(self): - return 1 - - def _mixture_(self): - # for one qubit the upper limit is 16 elements - ls = [(0.1, np.array([[1, 0], [0, 1]]))] * 16 - ls.append((0.84, np.array([[1, 0], [0, 1]]))) - return tuple(ls) - - class onlyDecompose: - def _decompose_(self): - return [cirq.Y.on(q1), defaultGate().on(q1)] - - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - - with pytest.raises(AssertionError, match="combinatorial explosion."): - _ = cirq.mixture(onlyDecompose()) - - -def test_serial_concatenation_default(): - q1 = cirq.GridQubit(1, 1) - - class defaultGate(cirq.Gate): - def num_qubits(self): - return 1 - - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - - class onlyDecompose: - def _decompose_(self): - return [cirq.Y.on(q1), defaultGate().on(q1)] - - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - - default = (1.0, np.array([[1, 0], [0, 1]])) - - with pytest.raises(TypeError, match="returned NotImplemented"): - _ = cirq.mixture(onlyDecompose()) - assert cirq.mixture(onlyDecompose(), 0) == 0 - np.testing.assert_equal(cirq.mixture(onlyDecompose(), default), default) - assert not cirq.has_mixture(onlyDecompose()) - - -def test_serial_concatenation_circuit(): - q1 = cirq.GridQubit(1, 1) - q2 = cirq.GridQubit(1, 2) - - class onlyDecompose: - def _decompose_(self): - circ = cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)]) - return cirq.decompose(circ) - - def _unitary_(self): - return None - - def _mixture_(self): - return NotImplemented - - c = ((1.0, np.array(cirq.unitary(cirq.Circuit([cirq.Y.on(q1), cirq.X.on(q2)])))),) - - np.testing.assert_equal(cirq.mixture(onlyDecompose()), c) - np.testing.assert_equal(cirq.mixture(onlyDecompose(), None), c) - np.testing.assert_equal(cirq.mixture(onlyDecompose(), NotImplemented), c) - np.testing.assert_equal(cirq.mixture(onlyDecompose(), (1,)), c) - - assert cirq.has_mixture(onlyDecompose()) - - @pytest.mark.parametrize( 'val,message', ( From bac774fda1489df026adc14f4f4479d874d9b3db Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Wed, 12 Jan 2022 20:09:16 +0530 Subject: [PATCH 32/33] Documentation changes and minor fixes. --- cirq-core/cirq/protocols/kraus_protocol.py | 59 ++++++++++++++-------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 7e3814368bb..fdaed206238 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -115,21 +115,34 @@ def kraus( where $I$ is the identity matrix. The matrices $A_k$ are sometimes called Kraus or noise operators. + Determines the Kraus representation of `val` by the following strategies: + + 1. Try to use `val._has_kraus_()`. + Case a) Method not present or returns `None`. + Continue to next strategy. + Case b) Returns the Kraus operator. + Method returns the result. + + 2. Try to use `mixture_protocol.mixture()`. + Case a) Method not present or returns `None`. + Continue to next strategy. + Case b) Method returns a valid mixture. + Method converts mixture into kraus and returns. + + 3. Try to use serial concatenation recursively. + Case a) One or more decomposed operators doesn't have Kraus. + `val` does not have a kraus representation. + Case b) All decomposed operators have Kraus representation. + Serially concatenate and return the result. + Args: - val: The value to describe by a channel. + val: The value to describe by Kraus representation. default: Determines the fallback behavior when `val` doesn't have - a channel. If `default` is not set, a TypeError is raised. If - default is set to a value, that value is returned. + a representation. If `default` is not set, a TypeError is raised. + If default is set to a value, that value is returned. Returns: - If `val` has a `_kraus_` method and its result is not NotImplemented, - that result is returned. Otherwise, if `val` has a `_mixture_` method - and its results is not NotImplement a tuple made up of channel - corresponding to that mixture being a probabilistic mixture of unitaries - is returned. Otherwise, if `val` has a `_unitary_` method and - its result is not NotImplemented a tuple made up of that result is - returned. Otherwise, if a default value was specified, the default - value is returned. + The kraus representation of `val`. Raises: TypeError: `val` doesn't have a _kraus_, _unitary_, _mixture_ method @@ -175,23 +188,29 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Determines whether `val` has a Kraus representation by attempting the following strategies: - 1. Try to use `val.has_channel_()`. - Case a) Method not present or returns `None`. + 1. Try to use `val._has_kraus_()`. + Case a) Method not present or returns `None` or returns `False`. + Continue to next strategy. + Case b) Method returns `True`. + return True. + + 1. Try to use `val._has_channel_()`. + Case a) Method not present or returns `None` or returns `False`. Continue to next strategy. Case b) Method returns `True`. - Kraus. + return True. 2. Try to use `val._kraus_()`. Case a) Method not present or returns `NotImplemented`. Continue to next strategy. Case b) Method returns a 3D array. - Kraus. + return True. - 3. Try to use `cirq.mixture()`. - Case a) Method not present or returns `NotImplemented`. + 3. Try to use `cirq.has_mixture()`. + Case a) Method not present or returns `None` or returns `False`. Continue to next strategy. - Case b) Method returns a 3D array. - Kraus. + Case b) Method returns `True`. + return True. 4. If decomposition is allowed apply recursion and check. @@ -220,7 +239,7 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: if allow_decompose: operations, _, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) - if operations is not None: + if operations is not None and operations != [val]: return all(has_kraus(val) for val in operations) return False From 60418350ea278252c944861bb405d72182672a96 Mon Sep 17 00:00:00 2001 From: Zshan0 Date: Sun, 23 Jan 2022 22:53:19 +0530 Subject: [PATCH 33/33] Minor changes. --- cirq-core/cirq/protocols/kraus_protocol.py | 14 ++++++++++---- cirq-core/cirq/testing/consistent_kraus.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index fdaed206238..181a990eb67 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -160,7 +160,7 @@ def kraus( decomposed, qubits, _ = decompose_protocol._try_decompose_into_operations_and_qubits(val) - if decomposed is not None and decomposed != [val]: + if decomposed is not None and decomposed != [val] and decomposed != []: superoperator_list = [_moment_superoperator(x, qubits, None) for x in decomposed] if not any([x is None for x in superoperator_list]): @@ -188,6 +188,8 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Determines whether `val` has a Kraus representation by attempting the following strategies: + #TODO + 1. Try to use `val._has_kraus_()`. Case a) Method not present or returns `None` or returns `False`. Continue to next strategy. @@ -230,8 +232,12 @@ def has_kraus(val: Any, *, allow_decompose: bool = True) -> bool: Returns: Whether or not `val` has a Kraus representation. """ - result = _gettr_helper(val, ['_has_kraus_', '_has_channel_', '_kraus_']) - if result is not None and result is not NotImplemented and result: + result = _gettr_helper(val, ['_has_kraus_', '_has_channel_']) + if result is not None and result is not NotImplemented: + return result + + result = _gettr_helper(val, ['_kraus_']) + if result is not None and result is not NotImplemented: return True if mixture_protocol.has_mixture(val, allow_decompose=False): @@ -252,7 +258,7 @@ def _moment_superoperator( return superoperator_result if superoperator_result is not NotImplemented else default -def _gettr_helper(val: Any, gett_str_list: Sequence[str]): +def _gettr_helper(val: Any, gett_str_list: Sequence[str]) -> Any: notImplementedFlag = False for gettr_str in gett_str_list: gettr = getattr(val, gettr_str, None) diff --git a/cirq-core/cirq/testing/consistent_kraus.py b/cirq-core/cirq/testing/consistent_kraus.py index 225c7209763..3c596eb4562 100644 --- a/cirq-core/cirq/testing/consistent_kraus.py +++ b/cirq-core/cirq/testing/consistent_kraus.py @@ -36,6 +36,7 @@ def assert_kraus_is_consistent_with_unitary(val: Any, ignoring_global_phase: boo # there is unitary and hence must have kraus operator assert has_krs + assert len(krs) == 1 actual = krs[0] if ignoring_global_phase: