Skip to content

Commit

Permalink
Use PEP-673 Self type annotations (#6057)
Browse files Browse the repository at this point in the history
Review: @pavoljuhas
  • Loading branch information
maffoo authored Apr 13, 2023
1 parent 1445f12 commit 5c2bcda
Show file tree
Hide file tree
Showing 25 changed files with 122 additions and 164 deletions.
17 changes: 7 additions & 10 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TypeVar,
Union,
)
from typing_extensions import Self

import networkx
import numpy as np
Expand Down Expand Up @@ -236,15 +237,15 @@ def __getitem__(self, key: Tuple[int, Iterable['cirq.Qid']]) -> 'cirq.Moment':
pass

@overload
def __getitem__(self: CIRCUIT_TYPE, key: slice) -> CIRCUIT_TYPE:
def __getitem__(self, key: slice) -> Self:
pass

@overload
def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, 'cirq.Qid']) -> CIRCUIT_TYPE:
def __getitem__(self, key: Tuple[slice, 'cirq.Qid']) -> Self:
pass

@overload
def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) -> CIRCUIT_TYPE:
def __getitem__(self, key: Tuple[slice, Iterable['cirq.Qid']]) -> Self:
pass

def __getitem__(self, key):
Expand Down Expand Up @@ -913,9 +914,7 @@ def all_operations(self) -> Iterator['cirq.Operation']:
"""
return (op for moment in self for op in moment.operations)

def map_operations(
self: CIRCUIT_TYPE, func: Callable[['cirq.Operation'], 'cirq.OP_TREE']
) -> CIRCUIT_TYPE:
def map_operations(self, func: Callable[['cirq.Operation'], 'cirq.OP_TREE']) -> Self:
"""Applies the given function to all operations in this circuit.
Args:
Expand Down Expand Up @@ -1287,9 +1286,7 @@ def _is_parameterized_(self) -> bool:
def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self.all_operations() for name in protocols.parameter_names(op)}

def _resolve_parameters_(
self: CIRCUIT_TYPE, resolver: 'cirq.ParamResolver', recursive: bool
) -> CIRCUIT_TYPE:
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> Self:
changed = False
resolved_moments: List['cirq.Moment'] = []
for moment in self:
Expand Down Expand Up @@ -1540,7 +1537,7 @@ def get_independent_qubit_sets(self) -> List[Set['cirq.Qid']]:
uf.union(*op.qubits)
return sorted([qs for qs in uf.to_sets()], key=min)

def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
def factorize(self) -> Iterable[Self]:
"""Factorize circuit into a sequence of independent circuits (factors).
Factorization is possible when the circuit's qubits can be divided
Expand Down
9 changes: 3 additions & 6 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Self

import numpy as np

Expand All @@ -52,8 +52,6 @@
"text_diagram_drawer", globals(), "cirq.circuits.text_diagram_drawer"
)

TSelf_Moment = TypeVar('TSelf_Moment', bound='Moment')


def _default_breakdown(qid: 'cirq.Qid') -> Tuple[Any, Any]:
# Attempt to convert into a position on the complex plane.
Expand Down Expand Up @@ -373,9 +371,8 @@ def _decompose_(self) -> 'cirq.OP_TREE':
return self._operations

def transform_qubits(
self: TSelf_Moment,
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
) -> TSelf_Moment:
self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']]
) -> Self:
"""Returns the same moment, but with different qubits.
Args:
Expand Down
20 changes: 9 additions & 11 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TypeVar, TYPE_CHECKING, Union

import abc
import functools
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union
from typing_extensions import Self

import numpy as np

Expand All @@ -24,8 +24,6 @@
if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='_BaseGridQid')


@functools.total_ordering
class _BaseGridQid(ops.Qid):
Expand Down Expand Up @@ -69,13 +67,13 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseGridQ
return neighbors

@abc.abstractmethod
def _with_row_col(self: TSelf, row: int, col: int) -> TSelf:
def _with_row_col(self, row: int, col: int) -> Self:
"""Returns a qid with the same type but a different coordinate."""

def __complex__(self) -> complex:
return self.col + 1j * self.row

def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
def __add__(self, other: Union[Tuple[int, int], Self]) -> Self:
if isinstance(other, _BaseGridQid):
if self.dimension != other.dimension:
raise TypeError(
Expand All @@ -94,7 +92,7 @@ def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
)
return self._with_row_col(row=self.row + other[0], col=self.col + other[1])

def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
def __sub__(self, other: Union[Tuple[int, int], Self]) -> Self:
if isinstance(other, _BaseGridQid):
if self.dimension != other.dimension:
raise TypeError(
Expand All @@ -113,13 +111,13 @@ def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
)
return self._with_row_col(row=self.row - other[0], col=self.col - other[1])

def __radd__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
def __radd__(self, other: Tuple[int, int]) -> Self:
return self + other

def __rsub__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
def __rsub__(self, other: Tuple[int, int]) -> Self:
return -self + other

def __neg__(self: TSelf) -> 'TSelf':
def __neg__(self) -> Self:
return self._with_row_col(row=-self.row, col=-self.col)


Expand Down
20 changes: 9 additions & 11 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TypeVar, TYPE_CHECKING, Union

import abc
import functools
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union
from typing_extensions import Self

from cirq import ops, protocols

if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='_BaseLineQid')


@functools.total_ordering
class _BaseLineQid(ops.Qid):
Expand Down Expand Up @@ -66,10 +64,10 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQ
return neighbors

@abc.abstractmethod
def _with_x(self: TSelf, x: int) -> TSelf:
def _with_x(self, x: int) -> Self:
"""Returns a qubit with the same type but a different value of `x`."""

def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
def __add__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
raise TypeError(
Expand All @@ -81,7 +79,7 @@ def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}")
return self._with_x(self.x + other)

def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
def __sub__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
raise TypeError(
Expand All @@ -95,13 +93,13 @@ def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
)
return self._with_x(self.x - other)

def __radd__(self: TSelf, other: int) -> TSelf:
def __radd__(self, other: int) -> Self:
return self + other

def __rsub__(self: TSelf, other: int) -> TSelf:
def __rsub__(self, other: int) -> Self:
return -self + other

def __neg__(self: TSelf) -> TSelf:
def __neg__(self) -> Self:
return self._with_x(-self.x)

def __complex__(self) -> complex:
Expand Down
10 changes: 4 additions & 6 deletions cirq-core/cirq/ops/arithmetic_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import abc
import itertools
from typing import Union, Iterable, List, Sequence, cast, Tuple, TypeVar, TYPE_CHECKING
from typing import Union, Iterable, List, Sequence, cast, Tuple, TYPE_CHECKING
from typing_extensions import Self

import numpy as np

Expand All @@ -25,9 +26,6 @@
import cirq


TSelfGate = TypeVar('TSelfGate', bound='ArithmeticGate')


class ArithmeticGate(Gate, metaclass=abc.ABCMeta):
r"""A helper gate for implementing reversible classical arithmetic.
Expand Down Expand Up @@ -55,7 +53,7 @@ class ArithmeticGate(Gate, metaclass=abc.ABCMeta):
...
... def with_registers(
... self, *new_registers: 'Union[int, Sequence[int]]'
... ) -> 'TSelfGate':
... ) -> 'Add':
... return Add(*new_registers)
...
... def apply(self, *register_values: int) -> 'Union[int, Iterable[int]]':
Expand Down Expand Up @@ -105,7 +103,7 @@ def registers(self) -> Sequence[Union[int, Sequence[int]]]:
raise NotImplementedError()

@abc.abstractmethod
def with_registers(self: TSelfGate, *new_registers: Union[int, Sequence[int]]) -> TSelfGate:
def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> Self:
"""Returns the same fate targeting different registers.
Args:
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def __init__(self, *, rads: value.TParamVal):
self._rads = rads
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)

def _with_exponent(self: 'Rx', exponent: value.TParamVal) -> 'Rx':
def _with_exponent(self, exponent: value.TParamVal) -> 'Rx':
return Rx(rads=exponent * _pi(exponent))

def _circuit_diagram_info_(
Expand Down Expand Up @@ -541,7 +541,7 @@ def __init__(self, *, rads: value.TParamVal):
self._rads = rads
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)

def _with_exponent(self: 'Ry', exponent: value.TParamVal) -> 'Ry':
def _with_exponent(self, exponent: value.TParamVal) -> 'Ry':
return Ry(rads=exponent * _pi(exponent))

def _circuit_diagram_info_(
Expand Down Expand Up @@ -891,7 +891,7 @@ def __init__(self, *, rads: value.TParamVal):
self._rads = rads
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)

def _with_exponent(self: 'Rz', exponent: value.TParamVal) -> 'Rz':
def _with_exponent(self, exponent: value.TParamVal) -> 'Rz':
return Rz(rads=exponent * _pi(exponent))

def _circuit_diagram_info_(
Expand Down
Loading

0 comments on commit 5c2bcda

Please sign in to comment.