From e33942b1e5a2c2df6dc2107578d8746304e225b5 Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Tue, 7 May 2024 16:14:46 -1000 Subject: [PATCH 1/5] enable simulation of controlled gates in classical simulator --- cirq-core/cirq/sim/classical_simulator.py | 19 ++++- .../cirq/sim/classical_simulator_test.py | 81 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator.py b/cirq-core/cirq/sim/classical_simulator.py index a5287637bfc..8d4b4622d48 100644 --- a/cirq-core/cirq/sim/classical_simulator.py +++ b/cirq-core/cirq/sim/classical_simulator.py @@ -117,12 +117,24 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos Raises: ValueError: If initial_state shape for type np.ndarray is not equal to 1. - If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement. + If gate is not one of X, SWAP, a controlled version of X or SWAP, or a measurement. """ if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1: raise ValueError('initial_state shape for type np.ndarray is not equal to 1') gate = action.gate if isinstance(action, ops.Operation) else action mapped_qubits = [self.qubit_map[i] for i in qubits] + + if isinstance(gate, ops.ControlledGate): + control_qubits = mapped_qubits[: gate.num_controls()] + mapped_qubits = mapped_qubits[gate.num_controls() :] + + for c, v in zip(control_qubits, gate.control_values): + if self._state.basis[c] not in v: + # gate has no effect; controls were off + return True + + gate = gate.sub_gate + if _is_identity(gate): pass elif gate == ops.X: @@ -138,7 +150,10 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos c1, c2, q = mapped_qubits self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2] else: - raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement') + raise ValueError( + f'{gate} is not one of X, SWAP; a controlled version ' + 'of X or SWAP; or a measurement' + ) return True diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index 3cf8c170bd8..5d90d531bd9 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -78,6 +78,87 @@ def test_CCNOT(): np.testing.assert_equal(results, expected_results) +def test_CCCX(): + CCCX = cirq.CCNOT.controlled() + qubits = cirq.LineQubit.range(4) + circuit = cirq.Circuit() + + for i in range(8): + not_idxs = [] + tmp = i + for j in range(3): + if tmp & 1: + not_idxs.append(j) + tmp >>= 1 + + circuit.append(cirq.X(qubits[j]) for j in not_idxs) + circuit.append(CCCX(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) + circuit.append(cirq.X(qubits[j]) for j in not_idxs) + + expected_results = { + 'key': np.array( + [ + [ + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 1, 0], + [0, 1, 1, 0], + [1, 1, 1, 1], + ] + ], + dtype=np.uint8, + ) + } + sim = cirq.ClassicalStateSimulator() + results = sim.run(circuit, param_resolver=None, repetitions=1).records + np.testing.assert_equal(results, expected_results) + + +def test_CSWAP(): + CSWAP = cirq.SWAP.controlled() + qubits = cirq.LineQubit.range(3) + circuit = cirq.Circuit() + + for i in range(8): + not_idxs = [] + tmp = i + for j in range(3): + if tmp & 1: + not_idxs.append(j) + tmp >>= 1 + + circuit.append(cirq.X(qubits[j]) for j in not_idxs) + circuit.append(CSWAP(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) + circuit.append(CSWAP(*qubits)) + circuit.append(cirq.X(qubits[j]) for j in not_idxs) + + expected_results = { + 'key': np.array( + [ + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 0, 1], + [0, 0, 1], + [1, 1, 0], + [0, 1, 1], + [1, 1, 1], + ] + ], + dtype=np.uint8, + ) + } + sim = cirq.ClassicalStateSimulator() + results = sim.run(circuit, param_resolver=None, repetitions=1).records + np.testing.assert_equal(results, expected_results) + + def test_measurement_gate(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() From 4b62458fa763f0651bd16f4f97a171e121306e44 Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Fri, 17 May 2024 12:11:02 -1000 Subject: [PATCH 2/5] small fixes following pull request --- cirq-core/cirq/sim/classical_simulator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator.py b/cirq-core/cirq/sim/classical_simulator.py index 8d4b4622d48..f9c8f1e3005 100644 --- a/cirq-core/cirq/sim/classical_simulator.py +++ b/cirq-core/cirq/sim/classical_simulator.py @@ -117,7 +117,8 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos Raises: ValueError: If initial_state shape for type np.ndarray is not equal to 1. - If gate is not one of X, SWAP, a controlled version of X or SWAP, or a measurement. + If gate is not one of X, SWAP, a controlled version of X or SWAP, + or a measurement. """ if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1: raise ValueError('initial_state shape for type np.ndarray is not equal to 1') @@ -129,7 +130,8 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos mapped_qubits = mapped_qubits[gate.num_controls() :] for c, v in zip(control_qubits, gate.control_values): - if self._state.basis[c] not in v: + controls_state = tuple(self._state.basis[c] for c in control_qubits) + if controls_state not in gate.control_values.expand(): # gate has no effect; controls were off return True From 19d5afe6f2fdd3ec9c4f56d730f655109ca1f216 Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Fri, 17 May 2024 12:45:17 -1000 Subject: [PATCH 3/5] improve tests --- .../cirq/sim/classical_simulator_test.py | 85 +++++-------------- 1 file changed, 21 insertions(+), 64 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index 5d90d531bd9..2594bb10c81 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -16,6 +16,7 @@ import pytest import cirq import sympy +from itertools import product def test_x_gate(): @@ -78,85 +79,41 @@ def test_CCNOT(): np.testing.assert_equal(results, expected_results) -def test_CCCX(): +@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)]) +def test_CCCX(initial_state): CCCX = cirq.CCNOT.controlled() qubits = cirq.LineQubit.range(4) - circuit = cirq.Circuit() - for i in range(8): - not_idxs = [] - tmp = i - for j in range(3): - if tmp & 1: - not_idxs.append(j) - tmp >>= 1 + circuit = cirq.Circuit() + circuit.append(CCCX(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) - circuit.append(cirq.X(qubits[j]) for j in not_idxs) - circuit.append(CCCX(*qubits)) - circuit.append(cirq.measure(qubits, key='key')) - circuit.append(cirq.X(qubits[j]) for j in not_idxs) + final_state = initial_state.copy() + final_state[-1] ^= all(final_state[:-1]) - expected_results = { - 'key': np.array( - [ - [ - [0, 0, 0, 0], - [1, 0, 0, 0], - [0, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 1, 0], - [0, 1, 1, 0], - [1, 1, 1, 1], - ] - ], - dtype=np.uint8, - ) - } sim = cirq.ClassicalStateSimulator() - results = sim.run(circuit, param_resolver=None, repetitions=1).records - np.testing.assert_equal(results, expected_results) + results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] + np.testing.assert_equal(results, final_state) -def test_CSWAP(): +@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)]) +def test_CSWAP(initial_state): CSWAP = cirq.SWAP.controlled() qubits = cirq.LineQubit.range(3) circuit = cirq.Circuit() - for i in range(8): - not_idxs = [] - tmp = i - for j in range(3): - if tmp & 1: - not_idxs.append(j) - tmp >>= 1 + circuit = cirq.Circuit() + circuit.append(CSWAP(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) - circuit.append(cirq.X(qubits[j]) for j in not_idxs) - circuit.append(CSWAP(*qubits)) - circuit.append(cirq.measure(qubits, key='key')) - circuit.append(CSWAP(*qubits)) - circuit.append(cirq.X(qubits[j]) for j in not_idxs) + a, b, c = initial_state + if a: + b, c = c, b + final_state = [a, b, c] - expected_results = { - 'key': np.array( - [ - [ - [0, 0, 0], - [1, 0, 0], - [0, 1, 0], - [1, 0, 1], - [0, 0, 1], - [1, 1, 0], - [0, 1, 1], - [1, 1, 1], - ] - ], - dtype=np.uint8, - ) - } sim = cirq.ClassicalStateSimulator() - results = sim.run(circuit, param_resolver=None, repetitions=1).records - np.testing.assert_equal(results, expected_results) + results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] + np.testing.assert_equal(results, final_state) def test_measurement_gate(): From 776fc9d23986ed4ac5932e8aa6845df9b48fc303 Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Mon, 20 May 2024 21:11:21 -1000 Subject: [PATCH 4/5] remove unnecessary for loop --- cirq-core/cirq/sim/classical_simulator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator.py b/cirq-core/cirq/sim/classical_simulator.py index f9c8f1e3005..02879e518a1 100644 --- a/cirq-core/cirq/sim/classical_simulator.py +++ b/cirq-core/cirq/sim/classical_simulator.py @@ -129,11 +129,10 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos control_qubits = mapped_qubits[: gate.num_controls()] mapped_qubits = mapped_qubits[gate.num_controls() :] - for c, v in zip(control_qubits, gate.control_values): - controls_state = tuple(self._state.basis[c] for c in control_qubits) - if controls_state not in gate.control_values.expand(): - # gate has no effect; controls were off - return True + controls_state = tuple(self._state.basis[c] for c in control_qubits) + if controls_state not in gate.control_values.expand(): + # gate has no effect; controls were off + return True gate = gate.sub_gate From b44e9b34cd7fb72dcec25a580746f01413abf401 Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Wed, 22 May 2024 15:11:08 -1000 Subject: [PATCH 5/5] fix linter error --- cirq-core/cirq/sim/classical_simulator_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index 2594bb10c81..96c8a4afdc0 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from itertools import product import numpy as np import pytest import cirq import sympy -from itertools import product def test_x_gate():