From ccde6892b01929036369f4313a85e707fe0338ac Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 5 Jun 2023 17:47:50 -0700 Subject: [PATCH] Fix mypy type issue for SimulationState and allocate new space only for previously unknown ancilla --- cirq-core/cirq/protocols/act_on_protocol.py | 2 +- .../sim/density_matrix_simulation_state.py | 8 ++++++++ cirq-core/cirq/sim/simulation_state.py | 20 +++++++++++-------- .../cirq/sim/state_vector_simulation_state.py | 8 ++++++++ 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/protocols/act_on_protocol.py b/cirq-core/cirq/protocols/act_on_protocol.py index 07c8f95ee005..8ac4875c01c4 100644 --- a/cirq-core/cirq/protocols/act_on_protocol.py +++ b/cirq-core/cirq/protocols/act_on_protocol.py @@ -149,7 +149,7 @@ def act_on( arg_fallback = getattr(sim_state, '_act_on_fallback_', None) if arg_fallback is not None: - qubits = action.qubits if isinstance(action, ops.Operation) else qubits + qubits = action.qubits if is_op else qubits result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose) if result is True: return diff --git a/cirq-core/cirq/sim/density_matrix_simulation_state.py b/cirq-core/cirq/sim/density_matrix_simulation_state.py index 1b7c5fa6e212..f8e0db62d4a9 100644 --- a/cirq-core/cirq/sim/density_matrix_simulation_state.py +++ b/cirq-core/cirq/sim/density_matrix_simulation_state.py @@ -285,6 +285,14 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + def add_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().add_qubits(qubits) + return ( + self.kronecker_product(type(self)(qubits=qubits), inplace=True) + if ret is NotImplemented + else ret + ) + def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True ) -> bool: diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py index 35b324abab06..8958ea75e663 100644 --- a/cirq-core/cirq/sim/simulation_state.py +++ b/cirq-core/cirq/sim/simulation_state.py @@ -168,6 +168,7 @@ def create_merged_state(self) -> Self: def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self: """Add qubits to a new state space and take the kron product. + Note that only subclasses that support `kronecker_product` will support this function. E.g Density Matrix and State Vector simulators. @@ -181,8 +182,9 @@ def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self: """ if qubits is None or not qubits: return self - new_space = type(self)(qubits=qubits) - return self.kronecker_product(new_space, inplace=True) + if any(q in self.qubits for q in qubits): + raise ValueError(f"Qubit to add {qubits} should not already be tracked.") + return NotImplemented def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self: """Remove qubits from the state space. @@ -328,15 +330,17 @@ def strat_act_on_from_apply_decompose( if operations is None: return NotImplemented assert len(qubits1) == len(qubits) - qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} - ancillas = list(set(q for op in operations for q in op.qubits if q not in qubits1)) - for q in ancillas: - qubit_map[q] = q - args.add_qubits(ancillas) + all_qubits = frozenset([q for op in operations for q in op.qubits]) + qubit_map = dict(zip(all_qubits, all_qubits)) + qubit_map |= dict(zip(qubits1, qubits)) + new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits) + args = args.add_qubits(new_ancilla) + if args is NotImplemented: + return NotImplemented for operation in operations: operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) protocols.act_on(operation, args) - args.remove_qubits(ancillas) + args.remove_qubits(new_ancilla) return True diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py index 9a0f547c5a4e..5903a15db79c 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -355,6 +355,14 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + def add_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().add_qubits(qubits) + return ( + self.kronecker_product(type(self)(qubits=qubits), inplace=True) + if ret is NotImplemented + else ret + ) + def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True ) -> bool: