Skip to content

Commit

Permalink
Fix mypy type issue for SimulationState and allocate new space only f…
Browse files Browse the repository at this point in the history
…or previously unknown ancilla
  • Loading branch information
tanujkhattar committed Jun 6, 2023
1 parent 70162d7 commit ccde689
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def act_on(

arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
qubits = action.qubits if is_op else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
if result is True:
return
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
20 changes: 12 additions & 8 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def create_merged_state(self) -> Self:

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Add qubits to a new state space and take the kron product.
Note that only subclasses that support `kronecker_product`
will support this function. E.g Density Matrix and
State Vector simulators.
Expand All @@ -181,8 +182,9 @@ def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""
if qubits is None or not qubits:
return self
new_space = type(self)(qubits=qubits)
return self.kronecker_product(new_space, inplace=True)
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Remove qubits from the state space.
Expand Down Expand Up @@ -328,15 +330,17 @@ def strat_act_on_from_apply_decompose(
if operations is None:
return NotImplemented
assert len(qubits1) == len(qubits)
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
ancillas = list(set(q for op in operations for q in op.qubits if q not in qubits1))
for q in ancillas:
qubit_map[q] = q
args.add_qubits(ancillas)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map |= dict(zip(qubits1, qubits))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
protocols.act_on(operation, args)
args.remove_qubits(ancillas)
args.remove_qubits(new_ancilla)
return True


Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down

0 comments on commit ccde689

Please sign in to comment.