Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize type annotations in cirq.sim #4773

Merged
merged 3 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import numpy as np

import cirq._version
from cirq import devices, ops, protocols, value, qis
from cirq import devices, ops, protocols, qis
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
Expand Down Expand Up @@ -886,10 +886,10 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)}

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def all_measurement_key_names(self) -> AbstractSet[str]:
Expand Down Expand Up @@ -1537,7 +1537,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors
)

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
return controls - protocols.measurement_key_objs(self)

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class CircuitOperation(ops.Operation):
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
_cached_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = dataclasses.field(
_cached_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
default=None, init=False
)

Expand Down Expand Up @@ -184,7 +184,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _is_measurement_(self) -> bool:
return self.circuit._is_measurement_()

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
if self._cached_measurement_key_objs is None:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if self.repetition_ids is not None:
Expand All @@ -207,7 +207,7 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}

def _control_keys_(self) -> AbstractSet[value.MeasurementKey]:
def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
if not protocols.control_keys(self.circuit):
return frozenset()
return protocols.control_keys(self.mapped_circuit())
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from cirq import devices, ops, protocols, value
from cirq import devices, ops, protocols
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -74,9 +74,9 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None
self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet[value.MeasurementKey]] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down Expand Up @@ -126,15 +126,15 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._control_keys is None:
self._control_keys = super()._control_keys_()
return self._control_keys
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _render(diagram: circuits.TextDiagramDrawer) -> str:


def circuit_to_latex_using_qcircuit(
circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT
circuit: 'cirq.Circuit', qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
) -> str:
"""Returns a QCircuit-based latex diagram of the given circuit.

Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import quimb.tensor as qtn

from cirq import devices, study, ops, protocols, value
from cirq import devices, ops, protocols, value
from cirq.sim import simulator_base
from cirq.sim.act_on_args import ActOnArgs

Expand Down Expand Up @@ -126,7 +126,7 @@ def _create_step_result(

def _create_simulator_trial_result(
self,
params: study.ParamResolver,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
) -> 'MPSTrialResult':
Expand All @@ -151,7 +151,7 @@ class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState', 'MPSSt

def __init__(
self,
params: study.ParamResolver,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
) -> None:
Expand Down Expand Up @@ -321,7 +321,7 @@ def state_vector(self) -> np.ndarray:
sorted_ind = tuple(sorted(state_vector.inds))
return state_vector.fuse({'i': sorted_ind}).data

def partial_trace(self, keep_qubits: Set[ops.Qid]) -> np.ndarray:
def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray:
"""Traces out all qubits except keep_qubits.

Args:
Expand Down Expand Up @@ -475,7 +475,7 @@ def estimation_stats(self):
}

def perform_measurement(
self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True
self, qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, collapse_state_vector=True
) -> List[int]:
"""Performs a measurement over one or more qubits.

Expand Down Expand Up @@ -533,7 +533,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:

def sample(
self,
qubits: Sequence[ops.Qid],
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _with_rescoped_keys_(
path: Tuple[str, ...],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
) -> 'ClassicallyControlledOperation':
def map_key(key: value.MeasurementKey) -> value.MeasurementKey:
def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey':
for i in range(len(path) + 1):
back_path = path[: len(path) - i]
new_key = key.with_key_path_prefix(*back_path)
Expand All @@ -195,7 +195,7 @@ def map_key(key: value.MeasurementKey) -> value.MeasurementKey:
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys])

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation))

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
return getter()
return NotImplemented

def _measurement_key_obj_(self) -> Optional[value.MeasurementKey]:
def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']:
getter = getattr(self.gate, '_measurement_key_obj_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> Optional[AbstractSet[value.MeasurementKey]]:
def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
import numpy as np

from cirq import linalg, protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


# TODO(#3241): support qudits and non-square operators.
class KrausChannel(raw_types.Gate):
Expand All @@ -25,7 +28,7 @@ class KrausChannel(raw_types.Gate):
def __init__(
self,
kraus_ops: Iterable[np.ndarray],
key: Union[str, value.MeasurementKey, None] = None,
key: Union[str, 'cirq.MeasurementKey', None] = None,
validate: bool = False,
):
kraus_ops = list(kraus_ops)
Expand All @@ -52,7 +55,7 @@ def __init__(
self._key = key

@staticmethod
def from_channel(channel: 'KrausChannel', key: Union[str, value.MeasurementKey, None] = None):
def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None):
"""Creates a copy of a channel with the given measurement key."""
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key)

Expand All @@ -76,7 +79,7 @@ def _measurement_key_name_(self) -> str:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
if self._key is None:
return NotImplemented
return self._key
Expand All @@ -99,7 +102,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet[value.MeasurementKey],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
return KrausChannel(
kraus_ops=self._kraus_ops,
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from cirq import protocols, value
from cirq import protocols
from cirq.ops import raw_types, pauli_string
from cirq.ops.measurement_gate import MeasurementGate
from cirq.ops.pauli_measurement_gate import PauliMeasurementGate
Expand All @@ -31,7 +31,7 @@ def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:

def measure_single_paulistring(
pauli_observable: pauli_string.PauliString,
key: Optional[Union[str, value.MeasurementKey]] = None,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
) -> raw_types.Operation:
"""Returns a single PauliMeasurementGate which measures the pauli observable

Expand Down Expand Up @@ -83,7 +83,7 @@ def measure_paulistring_terms(

def measure(
*target: 'cirq.Qid',
key: Optional[Union[str, value.MeasurementKey]] = None,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MeasurementGate(raw_types.Gate):
def __init__(
self,
num_qubits: Optional[int] = None,
key: Union[str, value.MeasurementKey] = '',
key: Union[str, 'cirq.MeasurementKey'] = '',
invert_mask: Tuple[bool, ...] = (),
qid_shape: Tuple[int, ...] = None,
) -> None:
Expand Down Expand Up @@ -75,7 +75,7 @@ def key(self) -> str:
return str(self.mkey)

@key.setter
def key(self, key: Union[str, value.MeasurementKey]):
def key(self, key: Union[str, 'cirq.MeasurementKey']):
if isinstance(key, value.MeasurementKey):
self.mkey = key
else:
Expand All @@ -84,7 +84,7 @@ def key(self, key: Union[str, value.MeasurementKey]):
def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if key == self.key:
return self
Expand Down Expand Up @@ -139,7 +139,7 @@ def _is_measurement_(self) -> bool:
def _measurement_key_name_(self) -> str:
return self.key

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
return self.mkey

def _kraus_(self):
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/mixed_unitary_channel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
import numpy as np

from cirq import linalg, protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


class MixedUnitaryChannel(raw_types.Gate):
"""A generic mixture that can record the index of its selected operator.
Expand All @@ -24,7 +27,7 @@ class MixedUnitaryChannel(raw_types.Gate):
def __init__(
self,
mixture: Iterable[Tuple[float, np.ndarray]],
key: Union[str, value.MeasurementKey, None] = None,
key: Union[str, 'cirq.MeasurementKey', None] = None,
validate: bool = False,
):
mixture = list(mixture)
Expand Down Expand Up @@ -54,7 +57,7 @@ def __init__(

@staticmethod
def from_mixture(
mixture: 'protocols.SupportsMixture', key: Union[str, value.MeasurementKey, None] = None
mixture: 'protocols.SupportsMixture', key: Union[str, 'cirq.MeasurementKey', None] = None
):
"""Creates a copy of a mixture with the given measurement key."""
return MixedUnitaryChannel(mixture=list(protocols.mixture(mixture)), key=key)
Expand Down Expand Up @@ -85,7 +88,7 @@ def _measurement_key_name_(self) -> str:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
if self._key is None:
return NotImplemented
return self._key
Expand All @@ -110,7 +113,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet[value.MeasurementKey],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
return MixedUnitaryChannel(
mixture=self._mixture,
Expand Down
Loading