diff --git a/cirq/sim/density_matrix_simulator_test.py b/cirq/sim/density_matrix_simulator_test.py index 58f1541d549..213932eb408 100644 --- a/cirq/sim/density_matrix_simulator_test.py +++ b/cirq/sim/density_matrix_simulator_test.py @@ -527,6 +527,25 @@ def test_simulate_qudits(dtype: Type[np.number], split: bool): assert len(result.measurements) == 0 +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +@pytest.mark.parametrize('split', [True, False]) +def test_reset_one_qubit_does_not_affect_partial_trace_of_other_qubits( + dtype: Type[np.number], split: bool +): + q0, q1 = cirq.LineQubit.range(2) + simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) + circuit = cirq.Circuit( + cirq.H(q0), + cirq.CX(q0, q1), + cirq.reset(q0), + ) + result = simulator.simulate(circuit) + expected = np.zeros((4, 4), dtype=dtype) + expected[0, 0] = 0.5 + expected[1, 1] = 0.5 + np.testing.assert_almost_equal(result.final_density_matrix, expected) + + @pytest.mark.parametrize( 'dtype,circuit', itertools.product(