Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change (gate, qubits) to GateOperation before act_on_fallback #4475

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,11 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState):

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: 'cirq.Operation',
allow_decompose: bool = True,
) -> bool:
"""Delegates the action to self.apply_op"""
if isinstance(action, ops.Gate):
action = ops.GateOperation(action, qubits)
return self.apply_op(action, self.prng)
return self.apply_op(op, self.prng)

def estimation_stats(self):
"""Returns some statistics about the memory usage and quality of the approximation."""
Expand Down
19 changes: 11 additions & 8 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,15 @@ def test_cz_act_on_equivalent_to_h_cx_h_tableau():
)
def test_act_on_ch_form(input_gate_sequence, outcome):
original_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=31)
num_qubits = cirq.num_qubits(input_gate_sequence[0])
if num_qubits == 1:
qubits = [cirq.LineQubit(1)]
else:
assert num_qubits == 2
qubits = cirq.LineQubit.range(2)

def qubits(gate):
num_qubits = cirq.num_qubits(gate)
if num_qubits == 1:
return [cirq.LineQubit(1)]
else:
assert num_qubits == 2
return cirq.LineQubit.range(2)

args = cirq.ActOnStabilizerCHFormArgs(
state=original_state.copy(),
qubits=cirq.LineQubit.range(2),
Expand All @@ -614,11 +617,11 @@ def test_act_on_ch_form(input_gate_sequence, outcome):
if outcome == 'Error':
with pytest.raises(TypeError, match="Failed to act action on state"):
for input_gate in input_gate_sequence:
cirq.act_on(input_gate, args, qubits)
cirq.act_on(input_gate, args, qubits(input_gate))
return

for input_gate in input_gate_sequence:
cirq.act_on(input_gate, args, qubits)
cirq.act_on(input_gate, args, qubits(input_gate))

if outcome == 'Original':
np.testing.assert_allclose(args.state.state_vector(), original_state.state_vector())
Expand Down
15 changes: 13 additions & 2 deletions cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,21 @@ def act_on(
f'{result!r} from {action!r}._act_on_'
)

if isinstance(action, ops.Gate) and qubits is not None:
try:
action = action.on(*qubits)
except ValueError:
raise TypeError(
"Failed to act action on state argument.\n"
"Tried action._act_on_ but gate can't be applied to the qubits.\n"
"\n"
f"Gate: {action}\n"
f"Qubits: {qubits}\n"
)

arg_fallback = getattr(args, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
result = arg_fallback(action, allow_decompose=allow_decompose)
if result is True:
return
if result is not NotImplemented:
Expand Down
14 changes: 6 additions & 8 deletions cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, Union, Sequence
from typing import Any, Tuple

import numpy as np
import pytest
Expand All @@ -36,8 +36,7 @@ def copy(self):

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: cirq.Operation,
allow_decompose: bool = True,
):
return self.fallback_result
Expand Down Expand Up @@ -87,11 +86,10 @@ def test_act_on_args_axes_deprecation():
class Args(DummyActOnArgs):
def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'] = None,
op: cirq.Operation,
allow_decompose: bool = True,
) -> bool:
self.measurements.append(qubits)
self.measurements.append(op.qubits)
return True

args = Args()
Expand All @@ -103,8 +101,8 @@ def _act_on_fallback_(
with cirq.testing.assert_deprecated(
"ActOnArgs.axes", "Use `protocols.act_on` instead.", deadline="v0.13"
):
cirq.act_on(object(), args) # type: ignore
assert args.measurements == [[cirq.LineQubit(1)]]
with pytest.raises(AttributeError, match="object has no attribute 'qubits'"):
cirq.act_on(object(), args) # type: ignore


def test_qubits_not_allowed_for_operations():
Expand Down
7 changes: 3 additions & 4 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,12 @@ def axes(self, value: Iterable[int]):


def strat_act_on_from_apply_decompose(
val: Any,
val: 'cirq.Operation',
args: ActOnArgs,
qubits: Sequence['cirq.Qid'],
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
assert len(qubits1) == len(qubits)
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
assert len(qubits1) == len(val.qubits)
qubit_map = {q: val.qubits[i] for i, q in enumerate(qubits1)}
if operations is None:
return NotImplemented
for operation in operations:
Expand Down
15 changes: 6 additions & 9 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Any,
Tuple,
List,
Union,
)

import numpy as np
Expand Down Expand Up @@ -82,17 +81,16 @@ def create_merged_state(self) -> TActOnArgs:

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: 'cirq.Operation',
allow_decompose: bool = True,
) -> bool:
gate = action.gate if isinstance(action, ops.Operation) else action
gate = op.gate

if isinstance(gate, ops.IdentityGate):
return True

if isinstance(gate, ops.SwapPowGate) and gate.exponent % 2 == 1 and gate.global_shift == 0:
q0, q1 = qubits
q0, q1 = op.qubits
args0 = self.args[q0]
args1 = self.args[q1]
if args0 is args1:
Expand All @@ -105,7 +103,7 @@ def _act_on_fallback_(
# Go through the op's qubits and join any disparate ActOnArgs states
# into a new combined state.
op_args_opt: Optional[TActOnArgs] = None
for q in qubits:
for q in op.qubits:
if op_args_opt is None:
op_args_opt = self.args[q]
elif q not in op_args_opt.qubits:
Expand All @@ -117,14 +115,13 @@ def _act_on_fallback_(
self.args[q] = op_args

# Act on the args with the operation
act_on_qubits = qubits if isinstance(action, ops.Gate) else None
protocols.act_on(action, op_args, act_on_qubits, allow_decompose=allow_decompose)
protocols.act_on(op, op_args, allow_decompose=allow_decompose)

# Decouple any measurements or resets
if self.split_untangled_states and isinstance(
gate, (ops.MeasurementGate, ops.ResetChannel)
):
for q in qubits:
for q in op.qubits:
q_args, op_args = op_args.factor((q,), validate=False)
self.args[q] = q_args

Expand Down
7 changes: 3 additions & 4 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any, Sequence, Tuple, Optional, Union
from typing import List, Dict, Any, Sequence, Tuple, Optional

import cirq

Expand All @@ -34,8 +34,7 @@ def copy(self) -> 'EmptyActOnArgs':

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: cirq.Operation,
allow_decompose: bool = True,
) -> bool:
return True
Expand Down Expand Up @@ -118,7 +117,7 @@ def test_identity_does_not_join():
def test_identity_fallback_does_not_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
args._act_on_fallback_(cirq.I, (q0, q1))
args._act_on_fallback_(cirq.IdentityGate(2)(q0, q1))
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
assert args[q0] is not args[None]
Expand Down
6 changes: 2 additions & 4 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Union

import pytest

Expand All @@ -31,8 +30,7 @@ def _perform_measurement(self, qubits):

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: cirq.Operation,
allow_decompose: bool = True,
) -> bool:
return True
Expand All @@ -53,7 +51,7 @@ def _decompose_(self, qubits):
yield cirq.X(*qubits)

args = DummyArgs()
assert act_on_args.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
assert act_on_args.strat_act_on_from_apply_decompose(Composite().on(cirq.LineQubit(0)), args)


def test_mapping():
Expand Down
21 changes: 8 additions & 13 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a density matrix."""

from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Iterable, Union
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Iterable

import numpy as np

from cirq import protocols, sim
from cirq._compat import deprecated_parameter
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose
from cirq.linalg import transformations
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -94,8 +94,7 @@ def __init__(

def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
op: 'cirq.Operation',
allow_decompose: bool = True,
) -> bool:
strats = [
Expand All @@ -106,7 +105,7 @@ def _act_on_fallback_(

# Try each strategy, stopping if one works.
for strat in strats:
result = strat(action, self, qubits)
result = strat(op, self)
if result is False:
break # coverage: ignore
if result is True:
Expand All @@ -115,9 +114,7 @@ def _act_on_fallback_(
raise TypeError(
"Can't simulate operations that don't implement "
"SupportsUnitary, SupportsConsistentApplyUnitary, "
"SupportsMixture, SupportsChannel or SupportsKraus or is a measurement: {!r}".format(
action
)
"SupportsMixture, SupportsChannel or SupportsKraus or is a measurement: {!r}".format(op)
)

def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
Expand Down Expand Up @@ -198,13 +195,11 @@ def sample(
)


def _strat_apply_channel_to_state(
action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid']
) -> bool:
def _strat_apply_channel_to_state(op: 'cirq.Operation', args: ActOnDensityMatrixArgs) -> bool:
"""Apply channel to state."""
axes = args.get_axes(qubits)
axes = args.get_axes(op.qubits)
result = protocols.apply_channel(
action,
op,
args=protocols.ApplyChannelArgs(
target_tensor=args.target_tensor,
out_buffer=args.available_buffer[0],
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/sim/act_on_density_matrix_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def _decompose_(self, qubits):

def test_cannot_act():
class NoDetails:
pass
@property
def qubits(self):
return []

qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
Expand Down
Loading