Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Disallow more invalid type expressions #16427

Merged
merged 1 commit into from
Feb 28, 2025

Conversation

AlexWaygood
Copy link
Member

Summary

As part of working on #16302, I'm auditing all uses of Type::to_instance(). I noticed that one use in Type::in_type_expression has a pretty clear bug, in that it treats Type::SubclassOf() the same as Type::ClassLiteral. That's incorrect: an object can only have a type[] type if it's a dynamic variable of some kind, and we should reject such variables in type expressions. The practical effect of this is that on main, we have this incorrect behaviour, where we should be reject y's type annotation altogether, but instead we accept it:

def _(x: type[int]):
    def f(y: x):
        reveal_type(y)  # revealed: int

This PR fixes the bug by making Type::in_type_expression() exhaustive over all Type variants and rejecting most of them when they appear in type expressions. (The method is not yet exhaustive over all KnownInstanceType variants; some of these would require special error messages, and it seemed out of the scope of this bugfix PR.)

Test Plan

Added mdtests.

Copy link
Contributor

@carljm carljm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@AlexWaygood AlexWaygood force-pushed the alex/reject-bad-type-exprs branch 2 times, most recently from 1f5f98b to 2809aa4 Compare February 28, 2025 09:59
@AlexWaygood AlexWaygood force-pushed the alex/reject-bad-type-exprs branch from 2809aa4 to e1bb7b1 Compare February 28, 2025 10:00
@AlexWaygood AlexWaygood enabled auto-merge (squash) February 28, 2025 10:01
@AlexWaygood AlexWaygood merged commit 09d0b22 into main Feb 28, 2025
20 checks passed
@AlexWaygood AlexWaygood deleted the alex/reject-bad-type-exprs branch February 28, 2025 10:04
dcreager added a commit that referenced this pull request Feb 28, 2025
* main:
  [red-knot] Switch to a handwritten parser for mdtest error assertions (#16422)
  [red-knot] Disallow more invalid type expressions (#16427)
  Bump version to Ruff 0.9.9 (#16434)
  Check `LinterSettings::preview` for version-related syntax errors (#16429)
  Avoid caching files with unsupported syntax errors (#16425)
  Prioritize "bug" label for changelog sections (#16433)
  [`flake8-copyright`] Add links to applicable options (`CPY001`) (#16421)
  Fix string-length limit in documentation for PYI054 (#16432)
  Show version-related syntax errors in the playground (#16419)
  Allow passing `ParseOptions` to inline tests (#16357)
  Bump version to 0.9.8 (#16414)
  [red-knot] Ignore surrounding whitespace when looking for `<!-- snapshot-diagnostics -->` directives in mdtests (#16380)
  Notify users for invalid client settings (#16361)
  Avoid indexing the project if `configurationPreference` is `editorOnly` (#16381)
dcreager added a commit that referenced this pull request Feb 28, 2025
* dcreager/alist-type:
  Remove unneeded shear override
  Update property test CI
  Move alist into red_knot_python_semantic
  [red-knot] Switch to a handwritten parser for mdtest error assertions (#16422)
  [red-knot] Disallow more invalid type expressions (#16427)
  Bump version to Ruff 0.9.9 (#16434)
  Check `LinterSettings::preview` for version-related syntax errors (#16429)
  Avoid caching files with unsupported syntax errors (#16425)
  Prioritize "bug" label for changelog sections (#16433)
  [`flake8-copyright`] Add links to applicable options (`CPY001`) (#16421)
  Fix string-length limit in documentation for PYI054 (#16432)
  Show version-related syntax errors in the playground (#16419)
  Allow passing `ParseOptions` to inline tests (#16357)
  Bump version to 0.9.8 (#16414)
  [red-knot] Ignore surrounding whitespace when looking for `<!-- snapshot-diagnostics -->` directives in mdtests (#16380)
  Notify users for invalid client settings (#16361)
  Avoid indexing the project if `configurationPreference` is `editorOnly` (#16381)
AlexWaygood added a commit that referenced this pull request Mar 11, 2025
)

## Summary

This PR fixes #16302.

The PR reworks `Type::to_instance()` to return `Option<Type>` rather
than `Type`. This reflects more accurately the fact that some variants
cannot be "turned into an instance", since they _already_ represent
instances of some kind. On `main`, we silently fallback to `Unknown` for
these variants, but this implicit behaviour can be somewhat surprising
and lead to unexpected bugs.

Returning `Option<Type>` rather than `Type` means that each callsite has
to account for the possibility that the type might already represent an
instance, and decide what to do about it.
In general, I think this increases the robustness of the code. Working
on this PR revealed two latent bugs in the code:
- One which has already been fixed by
#16427
- One which is fixed as part of #16608

I added special handling to `KnownClass::to_instance()`: If we fail to find one of these classes and the `test` feature is
_not_ enabled, we log a warning to the terminal saying that we failed to
find the class in typeshed and that we will be falling back to
`Type::Unknown`. A cache is maintained so that we record all classes
that we have already logged a warning for; we only log a warning for
failing to lookup a `KnownClass` if we know that it's the first time
we're looking it up.

## Test Plan

- All existing tests pass
- I ran the property tests via `QUICKCHECK_TESTS=1000000 cargo test
--release -p red_knot_python_semantic -- --ignored
types::property_tests::stable`

I also manually checked that warnings are appropriately printed to the
terminal when `KnownClass::to_instance()` falls back to `Unknown` and
the `test` feature is not enabled. To do this, I applied this diff to
the PR branch:

<details>
<summary>Patch deleting `int` and `str` from buitins</summary>

```diff
diff --git a/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi b/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
index 0a6dc57b0..86636a05b 100644
--- a/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
+++ b/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
@@ -228,111 +228,6 @@ _PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
 _NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20]
 _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0]  # noqa: Y026  # TODO: Use TypeAlias once mypy bugs are fixed
 
-class int:
-    @overload
-    def __new__(cls, x: ConvertibleToInt = ..., /) -> Self: ...
-    @overload
-    def __new__(cls, x: str | bytes | bytearray, /, base: SupportsIndex) -> Self: ...
-    def as_integer_ratio(self) -> tuple[int, Literal[1]]: ...
-    @Property
-    def real(self) -> int: ...
-    @Property
-    def imag(self) -> Literal[0]: ...
-    @Property
-    def numerator(self) -> int: ...
-    @Property
-    def denominator(self) -> Literal[1]: ...
-    def conjugate(self) -> int: ...
-    def bit_length(self) -> int: ...
-    if sys.version_info >= (3, 10):
-        def bit_count(self) -> int: ...
-
-    if sys.version_info >= (3, 11):
-        def to_bytes(
-            self, length: SupportsIndex = 1, byteorder: Literal["little", "big"] = "big", *, signed: bool = False
-        ) -> bytes: ...
-        @classmethod
-        def from_bytes(
-            cls,
-            bytes: Iterable[SupportsIndex] | SupportsBytes | ReadableBuffer,
-            byteorder: Literal["little", "big"] = "big",
-            *,
-            signed: bool = False,
-        ) -> Self: ...
-    else:
-        def to_bytes(self, length: SupportsIndex, byteorder: Literal["little", "big"], *, signed: bool = False) -> bytes: ...
-        @classmethod
-        def from_bytes(
-            cls,
-            bytes: Iterable[SupportsIndex] | SupportsBytes | ReadableBuffer,
-            byteorder: Literal["little", "big"],
-            *,
-            signed: bool = False,
-        ) -> Self: ...
-
-    if sys.version_info >= (3, 12):
-        def is_integer(self) -> Literal[True]: ...
-
-    def __add__(self, value: int, /) -> int: ...
-    def __sub__(self, value: int, /) -> int: ...
-    def __mul__(self, value: int, /) -> int: ...
-    def __floordiv__(self, value: int, /) -> int: ...
-    def __truediv__(self, value: int, /) -> float: ...
-    def __mod__(self, value: int, /) -> int: ...
-    def __divmod__(self, value: int, /) -> tuple[int, int]: ...
-    def __radd__(self, value: int, /) -> int: ...
-    def __rsub__(self, value: int, /) -> int: ...
-    def __rmul__(self, value: int, /) -> int: ...
-    def __rfloordiv__(self, value: int, /) -> int: ...
-    def __rtruediv__(self, value: int, /) -> float: ...
-    def __rmod__(self, value: int, /) -> int: ...
-    def __rdivmod__(self, value: int, /) -> tuple[int, int]: ...
-    @overload
-    def __pow__(self, x: Literal[0], /) -> Literal[1]: ...
-    @overload
-    def __pow__(self, value: Literal[0], mod: None, /) -> Literal[1]: ...
-    @overload
-    def __pow__(self, value: _PositiveInteger, mod: None = None, /) -> int: ...
-    @overload
-    def __pow__(self, value: _NegativeInteger, mod: None = None, /) -> float: ...
-    # positive __value -> int; negative __value -> float
-    # return type must be Any as `int | float` causes too many false-positive errors
-    @overload
-    def __pow__(self, value: int, mod: None = None, /) -> Any: ...
-    @overload
-    def __pow__(self, value: int, mod: int, /) -> int: ...
-    def __rpow__(self, value: int, mod: int | None = None, /) -> Any: ...
-    def __and__(self, value: int, /) -> int: ...
-    def __or__(self, value: int, /) -> int: ...
-    def __xor__(self, value: int, /) -> int: ...
-    def __lshift__(self, value: int, /) -> int: ...
-    def __rshift__(self, value: int, /) -> int: ...
-    def __rand__(self, value: int, /) -> int: ...
-    def __ror__(self, value: int, /) -> int: ...
-    def __rxor__(self, value: int, /) -> int: ...
-    def __rlshift__(self, value: int, /) -> int: ...
-    def __rrshift__(self, value: int, /) -> int: ...
-    def __neg__(self) -> int: ...
-    def __pos__(self) -> int: ...
-    def __invert__(self) -> int: ...
-    def __trunc__(self) -> int: ...
-    def __ceil__(self) -> int: ...
-    def __floor__(self) -> int: ...
-    def __round__(self, ndigits: SupportsIndex = ..., /) -> int: ...
-    def __getnewargs__(self) -> tuple[int]: ...
-    def __eq__(self, value: object, /) -> bool: ...
-    def __ne__(self, value: object, /) -> bool: ...
-    def __lt__(self, value: int, /) -> bool: ...
-    def __le__(self, value: int, /) -> bool: ...
-    def __gt__(self, value: int, /) -> bool: ...
-    def __ge__(self, value: int, /) -> bool: ...
-    def __float__(self) -> float: ...
-    def __int__(self) -> int: ...
-    def __abs__(self) -> int: ...
-    def __hash__(self) -> int: ...
-    def __bool__(self) -> bool: ...
-    def __index__(self) -> int: ...
-
 class float:
     def __new__(cls, x: ConvertibleToFloat = ..., /) -> Self: ...
     def as_integer_ratio(self) -> tuple[int, int]: ...
@@ -437,190 +332,6 @@ class _FormatMapMapping(Protocol):
 class _TranslateTable(Protocol):
     def __getitem__(self, key: int, /) -> str | int | None: ...
 
-class str(Sequence[str]):
-    @overload
-    def __new__(cls, object: object = ...) -> Self: ...
-    @overload
-    def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ...
-    @overload
-    def capitalize(self: LiteralString) -> LiteralString: ...
-    @overload
-    def capitalize(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def casefold(self: LiteralString) -> LiteralString: ...
-    @overload
-    def casefold(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def center(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def center(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    def count(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: ...
-    def endswith(
-        self, suffix: str | tuple[str, ...], start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /
-    ) -> bool: ...
-    @overload
-    def expandtabs(self: LiteralString, tabsize: SupportsIndex = 8) -> LiteralString: ...
-    @overload
-    def expandtabs(self, tabsize: SupportsIndex = 8) -> str: ...  # type: ignore[misc]
-    def find(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    @overload
-    def format(self: LiteralString, *args: LiteralString, **kwargs: LiteralString) -> LiteralString: ...
-    @overload
-    def format(self, *args: object, **kwargs: object) -> str: ...
-    def format_map(self, mapping: _FormatMapMapping, /) -> str: ...
-    def index(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def isalnum(self) -> bool: ...
-    def isalpha(self) -> bool: ...
-    def isascii(self) -> bool: ...
-    def isdecimal(self) -> bool: ...
-    def isdigit(self) -> bool: ...
-    def isidentifier(self) -> bool: ...
-    def islower(self) -> bool: ...
-    def isnumeric(self) -> bool: ...
-    def isprintable(self) -> bool: ...
-    def isspace(self) -> bool: ...
-    def istitle(self) -> bool: ...
-    def isupper(self) -> bool: ...
-    @overload
-    def join(self: LiteralString, iterable: Iterable[LiteralString], /) -> LiteralString: ...
-    @overload
-    def join(self, iterable: Iterable[str], /) -> str: ...  # type: ignore[misc]
-    @overload
-    def ljust(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def ljust(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    @overload
-    def lower(self: LiteralString) -> LiteralString: ...
-    @overload
-    def lower(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def lstrip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def lstrip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def partition(self: LiteralString, sep: LiteralString, /) -> tuple[LiteralString, LiteralString, LiteralString]: ...
-    @overload
-    def partition(self, sep: str, /) -> tuple[str, str, str]: ...  # type: ignore[misc]
-    if sys.version_info >= (3, 13):
-        @overload
-        def replace(
-            self: LiteralString, old: LiteralString, new: LiteralString, /, count: SupportsIndex = -1
-        ) -> LiteralString: ...
-        @overload
-        def replace(self, old: str, new: str, /, count: SupportsIndex = -1) -> str: ...  # type: ignore[misc]
-    else:
-        @overload
-        def replace(
-            self: LiteralString, old: LiteralString, new: LiteralString, count: SupportsIndex = -1, /
-        ) -> LiteralString: ...
-        @overload
-        def replace(self, old: str, new: str, count: SupportsIndex = -1, /) -> str: ...  # type: ignore[misc]
-    if sys.version_info >= (3, 9):
-        @overload
-        def removeprefix(self: LiteralString, prefix: LiteralString, /) -> LiteralString: ...
-        @overload
-        def removeprefix(self, prefix: str, /) -> str: ...  # type: ignore[misc]
-        @overload
-        def removesuffix(self: LiteralString, suffix: LiteralString, /) -> LiteralString: ...
-        @overload
-        def removesuffix(self, suffix: str, /) -> str: ...  # type: ignore[misc]
-
-    def rfind(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def rindex(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    @overload
-    def rjust(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def rjust(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    @overload
-    def rpartition(self: LiteralString, sep: LiteralString, /) -> tuple[LiteralString, LiteralString, LiteralString]: ...
-    @overload
-    def rpartition(self, sep: str, /) -> tuple[str, str, str]: ...  # type: ignore[misc]
-    @overload
-    def rsplit(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ...
-    @overload
-    def rsplit(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ...  # type: ignore[misc]
-    @overload
-    def rstrip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def rstrip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def split(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ...
-    @overload
-    def split(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ...  # type: ignore[misc]
-    @overload
-    def splitlines(self: LiteralString, keepends: bool = False) -> list[LiteralString]: ...
-    @overload
-    def splitlines(self, keepends: bool = False) -> list[str]: ...  # type: ignore[misc]
-    def startswith(
-        self, prefix: str | tuple[str, ...], start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /
-    ) -> bool: ...
-    @overload
-    def strip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def strip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def swapcase(self: LiteralString) -> LiteralString: ...
-    @overload
-    def swapcase(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def title(self: LiteralString) -> LiteralString: ...
-    @overload
-    def title(self) -> str: ...  # type: ignore[misc]
-    def translate(self, table: _TranslateTable, /) -> str: ...
-    @overload
-    def upper(self: LiteralString) -> LiteralString: ...
-    @overload
-    def upper(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def zfill(self: LiteralString, width: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def zfill(self, width: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    @staticmethod
-    @overload
-    def maketrans(x: dict[int, _T] | dict[str, _T] | dict[str | int, _T], /) -> dict[int, _T]: ...
-    @staticmethod
-    @overload
-    def maketrans(x: str, y: str, /) -> dict[int, int]: ...
-    @staticmethod
-    @overload
-    def maketrans(x: str, y: str, z: str, /) -> dict[int, int | None]: ...
-    @overload
-    def __add__(self: LiteralString, value: LiteralString, /) -> LiteralString: ...
-    @overload
-    def __add__(self, value: str, /) -> str: ...  # type: ignore[misc]
-    # Incompatible with Sequence.__contains__
-    def __contains__(self, key: str, /) -> bool: ...  # type: ignore[override]
-    def __eq__(self, value: object, /) -> bool: ...
-    def __ge__(self, value: str, /) -> bool: ...
-    @overload
-    def __getitem__(self: LiteralString, key: SupportsIndex | slice, /) -> LiteralString: ...
-    @overload
-    def __getitem__(self, key: SupportsIndex | slice, /) -> str: ...  # type: ignore[misc]
-    def __gt__(self, value: str, /) -> bool: ...
-    def __hash__(self) -> int: ...
-    @overload
-    def __iter__(self: LiteralString) -> Iterator[LiteralString]: ...
-    @overload
-    def __iter__(self) -> Iterator[str]: ...  # type: ignore[misc]
-    def __le__(self, value: str, /) -> bool: ...
-    def __len__(self) -> int: ...
-    def __lt__(self, value: str, /) -> bool: ...
-    @overload
-    def __mod__(self: LiteralString, value: LiteralString | tuple[LiteralString, ...], /) -> LiteralString: ...
-    @overload
-    def __mod__(self, value: Any, /) -> str: ...
-    @overload
-    def __mul__(self: LiteralString, value: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def __mul__(self, value: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    def __ne__(self, value: object, /) -> bool: ...
-    @overload
-    def __rmul__(self: LiteralString, value: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def __rmul__(self, value: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    def __getnewargs__(self) -> tuple[str]: ...
-
 class bytes(Sequence[int]):
```

</details>

And then ran red-knot on my
[typeshed-stats](https://github.com/AlexWaygood/typeshed-stats) project
using the command

```
cargo run -p red_knot -- check --project ../typeshed-stats --python-version="3.12" --verbose
```

I observed that the following logs were printed to the terminal, but
that each warning was only printed once (the desired behaviour):

```
INFO Python version: Python 3.12, platform: all
INFO Indexed 15 file(s)
INFO Could not find class `builtins.int` in typeshed on Python 3.12. Falling back to `Unknown` for the symbol instead.
INFO Could not find class `builtins.str` in typeshed on Python 3.12. Falling back to `Unknown` for the symbol instead.
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
red-knot Multi-file analysis & type inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants