Skip to content

Commit

Permalink
Code Quality Fixes in StatePreparationChannel (quantumlib#4503)
Browse files Browse the repository at this point in the history
* Fixing Minor Errors

Fixed many of the nits mentioned in quantumlib#4482, value-equality not a part of this commit.

* Approx Equality support added

State Preparation Channel got equality comparison similar to that of Matrix Gate.

* Added name parameter to state preparation channel

name parameter added to constructor, JSON, and repr serialization as well as test data.

* Custom names in serialization tests

The name has been changed in the JSON and repr protocol tests to make the older version fail.

* Format and Lint Fix

Had to run pylint.

* Adding tests to ignore name

Two gates with different names are still equal, added tests for that.
  • Loading branch information
AnimeshSinha1309 authored Sep 28, 2021
1 parent 52bf065 commit 24308de
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 20 deletions.
42 changes: 28 additions & 14 deletions cirq/ops/state_preparation_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Quantum gates to prepare a given target state."""

from typing import Any, Dict, Tuple, TYPE_CHECKING
from typing import Any, Dict, Tuple, Iterable, TYPE_CHECKING

import numpy as np

Expand All @@ -29,12 +29,12 @@
class StatePreparationChannel(raw_types.Gate):
"""A channel which prepares any state provided as the state vector on it's target qubits."""

def __init__(self, target_state: np.ndarray, name: str = "StatePreparation") -> None:
def __init__(self, target_state: np.ndarray, *, name: str = "StatePreparation") -> None:
"""Initializes a State Preparation channel.
Args:
target_state: The state vector that this gate should prepare.
name: the name of the gate
name: the name of the gate, used when printing it in the circuit diagram
Raises:
ValueError: if the array is not 1D, or does not have 2**n elements for some integer n.
Expand All @@ -51,8 +51,7 @@ def __init__(self, target_state: np.ndarray, name: str = "StatePreparation") ->
self._name = name
self._qid_shape = (2,) * n

@staticmethod
def _has_unitary_() -> bool:
def _has_unitary_(self) -> bool:
"""Checks and returns if the gate has a unitary representation.
It doesn't, since the resetting of the channels is a non-unitary operations,
it involves measurement."""
Expand All @@ -63,19 +62,23 @@ def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'target_state': self._state.tolist(),
'name': self._name,
}

@classmethod
def _from_json_dict_(cls, target_state, **kwargs):
def _from_json_dict_(
cls, target_state: np.ndarray, name: str, **kwargs
) -> 'StatePreparationChannel':
"""Recreates the channel object from it's serialized form
Args:
target_state: the state to prepare using this channel
name: the name of the gate for printing in circuit diagrams
kwargs: other keyword arguments, ignored
"""
return cls(target_state=np.array(target_state))
return cls(target_state=np.array(target_state), name=name)

def _num_qubits_(self):
def _num_qubits_(self) -> int:
return self._num_qubits

def _qid_shape_(self) -> Tuple[int, ...]:
Expand All @@ -92,12 +95,12 @@ def _circuit_diagram_info_(
)
return protocols.CircuitDiagramInfo(wire_symbols=symbols)

@staticmethod
def _has_kraus_():
def _has_kraus_(self) -> bool:
return True

def _kraus_(self):
def _kraus_(self) -> Iterable[np.ndarray]:
"""Returns the Kraus operator for this gate
The Kraus Operator is |Psi><i| for all |i>, where |Psi> is the target state.
This allows is to take any input state to the target state.
The operator satisfies the completeness relation Sum(E^ E) = I.
Expand All @@ -108,13 +111,24 @@ def _kraus_(self):
return operator

def __repr__(self) -> str:
return f'cirq.StatePreparationChannel({proper_repr(self._state)})'
return (
f'cirq.StatePreparationChannel('
f'target_state={proper_repr(self.state)}, name="{self._name}")'
)

def __str__(self) -> str:
return f'StatePreparationChannel({self.state})'

def _approx_eq_(self, other: Any, atol) -> bool:
if not isinstance(other, StatePreparationChannel):
return False
return np.allclose(self.state, other.state, rtol=0, atol=atol)

def __eq__(self, other) -> bool:
if not isinstance(other, StatePreparationChannel):
return False
return np.allclose(self.state, other.state)
return np.array_equal(self.state, other.state)

@property
def state(self):
def state(self) -> np.ndarray:
return self._state
31 changes: 27 additions & 4 deletions cirq/ops/state_preparation_channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,8 @@ def test_gate_params():
assert gate.num_qubits() == 2
assert not gate._has_unitary_()
assert gate._has_kraus_()
assert (
repr(gate)
== 'cirq.StatePreparationChannel(np.array([(1+0j), 0j, 0j, 0j], dtype=np.complex128))'
)
assert str(gate) == 'StatePreparationChannel([1.+0.j 0.+0.j 0.+0.j 0.+0.j])'
cirq.testing.assert_equivalent_repr(gate)


def test_gate_error_handling():
Expand All @@ -126,3 +124,28 @@ def test_equality_of_gates():
gate_2 = cirq.StatePreparationChannel(state)
assert gate_1 == gate_2, "Equal state not leading to same gate"
assert not gate_1 == state, "Incompatible objects shouldn't be equal"
state = np.array([0, 1, 0, 0], dtype=np.complex64)
gate_3 = cirq.StatePreparationChannel(state, name='gate_a')
gate_4 = cirq.StatePreparationChannel(state, name='gate_b')
assert gate_3 == gate_4, "Equal state with different names not leading to same gate"
assert gate_1 != gate_3, "Different states shouldn't lead to same gate"


def test_approx_equality_of_gates():
state = np.array([1, 0, 0, 0], dtype=np.complex64)
gate_1 = cirq.StatePreparationChannel(state)
gate_2 = cirq.StatePreparationChannel(state)
assert cirq.approx_eq(gate_1, gate_2), "Equal state not leading to same gate"
assert not cirq.approx_eq(gate_1, state), "Different object types cannot be approx equal"
perturbed_state = np.array([1 - 1e-9, 1e-10, 0, 0], dtype=np.complex64)
gate_3 = cirq.StatePreparationChannel(perturbed_state)
assert cirq.approx_eq(gate_3, gate_1), "Almost equal states should lead to the same gate"
different_state = np.array([1 - 1e-5, 1e-4, 0, 0], dtype=np.complex64)
gate_4 = cirq.StatePreparationChannel(different_state)
assert not cirq.approx_eq(gate_4, gate_1), "Different states should not lead to the same gate"
assert cirq.approx_eq(
gate_4, gate_1, atol=1e-3
), "Gates with difference in states under the tolerance aren't equal"
assert not cirq.approx_eq(
gate_4, gate_1, atol=1e-6
), "Gates with difference in states over the tolerance are equal"
3 changes: 2 additions & 1 deletion cirq/protocols/json_test_data/StatePreparationChannel.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"real": 0.0,
"imag": 0.0
}
]
],
"name": "StatePrepare"
}
2 changes: 1 addition & 1 deletion cirq/protocols/json_test_data/StatePreparationChannel.repr
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cirq.StatePreparationChannel(np.array([(1+0j), 0j, 0j, 0j], dtype=np.complex128))
cirq.StatePreparationChannel(target_state=np.array([(1+0j), 0j, 0j, 0j], dtype=np.complex128), name="StatePrepare")

0 comments on commit 24308de

Please sign in to comment.