Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Density Matrix and State Vector Simulators to work when an operation allocates new qubits as part of its decomposition #6108

Merged
merged 25 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c2ff8bd
WIP add factoring and kron methods to sim state for adding and removi…
senecameeks May 24, 2023
e673105
add test cases
senecameeks May 26, 2023
80fbbbc
add delegating gate test case
senecameeks May 26, 2023
5f15d97
update test
senecameeks May 26, 2023
cd7a573
all tests pass
senecameeks May 26, 2023
1cd81dd
add test case for unitary Y
senecameeks May 26, 2023
d9a46cc
nit
senecameeks May 26, 2023
908ee77
addresses PR comments by adding empty checks. Applys formatter. Subse…
senecameeks Jun 1, 2023
c903d32
nit formatting changes, add docustring with input/output for remove_q…
senecameeks Jun 1, 2023
b92d6d8
Merge branch 'master' of https://github.com/quantumlib/cirq
senecameeks Jun 2, 2023
70162d7
Merge branch 'master' into master
tanujkhattar Jun 5, 2023
530a69e
merge this branch and tanujkhattar@ccde689
senecameeks Jun 6, 2023
1be80c8
merging branches, adding test coverage in next push
senecameeks Jun 6, 2023
2b06182
Merge branch 'master' of https://github.com/quantumlib/cirq
senecameeks Jun 6, 2023
4adf75b
Merge branch 'master' of github.com:senecameeks/Cirq
senecameeks Jun 6, 2023
f70753b
format files
senecameeks Jun 6, 2023
f74f760
add coverage tests
senecameeks Jun 6, 2023
5d31ce3
change assert
senecameeks Jun 7, 2023
096fc14
coverage and type check tests should pass
senecameeks Jun 7, 2023
964dc69
incorporate tanujkhattar@1db8ac5
senecameeks Jun 7, 2023
40d5b33
nit
senecameeks Jun 7, 2023
0ab07c2
Merge branch 'master' into master
tanujkhattar Jun 7, 2023
40369ee
remove block comment
senecameeks Jun 7, 2023
ddd6fd9
Merge branch 'master' of github.com:senecameeks/Cirq
senecameeks Jun 7, 2023
774d715
add coverage
senecameeks Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def act_on(

arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
qubits = action.qubits if is_op else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
if result is True:
return
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
43 changes: 41 additions & 2 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,38 @@ def create_merged_state(self) -> Self:
"""Creates a final merged state."""
return self

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> None:
"""Add qubits to a new state space and take the kron product.

Note that only Density Matrix and State Vector simulators
override this function.

Args:
qubits: Sequence of qubits to be added.

Returns:
NotImplemented: If the subclass does not implement this method.

Raises:
ValueError: If a qubit being added is already tracked.
"""
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Remove qubits from the state space.

Args:
qubits: Sequence of qubits to be added.

Returns:
A new Simulation State with qubits removed. Or
`self` if there are no qubits to remove."""
if qubits is None or not qubits:
return self
return self.factor(qubits, inplace=True)[1]

def kronecker_product(self, other: Self, *, inplace=False) -> Self:
"""Joins two state spaces together."""
args = self if inplace else copy.copy(self)
Expand Down Expand Up @@ -294,13 +326,20 @@ def strat_act_on_from_apply_decompose(
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
assert len(qubits1) == len(qubits)
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
if operations is None:
return NotImplemented
assert len(qubits1) == len(qubits)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map.update(dict(zip(qubits1, qubits)))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
protocols.act_on(operation, args)
args.remove_qubits(new_ancilla)
return True


Expand Down
154 changes: 150 additions & 4 deletions cirq-core/cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
from cirq.sim import simulation_state
from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla


class DummyQuantumState(cirq.QuantumStateRepresentation):
Expand All @@ -31,17 +32,79 @@ def measure(self, axes, seed=None):
def reindex(self, axes):
return self

def kron(self, other):
return self

def factor(self, axes, validate=True, atol=1e-07):
return (self, self)


class DummySimulationState(cirq.SimulationState):
def __init__(self):
super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2))
def __init__(self, qubits=cirq.LineQubit.range(2)):
super().__init__(state=DummyQuantumState(), qubits=qubits)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
return True


class AncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)


class AncillaH(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.H(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.H(ancilla) ** self._exponent


class AncillaY(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.Y(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.Y(ancilla) ** self._exponent


class DelegatingAncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
a = cirq.NamedQubit('a')
yield cirq.CX(qubits[0], a)
yield AncillaZ(self._exponent).on(a)
yield cirq.CX(qubits[0], a)


def test_measurements():
args = DummySimulationState()
args.measure([cirq.LineQubit(0)], "test", [False], {})
Expand All @@ -57,8 +120,9 @@ def _decompose_(self, qubits):
yield cirq.X(*qubits)

args = DummySimulationState()
assert simulation_state.strat_act_on_from_apply_decompose(
Composite(), args, [cirq.LineQubit(0)]
assert (
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
is NotImplemented
)


Expand Down Expand Up @@ -101,3 +165,85 @@ def test_field_getters():
args = DummySimulationState()
assert args.prng is np.random
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_z(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_y(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaY(exp).on(q))

control_circuit = cirq.Circuit(cirq.Y(q))
control_circuit.append(cirq.Y(q))
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_borrowable_qubit(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(cirq.X(q))
test_circuit.append(AncillaH(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_delegating_gate_qubit(exp):
q = cirq.LineQubit(0)

test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
def test_phase_using_dirty_ancilla(num_ancilla: int):
q = cirq.LineQubit(0)
anc = cirq.NamedQubit.range(num_ancilla, prefix='anc')

u = cirq.MatrixGate(cirq.testing.random_unitary(2 ** (num_ancilla + 1)))
test_circuit = cirq.Circuit(
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
q = cirq.LineQubit(0)
u = cirq.MatrixGate(cirq.testing.random_unitary(2))
test_circuit = cirq.Circuit(
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
test_sim = eval(test_simulator)(test_circuit)
control_sim = eval(test_simulator)(control_circuit)
cirq.testing.assert_allclose_up_to_global_phase(test_sim, control_sim, atol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@daxfohl Do you know if we mess up the global phase as part of simulations ? The factor / kron methods are messing up global phase since if I replace PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q) with cirq.MatrixGate(cirq.unitary(PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla))).on(q); the tests pass using assert np.allclose(test_sim, control_sim) instead of
cirq.testing.assert_allclose_up_to_global_phase(test_sim, control_sim, atol=1e-6).

Ideally, the resulting state vectors and density matrices should be identical and not equal upto global phase, but the factor / kron (or some other interaction of these methods in this PR) seems to mess up the global phase here.

Copy link
Collaborator

@tanujkhattar tanujkhattar Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so what is happening is as follows:

state_to_factor = (0.63+0.33j)|00⟩ + (-0.2+0.67j)|10⟩
# StateVectorSimulationState factors the above state into e, r s.t. 
e= (0.88+0.47j)|0⟩
r= 0.71|0⟩ + (0.14+0.69j)|1⟩
# Now, we just ignore `e` and return `r`; which has the wrong global phase. 

@daxfohl Do we not encounter this problem when factoring out measured / reset qubits when we construct SimulationProductState ? I guess we can just multiply the extra phase to correct it, but curious if this is known pattern elsewhere as well? Seems very relevant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I guess we never really "discard" qubits when constructing a SimulationProductState from unentangled states, and therefore this issue doesn't arise?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a fix in tanujkhattar@1db8ac5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I fixed this in #5847. What is the difference in this scenario?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Probably the linalg functions should take an extra bool parameter to force all phase into the reminder, and verify that the extracted axes are left in a classical basis state. All use cases of those functions, that's what they're ultimately trying to do.

8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down