Skip to content

Commit

Permalink
Fix phase in factor (quantumlib#5847)
Browse files Browse the repository at this point in the history
Fixes quantumlib#5834

Phase in input tensor was being allocated to both output tensors. This PR readjusts to remove the phase from the remainder tensor (a nice thing about this is that if the remainder _is_ nothing but a global phase, then it's `1`). 

The validation is updated to check `np.all_close` rather than `allclose_up_to_global_phase` (I also ran all simulator unit tests with "validate=True" locally and they all passed with these changes), and a unit test for quantumlib#5834 is added.

Also added check that we aren't factoring out the last qubit during simulation and losing the remaining phase. This isn't strictly necessary since the remainder is guaranteed to be `1`, but prevent any surprises if that changes (and may as well skip it anyway for perf sake).
  • Loading branch information
daxfohl authored and rht committed May 1, 2023
1 parent 0dfd7c6 commit 1056cae
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 5 deletions.
6 changes: 3 additions & 3 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,12 @@ def factor_state_vector(
slices1 = (slice(None),) * n_axes + pivot[n_axes:]
slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes)
extracted = t1[slices1]
extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5
extracted = extracted / np.linalg.norm(extracted)
remainder = t1[slices2]
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
remainder = remainder / (np.linalg.norm(remainder) * t1[pivot] / abs(t1[pivot]))
if validate:
t2 = state_vector_kronecker_product(extracted, remainder)
if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol):
if not np.allclose(t2, t1, atol=atol):
if not np.isclose(np.linalg.norm(t1), 1):
raise ValueError('Input state must be normalized.')
raise EntangledStateError('The tensor cannot be factored by the requested axes')
Expand Down
19 changes: 19 additions & 0 deletions cirq-core/cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,22 @@ def test_default_tolerance():
# Here, we do NOT specify the default tolerance. It is merely to check that the default value
# is reasonable.
cirq.sub_state_vector(final_state_vector, [0])


@pytest.mark.parametrize('state_1', [0, 1])
@pytest.mark.parametrize('state_2', [0, 1])
def test_factor_state_vector(state_1: int, state_2: int):
# Kron two state vectors and apply a phase. Factoring should produce the expected results.
n = 12
for i in range(n):
phase = np.exp(2 * np.pi * 1j * i / n)
a = cirq.to_valid_state_vector(state_1, 1)
b = cirq.to_valid_state_vector(state_2, 1)
c = cirq.linalg.transformations.state_vector_kronecker_product(a, b) * phase
a1, b1 = cirq.linalg.transformations.factor_state_vector(c, [0], validate=True)
c1 = cirq.linalg.transformations.state_vector_kronecker_product(a1, b1)
assert np.allclose(c, c1)

# All phase goes into a1, and b1 is just the dephased state vector
assert np.allclose(a1, a * phase)
assert np.allclose(b1, b)
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/simulation_product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _act_on_fallback_(
gate_opt, (ops.ResetChannel, ops.MeasurementGate)
):
for q in qubits:
if op_args.allows_factoring:
if op_args.allows_factoring and len(op_args.qubits) > 1:
q_args, op_args = op_args.factor((q,), validate=False)
self._sim_states[q] = q_args

Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,3 +1434,19 @@ def test_unseparated_states_str():
qubits: (cirq.LineQubit(0), cirq.LineQubit(1))
output vector: 0.707j|00⟩ + 0.707j|10⟩"""
)


@pytest.mark.parametrize('split', [True, False])
def test_measurement_preserves_phase(split: bool):
c1, c2, t = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.H(t),
cirq.measure(t, key='t'),
cirq.CZ(c1, c2).with_classical_controls('t'),
cirq.reset(t),
)
simulator = cirq.Simulator(split_untangled_states=split)
# Run enough times that both options of |110> - |111> are likely measured.
for _ in range(20):
result = simulator.simulate(circuit, initial_state=(1, 1, 1), qubit_order=(c1, c2, t))
assert result.dirac_notation() == '|110⟩'
2 changes: 1 addition & 1 deletion docs/experiments/textbook_algorithms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
"print(np.round(bobs_bloch_vector, 3))\n",
"\n",
"# Verify they are the same state!\n",
"np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-7)"
"np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-6)"
]
},
{
Expand Down

0 comments on commit 1056cae

Please sign in to comment.