diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index bd81e60d26d..a51b8802dd8 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -47,6 +47,7 @@ TypeVar, Union, ) +from typing_extensions import Self import networkx import numpy as np @@ -236,15 +237,15 @@ def __getitem__(self, key: Tuple[int, Iterable['cirq.Qid']]) -> 'cirq.Moment': pass @overload - def __getitem__(self: CIRCUIT_TYPE, key: slice) -> CIRCUIT_TYPE: + def __getitem__(self, key: slice) -> Self: pass @overload - def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, 'cirq.Qid']) -> CIRCUIT_TYPE: + def __getitem__(self, key: Tuple[slice, 'cirq.Qid']) -> Self: pass @overload - def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) -> CIRCUIT_TYPE: + def __getitem__(self, key: Tuple[slice, Iterable['cirq.Qid']]) -> Self: pass def __getitem__(self, key): @@ -913,9 +914,7 @@ def all_operations(self) -> Iterator['cirq.Operation']: """ return (op for moment in self for op in moment.operations) - def map_operations( - self: CIRCUIT_TYPE, func: Callable[['cirq.Operation'], 'cirq.OP_TREE'] - ) -> CIRCUIT_TYPE: + def map_operations(self, func: Callable[['cirq.Operation'], 'cirq.OP_TREE']) -> Self: """Applies the given function to all operations in this circuit. Args: @@ -1287,9 +1286,7 @@ def _is_parameterized_(self) -> bool: def _parameter_names_(self) -> AbstractSet[str]: return {name for op in self.all_operations() for name in protocols.parameter_names(op)} - def _resolve_parameters_( - self: CIRCUIT_TYPE, resolver: 'cirq.ParamResolver', recursive: bool - ) -> CIRCUIT_TYPE: + def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> Self: changed = False resolved_moments: List['cirq.Moment'] = [] for moment in self: @@ -1540,7 +1537,7 @@ def get_independent_qubit_sets(self) -> List[Set['cirq.Qid']]: uf.union(*op.qubits) return sorted([qs for qs in uf.to_sets()], key=min) - def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]: + def factorize(self) -> Iterable[Self]: """Factorize circuit into a sequence of independent circuits (factors). Factorization is possible when the circuit's qubits can be divided diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index a46cca70c05..e79d8e6d586 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -30,9 +30,9 @@ Sequence, Tuple, TYPE_CHECKING, - TypeVar, Union, ) +from typing_extensions import Self import numpy as np @@ -52,8 +52,6 @@ "text_diagram_drawer", globals(), "cirq.circuits.text_diagram_drawer" ) -TSelf_Moment = TypeVar('TSelf_Moment', bound='Moment') - def _default_breakdown(qid: 'cirq.Qid') -> Tuple[Any, Any]: # Attempt to convert into a position on the complex plane. @@ -373,9 +371,8 @@ def _decompose_(self) -> 'cirq.OP_TREE': return self._operations def transform_qubits( - self: TSelf_Moment, - qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']], - ) -> TSelf_Moment: + self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']] + ) -> Self: """Returns the same moment, but with different qubits. Args: diff --git a/cirq-core/cirq/devices/grid_qubit.py b/cirq-core/cirq/devices/grid_qubit.py index 87ceda7a956..b41fdcacedc 100644 --- a/cirq-core/cirq/devices/grid_qubit.py +++ b/cirq-core/cirq/devices/grid_qubit.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TypeVar, TYPE_CHECKING, Union - import abc +import functools +from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union +from typing_extensions import Self import numpy as np @@ -24,8 +24,6 @@ if TYPE_CHECKING: import cirq -TSelf = TypeVar('TSelf', bound='_BaseGridQid') - @functools.total_ordering class _BaseGridQid(ops.Qid): @@ -69,13 +67,13 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseGridQ return neighbors @abc.abstractmethod - def _with_row_col(self: TSelf, row: int, col: int) -> TSelf: + def _with_row_col(self, row: int, col: int) -> Self: """Returns a qid with the same type but a different coordinate.""" def __complex__(self) -> complex: return self.col + 1j * self.row - def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf': + def __add__(self, other: Union[Tuple[int, int], Self]) -> Self: if isinstance(other, _BaseGridQid): if self.dimension != other.dimension: raise TypeError( @@ -94,7 +92,7 @@ def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf': ) return self._with_row_col(row=self.row + other[0], col=self.col + other[1]) - def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf': + def __sub__(self, other: Union[Tuple[int, int], Self]) -> Self: if isinstance(other, _BaseGridQid): if self.dimension != other.dimension: raise TypeError( @@ -113,13 +111,13 @@ def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf': ) return self._with_row_col(row=self.row - other[0], col=self.col - other[1]) - def __radd__(self: TSelf, other: Tuple[int, int]) -> 'TSelf': + def __radd__(self, other: Tuple[int, int]) -> Self: return self + other - def __rsub__(self: TSelf, other: Tuple[int, int]) -> 'TSelf': + def __rsub__(self, other: Tuple[int, int]) -> Self: return -self + other - def __neg__(self: TSelf) -> 'TSelf': + def __neg__(self) -> Self: return self._with_row_col(row=-self.row, col=-self.col) diff --git a/cirq-core/cirq/devices/line_qubit.py b/cirq-core/cirq/devices/line_qubit.py index f53ec69e595..2937558a9ef 100644 --- a/cirq-core/cirq/devices/line_qubit.py +++ b/cirq-core/cirq/devices/line_qubit.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TypeVar, TYPE_CHECKING, Union - import abc +import functools +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union +from typing_extensions import Self from cirq import ops, protocols if TYPE_CHECKING: import cirq -TSelf = TypeVar('TSelf', bound='_BaseLineQid') - @functools.total_ordering class _BaseLineQid(ops.Qid): @@ -66,10 +64,10 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQ return neighbors @abc.abstractmethod - def _with_x(self: TSelf, x: int) -> TSelf: + def _with_x(self, x: int) -> Self: """Returns a qubit with the same type but a different value of `x`.""" - def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf: + def __add__(self, other: Union[int, Self]) -> Self: if isinstance(other, _BaseLineQid): if self.dimension != other.dimension: raise TypeError( @@ -81,7 +79,7 @@ def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf: raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}") return self._with_x(self.x + other) - def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf: + def __sub__(self, other: Union[int, Self]) -> Self: if isinstance(other, _BaseLineQid): if self.dimension != other.dimension: raise TypeError( @@ -95,13 +93,13 @@ def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf: ) return self._with_x(self.x - other) - def __radd__(self: TSelf, other: int) -> TSelf: + def __radd__(self, other: int) -> Self: return self + other - def __rsub__(self: TSelf, other: int) -> TSelf: + def __rsub__(self, other: int) -> Self: return -self + other - def __neg__(self: TSelf) -> TSelf: + def __neg__(self) -> Self: return self._with_x(-self.x) def __complex__(self) -> complex: diff --git a/cirq-core/cirq/ops/arithmetic_operation.py b/cirq-core/cirq/ops/arithmetic_operation.py index 020c4609496..37e87be9c1b 100644 --- a/cirq-core/cirq/ops/arithmetic_operation.py +++ b/cirq-core/cirq/ops/arithmetic_operation.py @@ -15,7 +15,8 @@ import abc import itertools -from typing import Union, Iterable, List, Sequence, cast, Tuple, TypeVar, TYPE_CHECKING +from typing import Union, Iterable, List, Sequence, cast, Tuple, TYPE_CHECKING +from typing_extensions import Self import numpy as np @@ -25,9 +26,6 @@ import cirq -TSelfGate = TypeVar('TSelfGate', bound='ArithmeticGate') - - class ArithmeticGate(Gate, metaclass=abc.ABCMeta): r"""A helper gate for implementing reversible classical arithmetic. @@ -55,7 +53,7 @@ class ArithmeticGate(Gate, metaclass=abc.ABCMeta): ... ... def with_registers( ... self, *new_registers: 'Union[int, Sequence[int]]' - ... ) -> 'TSelfGate': + ... ) -> 'Add': ... return Add(*new_registers) ... ... def apply(self, *register_values: int) -> 'Union[int, Iterable[int]]': @@ -105,7 +103,7 @@ def registers(self) -> Sequence[Union[int, Sequence[int]]]: raise NotImplementedError() @abc.abstractmethod - def with_registers(self: TSelfGate, *new_registers: Union[int, Sequence[int]]) -> TSelfGate: + def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> Self: """Returns the same fate targeting different registers. Args: diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 01965cc3281..1d1359ec82c 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -357,7 +357,7 @@ def __init__(self, *, rads: value.TParamVal): self._rads = rads super().__init__(exponent=rads / _pi(rads), global_shift=-0.5) - def _with_exponent(self: 'Rx', exponent: value.TParamVal) -> 'Rx': + def _with_exponent(self, exponent: value.TParamVal) -> 'Rx': return Rx(rads=exponent * _pi(exponent)) def _circuit_diagram_info_( @@ -541,7 +541,7 @@ def __init__(self, *, rads: value.TParamVal): self._rads = rads super().__init__(exponent=rads / _pi(rads), global_shift=-0.5) - def _with_exponent(self: 'Ry', exponent: value.TParamVal) -> 'Ry': + def _with_exponent(self, exponent: value.TParamVal) -> 'Ry': return Ry(rads=exponent * _pi(exponent)) def _circuit_diagram_info_( @@ -891,7 +891,7 @@ def __init__(self, *, rads: value.TParamVal): self._rads = rads super().__init__(exponent=rads / _pi(rads), global_shift=-0.5) - def _with_exponent(self: 'Rz', exponent: value.TParamVal) -> 'Rz': + def _with_exponent(self, exponent: value.TParamVal) -> 'Rz': return Rz(rads=exponent * _pi(exponent)) def _circuit_diagram_info_( diff --git a/cirq-core/cirq/ops/dense_pauli_string.py b/cirq-core/cirq/ops/dense_pauli_string.py index 24622155b9f..6cf97c4eb31 100644 --- a/cirq-core/cirq/ops/dense_pauli_string.py +++ b/cirq-core/cirq/ops/dense_pauli_string.py @@ -27,11 +27,10 @@ overload, Sequence, Tuple, - Type, TYPE_CHECKING, - TypeVar, Union, ) +from typing_extensions import Self import numpy as np import sympy @@ -53,8 +52,6 @@ pauli_gates.Z, ] -TCls = TypeVar('TCls', bound='BaseDensePauliString') - @value.value_equality(approximate=True, distinct_child_types=True) class BaseDensePauliString(raw_types.Gate, metaclass=abc.ABCMeta): @@ -132,7 +129,7 @@ def _value_equality_values_(self): return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask) @classmethod - def one_hot(cls: Type[TCls], *, index: int, length: int, pauli: 'cirq.PAULI_GATE_LIKE') -> TCls: + def one_hot(cls, *, index: int, length: int, pauli: 'cirq.PAULI_GATE_LIKE') -> Self: """Creates a dense pauli string with only one non-identity Pauli. Args: @@ -149,7 +146,7 @@ def one_hot(cls: Type[TCls], *, index: int, length: int, pauli: 'cirq.PAULI_GATE return concrete_cls(pauli_mask=mask) @classmethod - def eye(cls: Type[TCls], length: int) -> TCls: + def eye(cls, length: int) -> Self: """Creates a dense pauli string containing only identity gates. Args: @@ -198,7 +195,7 @@ def _is_parameterized_(self) -> bool: def _parameter_names_(self) -> AbstractSet[str]: return protocols.parameter_names(self.coefficient) - def _resolve_parameters_(self: TCls, resolver: 'cirq.ParamResolver', recursive: bool) -> TCls: + def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> Self: return self.copy( coefficient=protocols.resolve_parameters(self.coefficient, resolver, recursive) ) @@ -206,7 +203,7 @@ def _resolve_parameters_(self: TCls, resolver: 'cirq.ParamResolver', recursive: def __pos__(self): return self - def __pow__(self: TCls, power: Union[int, float]) -> Union[NotImplementedType, TCls]: + def __pow__(self, power: Union[int, float]) -> Union[NotImplementedType, Self]: concrete_class = type(self) if isinstance(power, int): i_group = [1, +1j, -1, -1j] @@ -221,11 +218,11 @@ def __pow__(self: TCls, power: Union[int, float]) -> Union[NotImplementedType, T return NotImplemented @overload - def __getitem__(self: TCls, item: int) -> Union['cirq.Pauli', 'cirq.IdentityGate']: + def __getitem__(self, item: int) -> Union['cirq.Pauli', 'cirq.IdentityGate']: pass @overload - def __getitem__(self: TCls, item: slice) -> TCls: + def __getitem__(self, item: slice) -> Self: pass def __getitem__(self, item): @@ -304,7 +301,7 @@ def __rmul__(self, other): return NotImplemented - def tensor_product(self: TCls, other: 'BaseDensePauliString') -> TCls: + def tensor_product(self, other: 'BaseDensePauliString') -> Self: """Concatenates dense pauli strings and multiplies their coefficients. Args: @@ -319,7 +316,7 @@ def tensor_product(self: TCls, other: 'BaseDensePauliString') -> TCls: pauli_mask=np.concatenate([self.pauli_mask, other.pauli_mask]), ) - def __abs__(self: TCls) -> TCls: + def __abs__(self) -> Self: coef = self.coefficient return type(self)( coefficient=sympy.Abs(coef) if isinstance(coef, sympy.Expr) else abs(coef), @@ -405,10 +402,10 @@ def mutable_copy(self) -> 'MutableDensePauliString': @abc.abstractmethod def copy( - self: TCls, + self, coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None, pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None, - ) -> TCls: + ) -> Self: """Returns a copy with possibly modified contents. Args: @@ -492,17 +489,15 @@ class MutableDensePauliString(BaseDensePauliString): """ @overload - def __setitem__( - self: 'MutableDensePauliString', key: int, value: 'cirq.PAULI_GATE_LIKE' - ) -> 'MutableDensePauliString': + def __setitem__(self, key: int, value: 'cirq.PAULI_GATE_LIKE') -> Self: pass @overload def __setitem__( - self: 'MutableDensePauliString', + self, key: slice, value: Union[Iterable['cirq.PAULI_GATE_LIKE'], np.ndarray, BaseDensePauliString], - ) -> 'MutableDensePauliString': + ) -> Self: pass def __setitem__(self, key, value): diff --git a/cirq-core/cirq/ops/eigen_gate.py b/cirq-core/cirq/ops/eigen_gate.py index 837da19a359..b23e332b74b 100644 --- a/cirq-core/cirq/ops/eigen_gate.py +++ b/cirq-core/cirq/ops/eigen_gate.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import fractions +import math +import numbers from typing import ( AbstractSet, Any, @@ -23,12 +26,8 @@ Optional, Tuple, TYPE_CHECKING, - TypeVar, Union, ) -import abc -import math -import numbers import numpy as np import sympy @@ -40,8 +39,6 @@ if TYPE_CHECKING: import cirq -TSelf = TypeVar('TSelf', bound='EigenGate') - EigenComponent = NamedTuple( 'EigenComponent', @@ -135,7 +132,7 @@ def global_shift(self) -> float: return self._global_shift # virtual method - def _with_exponent(self: TSelf, exponent: value.TParamVal) -> 'EigenGate': + def _with_exponent(self, exponent: value.TParamVal) -> 'EigenGate': """Return the same kind of gate, but with a different exponent. Child classes should override this method if they have an __init__ @@ -301,7 +298,7 @@ def _period(self) -> Optional[float]: real_periods = [abs(2 / e) for e in exponents if e != 0] return _approximate_common_period(real_periods) - def __pow__(self: TSelf, exponent: Union[float, sympy.Symbol]) -> 'EigenGate': + def __pow__(self, exponent: Union[float, sympy.Symbol]) -> 'EigenGate': new_exponent = protocols.mul(self._exponent, exponent, NotImplemented) if new_exponent is NotImplemented: return NotImplemented diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 67ca2bed61b..472610214e0 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -31,6 +31,7 @@ Union, List, ) +from typing_extensions import Self import numpy as np @@ -42,9 +43,6 @@ import cirq -TSelf = TypeVar('TSelf', bound='GateOperation') - - @value.value_equality(approximate=True) class GateOperation(raw_types.Operation): """An application of a gate to a sequence of qubits. @@ -73,8 +71,8 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """The qubits targeted by the operation.""" return self._qubits - def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: - return cast(TSelf, self.gate.on(*new_qubits)) + def with_qubits(self, *new_qubits: 'cirq.Qid') -> Self: + return cast(Self, self.gate.on(*new_qubits)) def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation': if self.gate is new_gate: diff --git a/cirq-core/cirq/ops/named_qubit.py b/cirq-core/cirq/ops/named_qubit.py index e5f622a03c6..76bf2391fca 100644 --- a/cirq-core/cirq/ops/named_qubit.py +++ b/cirq-core/cirq/ops/named_qubit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Dict, List, TYPE_CHECKING, TypeVar +from typing import Any, Dict, List, TYPE_CHECKING from cirq import protocols from cirq.ops import raw_types @@ -21,8 +21,6 @@ if TYPE_CHECKING: import cirq -TSelf = TypeVar('TSelf', bound='_BaseNamedQid') - @functools.total_ordering class _BaseNamedQid(raw_types.Qid): diff --git a/cirq-core/cirq/ops/parity_gates.py b/cirq-core/cirq/ops/parity_gates.py index b65baa3a1cb..1839260d035 100644 --- a/cirq-core/cirq/ops/parity_gates.py +++ b/cirq-core/cirq/ops/parity_gates.py @@ -15,6 +15,7 @@ """Quantum gates that phase with respect to product-of-pauli observables.""" from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from typing_extensions import Self import numpy as np @@ -328,7 +329,7 @@ def __init__(self, *, rads: float): # Forces keyword args. XXPowGate.__init__(self, exponent=rads * 2 / np.pi, global_shift=-0.5) self.rads = rads - def _with_exponent(self: 'MSGate', exponent: value.TParamVal) -> 'MSGate': + def _with_exponent(self, exponent: value.TParamVal) -> Self: return type(self)(rads=exponent * np.pi / 2) def _circuit_diagram_info_( diff --git a/cirq-core/cirq/ops/pauli_gates.py b/cirq-core/cirq/ops/pauli_gates.py index 7d4ed3ee293..f8ea35f1c63 100644 --- a/cirq-core/cirq/ops/pauli_gates.py +++ b/cirq-core/cirq/ops/pauli_gates.py @@ -112,10 +112,10 @@ def __init__(self): Pauli.__init__(self, index=0, name='X') common_gates.XPowGate.__init__(self, exponent=1.0) - def __pow__(self: '_PauliX', exponent: 'cirq.TParamVal') -> common_gates.XPowGate: + def __pow__(self, exponent: 'cirq.TParamVal') -> common_gates.XPowGate: return common_gates.XPowGate(exponent=exponent) if exponent != 1 else _PauliX() - def _with_exponent(self: '_PauliX', exponent: 'cirq.TParamVal') -> common_gates.XPowGate: + def _with_exponent(self, exponent: 'cirq.TParamVal') -> common_gates.XPowGate: return self.__pow__(exponent) @classmethod @@ -125,7 +125,7 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs): return Pauli._XYZ[0] @property - def basis(self: '_PauliX') -> Dict[int, '_XEigenState']: + def basis(self) -> Dict[int, '_XEigenState']: from cirq.value.product_state import _XEigenState return {+1: _XEigenState(+1), -1: _XEigenState(-1)} @@ -136,10 +136,10 @@ def __init__(self): Pauli.__init__(self, index=1, name='Y') common_gates.YPowGate.__init__(self, exponent=1.0) - def __pow__(self: '_PauliY', exponent: 'cirq.TParamVal') -> common_gates.YPowGate: + def __pow__(self, exponent: 'cirq.TParamVal') -> common_gates.YPowGate: return common_gates.YPowGate(exponent=exponent) if exponent != 1 else _PauliY() - def _with_exponent(self: '_PauliY', exponent: 'cirq.TParamVal') -> common_gates.YPowGate: + def _with_exponent(self, exponent: 'cirq.TParamVal') -> common_gates.YPowGate: return self.__pow__(exponent) @classmethod @@ -149,7 +149,7 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs): return Pauli._XYZ[1] @property - def basis(self: '_PauliY') -> Dict[int, '_YEigenState']: + def basis(self) -> Dict[int, '_YEigenState']: from cirq.value.product_state import _YEigenState return {+1: _YEigenState(+1), -1: _YEigenState(-1)} @@ -160,10 +160,10 @@ def __init__(self): Pauli.__init__(self, index=2, name='Z') common_gates.ZPowGate.__init__(self, exponent=1.0) - def __pow__(self: '_PauliZ', exponent: 'cirq.TParamVal') -> common_gates.ZPowGate: + def __pow__(self, exponent: 'cirq.TParamVal') -> common_gates.ZPowGate: return common_gates.ZPowGate(exponent=exponent) if exponent != 1 else _PauliZ() - def _with_exponent(self: '_PauliZ', exponent: 'cirq.TParamVal') -> common_gates.ZPowGate: + def _with_exponent(self, exponent: 'cirq.TParamVal') -> common_gates.ZPowGate: return self.__pow__(exponent) @classmethod @@ -173,7 +173,7 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs): return Pauli._XYZ[2] @property - def basis(self: '_PauliZ') -> Dict[int, '_ZEigenState']: + def basis(self) -> Dict[int, '_ZEigenState']: from cirq.value.product_state import _ZEigenState return {+1: _ZEigenState(+1), -1: _ZEigenState(-1)} diff --git a/cirq-core/cirq/ops/pauli_string_raw_types.py b/cirq-core/cirq/ops/pauli_string_raw_types.py index f102c834ede..8aa8982419b 100644 --- a/cirq-core/cirq/ops/pauli_string_raw_types.py +++ b/cirq-core/cirq/ops/pauli_string_raw_types.py @@ -13,7 +13,8 @@ # limitations under the License. import abc -from typing import Any, Dict, Sequence, Tuple, TypeVar, TYPE_CHECKING +from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import Self from cirq import protocols from cirq.ops import pauli_string as ps, raw_types @@ -21,10 +22,6 @@ if TYPE_CHECKING: import cirq -TSelf_PauliStringGateOperation = TypeVar( - 'TSelf_PauliStringGateOperation', bound='PauliStringGateOperation' -) - class PauliStringGateOperation(raw_types.Operation, metaclass=abc.ABCMeta): def __init__(self, pauli_string: ps.PauliString) -> None: @@ -38,16 +35,12 @@ def validate_args(self, qubits: Sequence[raw_types.Qid]) -> None: if len(qubits) != len(self.pauli_string): raise ValueError('Incorrect number of qubits for gate') - def with_qubits( - self: TSelf_PauliStringGateOperation, *new_qubits: 'cirq.Qid' - ) -> TSelf_PauliStringGateOperation: + def with_qubits(self, *new_qubits: 'cirq.Qid') -> Self: self.validate_args(new_qubits) return self.map_qubits(dict(zip(self.pauli_string.qubits, new_qubits))) @abc.abstractmethod - def map_qubits( - self: TSelf_PauliStringGateOperation, qubit_map: Dict[raw_types.Qid, raw_types.Qid] - ) -> TSelf_PauliStringGateOperation: + def map_qubits(self, qubit_map: Dict[raw_types.Qid, raw_types.Qid]) -> Self: """Return an equivalent operation on new qubits with its Pauli string mapped to new qubits. diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index e04f9335840..3bdf3bc9a58 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -31,10 +31,10 @@ Optional, Sequence, Tuple, - TypeVar, TYPE_CHECKING, Union, ) +from typing_extensions import Self import numpy as np import sympy @@ -483,9 +483,6 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, attribute_names=[]) -TSelf = TypeVar('TSelf', bound='Operation') - - class Operation(metaclass=abc.ABCMeta): """An effect applied to a collection of qubits. @@ -514,7 +511,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: return protocols.qid_shape(self.qubits) @abc.abstractmethod - def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: + def with_qubits(self, *new_qubits: 'cirq.Qid') -> Self: """Returns the same operation, but applied to different qubits. Args: @@ -556,9 +553,8 @@ def with_tags(self, *new_tags: Hashable) -> 'cirq.Operation': return TaggedOperation(self, *new_tags) def transform_qubits( - self: TSelf, - qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']], - ) -> TSelf: + self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']] + ) -> Self: """Returns the same operation, but with different qubits. Args: diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index 2cc31855b78..f526695fd20 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Sequence, Tuple +from typing_extensions import Self import numpy as np import pytest import cirq -from cirq.ops.raw_types import TSelf class DummyQuantumState(cirq.QuantumStateRepresentation): @@ -65,7 +65,7 @@ class Op(cirq.Operation): def qubits(self) -> Tuple['cirq.Qid', ...]: # type: ignore[empty-body] pass - def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: # type: ignore[empty-body] + def with_qubits(self, *new_qubits: 'cirq.Qid') -> Self: # type: ignore[empty-body] pass def _act_on_(self, sim_state): @@ -82,7 +82,7 @@ class Op(cirq.Operation): def qubits(self) -> Tuple['cirq.Qid', ...]: # type: ignore[empty-body] pass - def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: # type: ignore[empty-body] + def with_qubits(self, *new_qubits: 'cirq.Qid') -> Self: # type: ignore[empty-body] pass state = DummySimulationState() diff --git a/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py b/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py index 3c89fb853f4..00d1e5b65de 100644 --- a/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py +++ b/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py @@ -121,9 +121,7 @@ def _wire_symbols_including_formatted_exponent( result[k] += '^' + exponent return result - def _formatted_exponent( - self: 'cirq.CircuitDiagramInfo', args: 'cirq.CircuitDiagramInfoArgs' - ) -> Optional[str]: + def _formatted_exponent(self, args: 'cirq.CircuitDiagramInfoArgs') -> Optional[str]: if protocols.is_parameterized(self.exponent): name = str(self.exponent) diff --git a/cirq-core/cirq/protocols/phase_protocol.py b/cirq-core/cirq/protocols/phase_protocol.py index 76dd5035a6d..da072968ddb 100644 --- a/cirq-core/cirq/protocols/phase_protocol.py +++ b/cirq-core/cirq/protocols/phase_protocol.py @@ -30,7 +30,7 @@ class SupportsPhase(Protocol): """An effect that can be phased around the Z axis of target qubits.""" @doc_private - def _phase_by_(self: Any, phase_turns: float, qubit_index: int): + def _phase_by_(self, phase_turns: float, qubit_index: int): """Returns a phased version of the effect. Specifically, returns an object with matrix P U P^-1 (up to global diff --git a/cirq-core/cirq/protocols/resolve_parameters.py b/cirq-core/cirq/protocols/resolve_parameters.py index a1ddfeddb8f..3243f26b555 100644 --- a/cirq-core/cirq/protocols/resolve_parameters.py +++ b/cirq-core/cirq/protocols/resolve_parameters.py @@ -14,6 +14,7 @@ import numbers from typing import AbstractSet, Any, cast, TYPE_CHECKING, TypeVar +from typing_extensions import Self import sympy from typing_extensions import Protocol @@ -33,13 +34,13 @@ class SupportsParameterization(Protocol): via a ParamResolver""" @doc_private - def _is_parameterized_(self: Any) -> bool: + def _is_parameterized_(self) -> bool: """Whether the object is parameterized by any Symbols that require resolution. Returns True if the object has any unresolved Symbols and False otherwise.""" @doc_private - def _parameter_names_(self: Any) -> AbstractSet[str]: + def _parameter_names_(self) -> AbstractSet[str]: """Returns a collection of string names of parameters that require resolution. If _is_parameterized_ is False, the collection is empty. The converse is not necessarily true, because some objects may report @@ -48,7 +49,7 @@ def _parameter_names_(self: Any) -> AbstractSet[str]: """ @doc_private - def _resolve_parameters_(self: T, resolver: 'cirq.ParamResolver', recursive: bool) -> T: + def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> Self: """Resolve the parameters in the effect.""" diff --git a/cirq-core/cirq/qis/quantum_state_representation.py b/cirq-core/cirq/qis/quantum_state_representation.py index b2f6587adea..41b808a3983 100644 --- a/cirq-core/cirq/qis/quantum_state_representation.py +++ b/cirq-core/cirq/qis/quantum_state_representation.py @@ -13,7 +13,9 @@ # limitations under the License. import abc -from typing import List, Sequence, Tuple, TYPE_CHECKING, TypeVar +from typing import List, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import Self + import numpy as np from cirq import value @@ -21,12 +23,10 @@ if TYPE_CHECKING: import cirq -TSelf = TypeVar('TSelf', bound='QuantumStateRepresentation') - class QuantumStateRepresentation(metaclass=abc.ABCMeta): @abc.abstractmethod - def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: + def copy(self, deep_copy_buffers: bool = True) -> Self: """Creates a copy of the object. Args: deep_copy_buffers: If True, buffers will also be deep-copied. @@ -71,17 +71,15 @@ def sample( measurements.append(state.measure(axes, prng)) return np.array(measurements, dtype=np.uint8) - def kron(self: TSelf, other: TSelf) -> TSelf: + def kron(self, other: Self) -> Self: """Joins two state spaces together.""" raise NotImplementedError() - def factor( - self: TSelf, axes: Sequence[int], *, validate=True, atol=1e-07 - ) -> Tuple[TSelf, TSelf]: + def factor(self, axes: Sequence[int], *, validate=True, atol=1e-07) -> Tuple[Self, Self]: """Splits two state spaces after a measurement or reset.""" raise NotImplementedError() - def reindex(self: TSelf, axes: Sequence[int]) -> TSelf: + def reindex(self, axes: Sequence[int]) -> Self: """Physically reindexes the state by the new basis. Args: axes: The desired axis order. diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py index fcaeb95f9f2..88d1aa43221 100644 --- a/cirq-core/cirq/sim/simulation_state.py +++ b/cirq-core/cirq/sim/simulation_state.py @@ -27,6 +27,7 @@ TYPE_CHECKING, Tuple, ) +from typing_extensions import Self import numpy as np @@ -34,7 +35,6 @@ from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.sim.simulation_state_base import SimulationStateBase -TSelf = TypeVar('TSelf', bound='SimulationState') TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation') if TYPE_CHECKING: @@ -146,7 +146,7 @@ def sample( return self._state.sample(self.get_axes(qubits), repetitions, seed) raise NotImplementedError() - def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: + def copy(self, deep_copy_buffers: bool = True) -> Self: """Creates a copy of the object. Args: @@ -162,11 +162,11 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers) return args - def create_merged_state(self: TSelf) -> TSelf: + def create_merged_state(self) -> Self: """Creates a final merged state.""" return self - def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: + def kronecker_product(self, other: Self, *, inplace=False) -> Self: """Joins two state spaces together.""" args = self if inplace else copy.copy(self) args._state = self._state.kron(other._state) @@ -174,8 +174,8 @@ def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: return args def factor( - self: TSelf, qubits: Sequence['cirq.Qid'], *, validate=True, atol=1e-07, inplace=False - ) -> Tuple[TSelf, TSelf]: + self, qubits: Sequence['cirq.Qid'], *, validate=True, atol=1e-07, inplace=False + ) -> Tuple[Self, Self]: """Splits two state spaces after a measurement or reset.""" extracted = copy.copy(self) remainder = self if inplace else copy.copy(self) @@ -191,9 +191,7 @@ def allows_factoring(self): """Subclasses that allow factorization should override this.""" return self._state.supports_factor if self._state is not None else False - def transpose_to_qubit_order( - self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False - ) -> TSelf: + def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid'], *, inplace=False) -> Self: """Physically reindexes the state by the new basis. Args: @@ -276,7 +274,7 @@ def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): args._set_qubits(qubits) return args - def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf: + def __getitem__(self, item: Optional['cirq.Qid']) -> Self: if item not in self.qubit_map: raise IndexError(f'{item} not in {self.qubits}') return self diff --git a/cirq-core/cirq/sim/simulation_state_base.py b/cirq-core/cirq/sim/simulation_state_base.py index 044ee94c4a1..684813f07d5 100644 --- a/cirq-core/cirq/sim/simulation_state_base.py +++ b/cirq-core/cirq/sim/simulation_state_base.py @@ -27,6 +27,7 @@ TypeVar, Union, ) +from typing_extensions import Self import numpy as np @@ -37,7 +38,6 @@ import cirq -TSelfTarget = TypeVar('TSelfTarget', bound='SimulationStateBase') TSimulationState = TypeVar('TSimulationState', bound='cirq.SimulationState') @@ -98,7 +98,7 @@ def apply_operation(self, op: 'cirq.Operation'): protocols.act_on(op, self) @abc.abstractmethod - def copy(self: TSelfTarget, deep_copy_buffers: bool = True) -> TSelfTarget: + def copy(self, deep_copy_buffers: bool = True) -> Self: """Creates a copy of the object. Args: diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 044e1afb609..3058552b805 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -34,7 +34,7 @@ def measure( self.measurement_count += 1 return [self.gate_count] - def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState': + def kron(self, other: 'CountingState') -> 'CountingState': return CountingState( self.data, self.gate_count + other.gate_count, @@ -43,13 +43,13 @@ def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState': ) def factor( - self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07 + self, axes: Sequence[int], *, validate=True, atol=1e-07 ) -> Tuple['CountingState', 'CountingState']: return CountingState( self.data, self.gate_count, self.measurement_count, self.copy_count ), CountingState(self.data) - def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState': + def reindex(self, axes: Sequence[int]) -> 'CountingState': return CountingState(self.data, self.gate_count, self.measurement_count, self.copy_count) def copy(self, deep_copy_buffers: bool = True) -> 'CountingState': diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 532212e84fe..be8955c324a 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -14,7 +14,8 @@ import abc import enum -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar +from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import Self from cirq.value import digits, value_equality_attr @@ -41,9 +42,6 @@ def __repr__(self): return f'cirq.{str(self)}' -TSelf = TypeVar('TSelf', bound='ClassicalDataStoreReader') - - class ClassicalDataStoreReader(abc.ABC): @abc.abstractmethod def keys(self) -> Tuple['cirq.MeasurementKey', ...]: @@ -101,7 +99,7 @@ def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: """ @abc.abstractmethod - def copy(self: TSelf) -> TSelf: + def copy(self) -> Self: """Creates a copy of the object.""" diff --git a/cirq-core/cirq/value/linear_dict.py b/cirq-core/cirq/value/linear_dict.py index 461c9ec102e..6b8e9d3a91d 100644 --- a/cirq-core/cirq/value/linear_dict.py +++ b/cirq-core/cirq/value/linear_dict.py @@ -32,6 +32,7 @@ Union, ValuesView, ) +from typing_extensions import Self Scalar = Union[complex, float, numbers.Complex] TVector = TypeVar('TVector') @@ -113,8 +114,6 @@ def __init__( if terms is not None: self.update(terms) - TSelf = TypeVar('TSelf', bound='LinearDict[TVector]') - @classmethod def fromkeys(cls, vectors, coefficient=0): return LinearDict(dict.fromkeys(vectors, complex(coefficient))) @@ -123,14 +122,14 @@ def _check_vector_valid(self, vector: TVector) -> None: if not self._is_valid(vector): raise ValueError(f'{vector} is not compatible with linear combination {self}') - def clean(self: 'TSelf', *, atol: float = 1e-9) -> 'TSelf': + def clean(self, *, atol: float = 1e-9) -> Self: """Remove terms with coefficients of absolute value atol or less.""" negligible = [v for v, c in self._terms.items() if abs(c) <= atol] # type: ignore[operator] for v in negligible: del self._terms[v] return self - def copy(self: 'TSelf') -> 'TSelf': + def copy(self) -> Self: factory = type(self) return factory(self._terms.copy()) @@ -206,19 +205,19 @@ def __iter__(self) -> Iterator[TVector]: def __len__(self) -> int: return len([v for v, c in self._terms.items() if c != 0]) - def __iadd__(self: 'TSelf', other: 'TSelf') -> 'TSelf': + def __iadd__(self, other: Self) -> Self: for vector, other_coefficient in other.items(): old_coefficient = self._terms.get(vector, 0) new_coefficient = old_coefficient + other_coefficient self[vector] = new_coefficient return self.clean(atol=0) - def __add__(self: 'TSelf', other: 'TSelf') -> 'TSelf': + def __add__(self, other: Self) -> Self: result = self.copy() result += other return result - def __isub__(self: 'TSelf', other: 'TSelf') -> 'TSelf': + def __isub__(self, other: Self) -> Self: for vector, other_coefficient in other.items(): old_coefficient = self._terms.get(vector, 0) new_coefficient = old_coefficient - other_coefficient @@ -226,30 +225,30 @@ def __isub__(self: 'TSelf', other: 'TSelf') -> 'TSelf': self.clean(atol=0) return self - def __sub__(self: 'TSelf', other: 'TSelf') -> 'TSelf': + def __sub__(self, other: Self) -> Self: result = self.copy() result -= other return result - def __neg__(self: 'TSelf') -> 'TSelf': + def __neg__(self) -> Self: factory = type(self) return factory({v: -c for v, c in self.items()}) - def __imul__(self: 'TSelf', a: Scalar) -> 'TSelf': + def __imul__(self, a: Scalar) -> Self: for vector in self: self._terms[vector] *= a self.clean(atol=0) return self - def __mul__(self: 'TSelf', a: Scalar) -> 'TSelf': + def __mul__(self, a: Scalar) -> Self: result = self.copy() result *= a return result - def __rmul__(self: 'TSelf', a: Scalar) -> 'TSelf': + def __rmul__(self, a: Scalar) -> Self: return self.__mul__(a) - def __truediv__(self: 'TSelf', a: Scalar) -> 'TSelf': + def __truediv__(self, a: Scalar) -> Self: return self.__mul__(1 / a) def __bool__(self) -> bool: diff --git a/cirq-google/cirq_google/experimental/ops/coupler_pulse.py b/cirq-google/cirq_google/experimental/ops/coupler_pulse.py index 3d467094309..415d8d623bf 100644 --- a/cirq-google/cirq_google/experimental/ops/coupler_pulse.py +++ b/cirq-google/cirq_google/experimental/ops/coupler_pulse.py @@ -115,7 +115,7 @@ def _is_parameterized_(self) -> bool: or cirq.is_parameterized(self.q1_detune_mhz) ) - def _parameter_names_(self: Any) -> AbstractSet[str]: + def _parameter_names_(self) -> AbstractSet[str]: return ( cirq.parameter_names(self.hold_time) | cirq.parameter_names(self.coupling_mhz)