From 846916a7c419072ebb5580000b4cf5d0c808eed5 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 26 Aug 2021 12:24:19 -0700 Subject: [PATCH 1/6] Change (gate, qubits) to GateOperation before act_on_fallback --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 7 ++----- cirq-core/cirq/protocols/act_on_protocol.py | 6 ++++-- cirq-core/cirq/protocols/act_on_protocol_test.py | 8 +++----- cirq-core/cirq/sim/act_on_args_container.py | 14 ++++++-------- cirq-core/cirq/sim/act_on_args_container_test.py | 5 ++--- cirq-core/cirq/sim/act_on_args_test.py | 3 +-- cirq-core/cirq/sim/act_on_density_matrix_args.py | 9 +++------ cirq-core/cirq/sim/act_on_state_vector_args.py | 7 +++---- .../sim/clifford/act_on_clifford_tableau_args.py | 5 ++--- .../sim/clifford/act_on_stabilizer_ch_form_args.py | 5 ++--- cirq-core/cirq/sim/operation_target.py | 6 ++---- cirq-core/cirq/sim/simulator_base_test.py | 3 +-- 12 files changed, 31 insertions(+), 47 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 6ad56717e13..8b684409ad1 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -448,14 +448,11 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> bool: """Delegates the action to self.apply_op""" - if isinstance(action, ops.Gate): - action = ops.GateOperation(action, qubits) - return self.apply_op(action, self.prng) + return self.apply_op(op, self.prng) def estimation_stats(self): """Returns some statistics about the memory usage and quality of the approximation.""" diff --git a/cirq-core/cirq/protocols/act_on_protocol.py b/cirq-core/cirq/protocols/act_on_protocol.py index be6d8047e10..38c2b082e38 100644 --- a/cirq-core/cirq/protocols/act_on_protocol.py +++ b/cirq-core/cirq/protocols/act_on_protocol.py @@ -150,10 +150,12 @@ def act_on( f'{result!r} from {action!r}._act_on_' ) + if isinstance(action, ops.Gate) and qubits is not None: + action = action.on(*qubits) + arg_fallback = getattr(args, '_act_on_fallback_', None) if arg_fallback is not None: - qubits = action.qubits if isinstance(action, ops.Operation) else qubits - result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose) + result = arg_fallback(action, allow_decompose=allow_decompose) if result is True: return if result is not NotImplemented: diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index 7e287293725..d0a06708ab2 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -36,8 +36,7 @@ def copy(self): def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: cirq.Operation, allow_decompose: bool = True, ): return self.fallback_result @@ -87,11 +86,10 @@ def test_act_on_args_axes_deprecation(): class Args(DummyActOnArgs): def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'] = None, + op: cirq.Operation, allow_decompose: bool = True, ) -> bool: - self.measurements.append(qubits) + self.measurements.append(op.qubits) return True args = Args() diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index cdbe3bc5871..f8573d28f8e 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -82,17 +82,16 @@ def create_merged_state(self) -> TActOnArgs: def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> bool: - gate = action.gate if isinstance(action, ops.Operation) else action + gate = op.gate if isinstance(gate, ops.IdentityGate): return True if isinstance(gate, ops.SwapPowGate) and gate.exponent % 2 == 1 and gate.global_shift == 0: - q0, q1 = qubits + q0, q1 = op.qubits args0 = self.args[q0] args1 = self.args[q1] if args0 is args1: @@ -105,7 +104,7 @@ def _act_on_fallback_( # Go through the op's qubits and join any disparate ActOnArgs states # into a new combined state. op_args_opt: Optional[TActOnArgs] = None - for q in qubits: + for q in op.qubits: if op_args_opt is None: op_args_opt = self.args[q] elif q not in op_args_opt.qubits: @@ -117,14 +116,13 @@ def _act_on_fallback_( self.args[q] = op_args # Act on the args with the operation - act_on_qubits = qubits if isinstance(action, ops.Gate) else None - protocols.act_on(action, op_args, act_on_qubits, allow_decompose=allow_decompose) + protocols.act_on(op, op_args, allow_decompose=allow_decompose) # Decouple any measurements or resets if self.split_untangled_states and isinstance( gate, (ops.MeasurementGate, ops.ResetChannel) ): - for q in qubits: + for q in op.qubits: q_args, op_args = op_args.factor((q,), validate=False) self.args[q] = q_args diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index befea612531..ba5b69bf31b 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -34,8 +34,7 @@ def copy(self) -> 'EmptyActOnArgs': def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: cirq.Operation, allow_decompose: bool = True, ) -> bool: return True @@ -135,7 +134,7 @@ def test_identity_does_not_join(): def test_identity_fallback_does_not_join(): args = create_container(qs2) assert len(set(args.values())) == 3 - args._act_on_fallback_(cirq.I, (q0, q1)) + args._act_on_fallback_(cirq.IdentityGate(2)(q0, q1)) assert len(set(args.values())) == 3 assert args[q0] is not args[q1] assert args[q0] is not args[None] diff --git a/cirq-core/cirq/sim/act_on_args_test.py b/cirq-core/cirq/sim/act_on_args_test.py index 7de06275d82..c92fcb96c99 100644 --- a/cirq-core/cirq/sim/act_on_args_test.py +++ b/cirq-core/cirq/sim/act_on_args_test.py @@ -34,8 +34,7 @@ def _perform_measurement(self, qubits): def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: cirq.Operation, allow_decompose: bool = True, ) -> bool: return True diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index e26e6682fe6..1b19b07d0cf 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -94,8 +94,7 @@ def __init__( def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> bool: strats = [ @@ -106,7 +105,7 @@ def _act_on_fallback_( # Try each strategy, stopping if one works. for strat in strats: - result = strat(action, self, qubits) + result = strat(op, self, op.qubits) if result is False: break # coverage: ignore if result is True: @@ -115,9 +114,7 @@ def _act_on_fallback_( raise TypeError( "Can't simulate operations that don't implement " "SupportsUnitary, SupportsConsistentApplyUnitary, " - "SupportsMixture, SupportsChannel or SupportsKraus or is a measurement: {!r}".format( - action - ) + "SupportsMixture, SupportsChannel or SupportsKraus or is a measurement: {!r}".format(op) ) def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 80694f326f5..ca98d0b9c88 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -161,8 +161,7 @@ def subspace_index( def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> bool: strats = [ @@ -175,7 +174,7 @@ def _act_on_fallback_( # Try each strategy, stopping if one works. for strat in strats: - result = strat(action, self, qubits) + result = strat(op, self, op.qubits) if result is False: break # coverage: ignore if result is True: @@ -184,7 +183,7 @@ def _act_on_fallback_( raise TypeError( "Can't simulate operations that don't implement " "SupportsUnitary, SupportsConsistentApplyUnitary, " - "SupportsMixture or is a measurement: {!r}".format(action) + "SupportsMixture or is a measurement: {!r}".format(op) ) def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 4e073ae790e..5a025e26843 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -87,15 +87,14 @@ def __init__( def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> Union[bool, NotImplementedType]: strats = [] if allow_decompose: strats.append(_strat_act_on_clifford_tableau_from_single_qubit_decompose) for strat in strats: - result = strat(action, self, qubits) + result = strat(op, self, op.qubits) if result is False: break # coverage: ignore if result is True: diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index a89dd8f61c4..86702489e15 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -85,15 +85,14 @@ def __init__( def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> Union[bool, NotImplementedType]: strats = [] if allow_decompose: strats.append(_strat_act_on_stabilizer_ch_form_from_single_qubit_decompose) for strat in strats: - result = strat(action, self, qubits) + result = strat(op, self, op.qubits) if result is True: return True assert result is NotImplemented, str(result) diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 0b74abd75fa..08151418897 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -50,15 +50,13 @@ def create_merged_state(self) -> TActOnArgs: @abc.abstractmethod def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: 'cirq.Operation', allow_decompose: bool = True, ) -> Union[bool, NotImplementedType]: """Handles the act_on protocol fallback implementation. Args: - action: Either a gate or an operation to act on. - qubits: The applicable qubits if a gate is passed as the action. + op: An operation to act on. allow_decompose: Flag to allow decomposition. Returns: diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index f3fe26a1b32..e7a96e81587 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -47,8 +47,7 @@ def copy(self) -> 'CountingActOnArgs': def _act_on_fallback_( self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], + op: Union['cirq.Operation', 'cirq.Gate'], allow_decompose: bool = True, ) -> bool: self.gate_count += 1 From d93fc5c05ae2d0d0ff508dcf51c8c086d7f59812 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 26 Aug 2021 14:10:46 -0700 Subject: [PATCH 2/6] Unit tests --- cirq-core/cirq/ops/common_gates_test.py | 19 +++++++++++-------- cirq-core/cirq/protocols/act_on_protocol.py | 11 ++++++++++- .../cirq/protocols/act_on_protocol_test.py | 4 ++-- .../sim/act_on_density_matrix_args_test.py | 4 +++- .../cirq/sim/act_on_state_vector_args_test.py | 4 +++- .../act_on_clifford_tableau_args_test.py | 4 +++- 6 files changed, 32 insertions(+), 14 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index a3475474028..a524addd14b 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -596,12 +596,15 @@ def test_cz_act_on_equivalent_to_h_cx_h_tableau(): ) def test_act_on_ch_form(input_gate_sequence, outcome): original_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=31) - num_qubits = cirq.num_qubits(input_gate_sequence[0]) - if num_qubits == 1: - qubits = [cirq.LineQubit(1)] - else: - assert num_qubits == 2 - qubits = cirq.LineQubit.range(2) + + def qubits(gate): + num_qubits = cirq.num_qubits(gate) + if num_qubits == 1: + return [cirq.LineQubit(1)] + else: + assert num_qubits == 2 + return cirq.LineQubit.range(2) + args = cirq.ActOnStabilizerCHFormArgs( state=original_state.copy(), qubits=cirq.LineQubit.range(2), @@ -614,11 +617,11 @@ def test_act_on_ch_form(input_gate_sequence, outcome): if outcome == 'Error': with pytest.raises(TypeError, match="Failed to act action on state"): for input_gate in input_gate_sequence: - cirq.act_on(input_gate, args, qubits) + cirq.act_on(input_gate, args, qubits(input_gate)) return for input_gate in input_gate_sequence: - cirq.act_on(input_gate, args, qubits) + cirq.act_on(input_gate, args, qubits(input_gate)) if outcome == 'Original': np.testing.assert_allclose(args.state.state_vector(), original_state.state_vector()) diff --git a/cirq-core/cirq/protocols/act_on_protocol.py b/cirq-core/cirq/protocols/act_on_protocol.py index 38c2b082e38..1e5d31c8b02 100644 --- a/cirq-core/cirq/protocols/act_on_protocol.py +++ b/cirq-core/cirq/protocols/act_on_protocol.py @@ -151,7 +151,16 @@ def act_on( ) if isinstance(action, ops.Gate) and qubits is not None: - action = action.on(*qubits) + try: + action = action.on(*qubits) + except ValueError: + raise TypeError( + "Failed to act action on state argument.\n" + "Tried action._act_on_ but gate can't be applied to the qubits.\n" + "\n" + f"Gate: {action}\n" + f"Qubits: {qubits}\n" + ) arg_fallback = getattr(args, '_act_on_fallback_', None) if arg_fallback is not None: diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index d0a06708ab2..240c93f9c19 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -101,8 +101,8 @@ def _act_on_fallback_( with cirq.testing.assert_deprecated( "ActOnArgs.axes", "Use `protocols.act_on` instead.", deadline="v0.13" ): - cirq.act_on(object(), args) # type: ignore - assert args.measurements == [[cirq.LineQubit(1)]] + with pytest.raises(AttributeError, match="object has no attribute 'qubits'"): + cirq.act_on(object(), args) # type: ignore def test_qubits_not_allowed_for_operations(): diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py index 38b68aca350..073429f9a5e 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py @@ -47,7 +47,9 @@ def _decompose_(self, qubits): def test_cannot_act(): class NoDetails: - pass + @property + def qubits(self): + return [] qid_shape = (2,) tensor = cirq.to_valid_density_matrix( diff --git a/cirq-core/cirq/sim/act_on_state_vector_args_test.py b/cirq-core/cirq/sim/act_on_state_vector_args_test.py index b8bfb30dfc8..a74b5c84a68 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args_test.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args_test.py @@ -44,7 +44,9 @@ def _decompose_(self, qubits): def test_cannot_act(): class NoDetails: - pass + @property + def qubits(self): + return [] args = cirq.ActOnStateVectorArgs( target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64), diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py index ce53f6a9d56..5752491c76c 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py @@ -65,7 +65,9 @@ def _unitary_(self): def test_cannot_act(): class NoDetails: - pass + @property + def qubits(self): + return [] class NoDetailsSingleQubitGate(cirq.SingleQubitGate): pass From 6953e5d33b628f49677748d7733738c56a741433 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 26 Aug 2021 14:22:53 -0700 Subject: [PATCH 3/6] nit --- cirq-core/cirq/sim/simulator_base_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index e7a96e81587..eeb3397916e 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -47,7 +47,7 @@ def copy(self) -> 'CountingActOnArgs': def _act_on_fallback_( self, - op: Union['cirq.Operation', 'cirq.Gate'], + op: cirq.Operation, allow_decompose: bool = True, ) -> bool: self.gate_count += 1 From c33fdbcfbe7d36e371a488c8253bba0b456ab9d2 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 26 Aug 2021 14:42:49 -0700 Subject: [PATCH 4/6] lint --- cirq-core/cirq/protocols/act_on_protocol_test.py | 2 +- cirq-core/cirq/sim/act_on_args_container.py | 1 - cirq-core/cirq/sim/act_on_args_container_test.py | 2 +- cirq-core/cirq/sim/act_on_args_test.py | 1 - cirq-core/cirq/sim/act_on_density_matrix_args.py | 4 ++-- cirq-core/cirq/sim/simulator_base_test.py | 2 +- 6 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index 240c93f9c19..13b83b01246 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple, Union, Sequence +from typing import Any, Tuple import numpy as np import pytest diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index f8573d28f8e..fea5fd2406d 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -23,7 +23,6 @@ Any, Tuple, List, - Union, ) import numpy as np diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index ba5b69bf31b..8a45138952e 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict, Any, Sequence, Tuple, Optional, Union +from typing import List, Dict, Any, Sequence, Tuple, Optional import cirq diff --git a/cirq-core/cirq/sim/act_on_args_test.py b/cirq-core/cirq/sim/act_on_args_test.py index c92fcb96c99..e6e233a002a 100644 --- a/cirq-core/cirq/sim/act_on_args_test.py +++ b/cirq-core/cirq/sim/act_on_args_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union import pytest diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 1b19b07d0cf..a5947786e73 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -13,14 +13,14 @@ # limitations under the License. """Objects and methods for acting efficiently on a density matrix.""" -from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Iterable, Union +from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Iterable import numpy as np from cirq import protocols, sim from cirq._compat import deprecated_parameter -from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose from cirq.linalg import transformations +from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index eeb3397916e..6763918c780 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Dict, Any, Sequence, Tuple, Union +from typing import List, Dict, Any, Sequence, Tuple import numpy as np import pytest From 981f5fc6bf4f536f024353afaa6e2ff9bd03f4fa Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Wed, 1 Sep 2021 07:16:20 -0700 Subject: [PATCH 5/6] strats --- cirq-core/cirq/sim/act_on_args.py | 7 ++-- cirq-core/cirq/sim/act_on_args_test.py | 2 +- .../cirq/sim/act_on_density_matrix_args.py | 10 +++--- .../cirq/sim/act_on_state_vector_args.py | 33 +++++++++---------- .../clifford/act_on_clifford_tableau_args.py | 11 ++++--- .../act_on_stabilizer_ch_form_args.py | 11 ++++--- 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 87cbe783413..30071ddda78 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -228,13 +228,12 @@ def axes(self, value: Iterable[int]): def strat_act_on_from_apply_decompose( - val: Any, + val: 'cirq.Operation', args: ActOnArgs, - qubits: Sequence['cirq.Qid'], ) -> bool: operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) - assert len(qubits1) == len(qubits) - qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} + assert len(qubits1) == len(val.qubits) + qubit_map = {q: val.qubits[i] for i, q in enumerate(qubits1)} if operations is None: return NotImplemented for operation in operations: diff --git a/cirq-core/cirq/sim/act_on_args_test.py b/cirq-core/cirq/sim/act_on_args_test.py index e6e233a002a..a1667961f8b 100644 --- a/cirq-core/cirq/sim/act_on_args_test.py +++ b/cirq-core/cirq/sim/act_on_args_test.py @@ -54,7 +54,7 @@ def _decompose_(self, qubits): yield cirq.X(*qubits) args = DummyArgs() - assert act_on_args.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)]) + assert act_on_args.strat_act_on_from_apply_decompose(Composite().on(cirq.LineQubit(0)), args) def test_mapping(): diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index a5947786e73..2441613f891 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -105,7 +105,7 @@ def _act_on_fallback_( # Try each strategy, stopping if one works. for strat in strats: - result = strat(op, self, op.qubits) + result = strat(op, self) if result is False: break # coverage: ignore if result is True: @@ -217,13 +217,11 @@ def sample( ) -def _strat_apply_channel_to_state( - action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid'] -) -> bool: +def _strat_apply_channel_to_state(op: 'cirq.Operation', args: ActOnDensityMatrixArgs) -> bool: """Apply channel to state.""" - axes = args.get_axes(qubits) + axes = args.get_axes(op.qubits) result = protocols.apply_channel( - action, + op, args=protocols.ApplyChannelArgs( target_tensor=args.target_tensor, out_buffer=args.available_buffer[0], diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index b1d4118231d..ea1221cf79c 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -177,7 +177,7 @@ def _act_on_fallback_( # Try each strategy, stopping if one works. for strat in strats: - result = strat(op, self, op.qubits) + result = strat(op, self) if result is False: break # coverage: ignore if result is True: @@ -278,16 +278,15 @@ def sample( def _strat_act_on_state_vector_from_apply_unitary( - unitary_value: Any, + op: 'cirq.Operation', args: 'cirq.ActOnStateVectorArgs', - qubits: Sequence['cirq.Qid'], ) -> bool: new_target_tensor = protocols.apply_unitary( - unitary_value, + op, protocols.ApplyUnitaryArgs( target_tensor=args.target_tensor, available_buffer=args.available_buffer, - axes=args.get_axes(qubits), + axes=args.get_axes(op.qubits), ), allow_decompose=False, default=NotImplemented, @@ -299,30 +298,30 @@ def _strat_act_on_state_vector_from_apply_unitary( def _strat_act_on_state_vector_from_mixture( - action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] + op: 'cirq.Operation', args: 'cirq.ActOnStateVectorArgs' ) -> bool: - mixture = protocols.mixture(action, default=None) + mixture = protocols.mixture(op, default=None) if mixture is None: return NotImplemented probabilities, unitaries = zip(*mixture) index = args.prng.choice(range(len(unitaries)), p=probabilities) - shape = protocols.qid_shape(action) * 2 + shape = protocols.qid_shape(op) * 2 unitary = unitaries[index].astype(args.target_tensor.dtype).reshape(shape) linalg.targeted_left_multiply( - unitary, args.target_tensor, args.get_axes(qubits), out=args.available_buffer + unitary, args.target_tensor, args.get_axes(op.qubits), out=args.available_buffer ) args.swap_target_tensor_for(args.available_buffer) - if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) + if protocols.is_measurement(op): + key = protocols.measurement_key_name(op) args.log_of_measurement_results[key] = [index] return True def _strat_act_on_state_vector_from_channel( - action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] + op: 'cirq.Operation', args: 'cirq.ActOnStateVectorArgs' ) -> bool: - kraus_operators = protocols.kraus(action, default=None) + kraus_operators = protocols.kraus(op, default=None) if kraus_operators is None: return NotImplemented @@ -330,11 +329,11 @@ def prepare_into_buffer(k: int): linalg.targeted_left_multiply( left_matrix=kraus_tensors[k], right_target=args.target_tensor, - target_axes=args.get_axes(qubits), + target_axes=args.get_axes(op.qubits), out=args.available_buffer, ) - shape = protocols.qid_shape(action) + shape = protocols.qid_shape(op) kraus_tensors = [e.reshape(shape * 2).astype(args.target_tensor.dtype) for e in kraus_operators] p = args.prng.random() weight = None @@ -362,7 +361,7 @@ def prepare_into_buffer(k: int): args.available_buffer /= np.sqrt(weight) args.swap_target_tensor_for(args.available_buffer) - if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) + if protocols.is_measurement(op): + key = protocols.measurement_key_name(op) args.log_of_measurement_results[key] = [index] return True diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 5a025e26843..5ad1df43acb 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -94,7 +94,7 @@ def _act_on_fallback_( if allow_decompose: strats.append(_strat_act_on_clifford_tableau_from_single_qubit_decompose) for strat in strats: - result = strat(op, self, op.qubits) + result = strat(op, self) if result is False: break # coverage: ignore if result is True: @@ -126,14 +126,15 @@ def sample( def _strat_act_on_clifford_tableau_from_single_qubit_decompose( - val: Any, args: 'cirq.ActOnCliffordTableauArgs', qubits: Sequence['cirq.Qid'] + op: 'cirq.Operation', args: 'cirq.ActOnCliffordTableauArgs' ) -> bool: - if num_qubits(val) == 1: - if not has_unitary(val): + if num_qubits(op) == 1: + if not has_unitary(op): return NotImplemented - u = unitary(val) + u = unitary(op) clifford_gate = SingleQubitCliffordGate.from_unitary(u) if clifford_gate is not None: + qubits = op.qubits for axis, quarter_turns in clifford_gate.decompose_rotation(): if axis == pauli_gates.X: common_gates.XPowGate(exponent=quarter_turns / 2)._act_on_(args, qubits) diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 86702489e15..26d3de2f055 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -92,7 +92,7 @@ def _act_on_fallback_( if allow_decompose: strats.append(_strat_act_on_stabilizer_ch_form_from_single_qubit_decompose) for strat in strats: - result = strat(op, self, op.qubits) + result = strat(op, self) if result is True: return True assert result is NotImplemented, str(result) @@ -128,17 +128,18 @@ def sample( def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose( - val: Any, args: 'cirq.ActOnStabilizerCHFormArgs', qubits: Sequence['cirq.Qid'] + op: 'cirq.Operation', args: 'cirq.ActOnStabilizerCHFormArgs' ) -> bool: - if num_qubits(val) == 1: - if not has_unitary(val): + if num_qubits(op) == 1: + if not has_unitary(op): return NotImplemented - u = unitary(val) + u = unitary(op) clifford_gate = SingleQubitCliffordGate.from_unitary(u) if clifford_gate is not None: # Gather the effective unitary applied so as to correct for the # global phase later. final_unitary = np.eye(2) + qubits = op.qubits for axis, quarter_turns in clifford_gate.decompose_rotation(): gate = None # type: Optional[cirq.Gate] if axis == pauli_gates.X: From 1952a56de2c471c46250008098c19b36b0e44d28 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 23 Sep 2021 13:40:17 -0700 Subject: [PATCH 6/6] cleanup --- cirq-core/cirq/sim/act_on_state_vector_args.py | 2 +- .../cirq/sim/clifford/act_on_clifford_tableau_args_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index ea1221cf79c..913ed2c3a1c 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -173,7 +173,7 @@ def _act_on_fallback_( _strat_act_on_state_vector_from_channel, ] if allow_decompose: - strats.append(strat_act_on_from_apply_decompose) + strats.append(strat_act_on_from_apply_decompose) # type: ignore # Try each strategy, stopping if one works. for strat in strats: diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py index 5752491c76c..f2ecd96bdc0 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py @@ -67,7 +67,7 @@ def test_cannot_act(): class NoDetails: @property def qubits(self): - return [] + pass class NoDetailsSingleQubitGate(cirq.SingleQubitGate): pass