From 1056caebc35cf2933cabeb97a2cf331184e447a7 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 1 Sep 2022 16:24:23 -0700 Subject: [PATCH] Fix phase in factor (#5847) Fixes #5834 Phase in input tensor was being allocated to both output tensors. This PR readjusts to remove the phase from the remainder tensor (a nice thing about this is that if the remainder _is_ nothing but a global phase, then it's `1`). The validation is updated to check `np.all_close` rather than `allclose_up_to_global_phase` (I also ran all simulator unit tests with "validate=True" locally and they all passed with these changes), and a unit test for #5834 is added. Also added check that we aren't factoring out the last qubit during simulation and losing the remaining phase. This isn't strictly necessary since the remainder is guaranteed to be `1`, but prevent any surprises if that changes (and may as well skip it anyway for perf sake). --- cirq-core/cirq/linalg/transformations.py | 6 +++--- cirq-core/cirq/linalg/transformations_test.py | 19 +++++++++++++++++++ .../cirq/sim/simulation_product_state.py | 2 +- cirq-core/cirq/sim/sparse_simulator_test.py | 16 ++++++++++++++++ docs/experiments/textbook_algorithms.ipynb | 2 +- 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 1a1b8c257e0..be91bf87688 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -589,12 +589,12 @@ def factor_state_vector( slices1 = (slice(None),) * n_axes + pivot[n_axes:] slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes) extracted = t1[slices1] - extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5 + extracted = extracted / np.linalg.norm(extracted) remainder = t1[slices2] - remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5 + remainder = remainder / (np.linalg.norm(remainder) * t1[pivot] / abs(t1[pivot])) if validate: t2 = state_vector_kronecker_product(extracted, remainder) - if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol): + if not np.allclose(t2, t1, atol=atol): if not np.isclose(np.linalg.norm(t1), 1): raise ValueError('Input state must be normalized.') raise EntangledStateError('The tensor cannot be factored by the requested axes') diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index 6e13fb64d62..b92ff5faa13 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -613,3 +613,22 @@ def test_default_tolerance(): # Here, we do NOT specify the default tolerance. It is merely to check that the default value # is reasonable. cirq.sub_state_vector(final_state_vector, [0]) + + +@pytest.mark.parametrize('state_1', [0, 1]) +@pytest.mark.parametrize('state_2', [0, 1]) +def test_factor_state_vector(state_1: int, state_2: int): + # Kron two state vectors and apply a phase. Factoring should produce the expected results. + n = 12 + for i in range(n): + phase = np.exp(2 * np.pi * 1j * i / n) + a = cirq.to_valid_state_vector(state_1, 1) + b = cirq.to_valid_state_vector(state_2, 1) + c = cirq.linalg.transformations.state_vector_kronecker_product(a, b) * phase + a1, b1 = cirq.linalg.transformations.factor_state_vector(c, [0], validate=True) + c1 = cirq.linalg.transformations.state_vector_kronecker_product(a1, b1) + assert np.allclose(c, c1) + + # All phase goes into a1, and b1 is just the dephased state vector + assert np.allclose(a1, a * phase) + assert np.allclose(b1, b) diff --git a/cirq-core/cirq/sim/simulation_product_state.py b/cirq-core/cirq/sim/simulation_product_state.py index 1338d7ebf7a..c4421b41ec2 100644 --- a/cirq-core/cirq/sim/simulation_product_state.py +++ b/cirq-core/cirq/sim/simulation_product_state.py @@ -122,7 +122,7 @@ def _act_on_fallback_( gate_opt, (ops.ResetChannel, ops.MeasurementGate) ): for q in qubits: - if op_args.allows_factoring: + if op_args.allows_factoring and len(op_args.qubits) > 1: q_args, op_args = op_args.factor((q,), validate=False) self._sim_states[q] = q_args diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 1ebef6b4240..ee16e285313 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -1434,3 +1434,19 @@ def test_unseparated_states_str(): qubits: (cirq.LineQubit(0), cirq.LineQubit(1)) output vector: 0.707j|00⟩ + 0.707j|10⟩""" ) + + +@pytest.mark.parametrize('split', [True, False]) +def test_measurement_preserves_phase(split: bool): + c1, c2, t = cirq.LineQubit.range(3) + circuit = cirq.Circuit( + cirq.H(t), + cirq.measure(t, key='t'), + cirq.CZ(c1, c2).with_classical_controls('t'), + cirq.reset(t), + ) + simulator = cirq.Simulator(split_untangled_states=split) + # Run enough times that both options of |110> - |111> are likely measured. + for _ in range(20): + result = simulator.simulate(circuit, initial_state=(1, 1, 1), qubit_order=(c1, c2, t)) + assert result.dirac_notation() == '|110⟩' diff --git a/docs/experiments/textbook_algorithms.ipynb b/docs/experiments/textbook_algorithms.ipynb index 2287dcc5fad..0949421cf64 100644 --- a/docs/experiments/textbook_algorithms.ipynb +++ b/docs/experiments/textbook_algorithms.ipynb @@ -233,7 +233,7 @@ "print(np.round(bobs_bloch_vector, 3))\n", "\n", "# Verify they are the same state!\n", - "np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-7)" + "np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-6)" ] }, {