From 41d70711798780c04d27fd1f581c962c633ff67d Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Sat, 18 Nov 2023 21:43:10 +0100 Subject: [PATCH] stubgen: use PEP 604 unions everywhere Fixes #12920 --- mypy/stubgen.py | 22 ++++++++++------------ mypy/stubutil.py | 25 +++++++++++++++++-------- test-data/unit/stubgen.test | 30 ++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 837cd723c4108..2e1fcaf5b13f2 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -22,7 +22,7 @@ => Generate out/urllib/parse.pyi. $ stubgen -p urllib - => Generate stubs for whole urlib package (recursively). + => Generate stubs for whole urllib package (recursively). For C modules, you can get more precise function signatures by parsing .rst (Sphinx) documentation for extra information. For this, use the --doc-dir option: @@ -305,6 +305,13 @@ def visit_str_expr(self, node: StrExpr) -> str: return repr(node.value) def visit_index_expr(self, node: IndexExpr) -> str: + base_fullname = self.stubgen.get_fullname(node.base) + if base_fullname == "typing.Union": + if isinstance(node.index, TupleExpr): + return " | ".join([item.accept(self) for item in node.index.items]) + return node.index.accept(self) + if base_fullname == "typing.Optional": + return f"{node.index.accept(self)} | None" base = node.base.accept(self) index = node.index.accept(self) if len(index) > 2 and index.startswith("(") and index.endswith(")"): @@ -675,7 +682,7 @@ def process_decorator(self, o: Decorator) -> None: self.add_decorator(qualname, require_name=False) def get_fullname(self, expr: Expression) -> str: - """Return the full name resolving imports and import aliases.""" + """Return the expression's full name.""" if ( self.analyzed and isinstance(expr, (NameExpr, MemberExpr)) @@ -684,16 +691,7 @@ def get_fullname(self, expr: Expression) -> str: ): return expr.fullname name = get_qualified_name(expr) - if "." not in name: - real_module = self.import_tracker.module_for.get(name) - real_short = self.import_tracker.reverse_alias.get(name, name) - if real_module is None and real_short not in self.defined_names: - real_module = "builtins" # not imported and not defined, must be a builtin - else: - name_module, real_short = name.split(".", 1) - real_module = self.import_tracker.reverse_alias.get(name_module, name_module) - resolved_name = real_short if real_module is None else f"{real_module}.{real_short}" - return resolved_name + return self.resolve_name(name) def visit_class_def(self, o: ClassDef) -> None: self._current_class = o diff --git a/mypy/stubutil.py b/mypy/stubutil.py index 5ec2400871453..ed34759b5732b 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -226,6 +226,11 @@ def visit_any(self, t: AnyType) -> str: def visit_unbound_type(self, t: UnboundType) -> str: s = t.name + fullname = self.stubgen.resolve_name(s) + if fullname == "typing.Union": + return " | ".join([item.accept(self) for item in t.args]) + if fullname == "typing.Optional": + return f"{t.args[0].accept(self)} | None" if self.known_modules is not None and "." in s: # see if this object is from any of the modules that we're currently processing. # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo". @@ -580,14 +585,18 @@ def __init__( def get_sig_generators(self) -> list[SignatureGenerator]: return [] - def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool: - """Return True if the variable name identifies the same object as the given fullname(s).""" - if isinstance(fullname, tuple): - return any(self.refers_to_fullname(name, fname) for fname in fullname) - module, short = fullname.rsplit(".", 1) - return self.import_tracker.module_for.get(name) == module and ( - name == short or self.import_tracker.reverse_alias.get(name) == short - ) + def resolve_name(self, name: str) -> str: + """Return the full name resolving imports and import aliases.""" + if "." not in name: + real_module = self.import_tracker.module_for.get(name) + real_short = self.import_tracker.reverse_alias.get(name, name) + if real_module is None and real_short not in self.defined_names: + real_module = "builtins" # not imported and not defined, must be a builtin + else: + name_module, real_short = name.split(".", 1) + real_module = self.import_tracker.reverse_alias.get(name_module, name_module) + resolved_name = real_short if real_module is None else f"{real_module}.{real_short}" + return resolved_name def add_name(self, fullname: str, require: bool = True) -> str: """Add a name to be imported and return the name reference. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 2a43ce16383de..c6216ec96999a 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -4090,3 +4090,33 @@ from dataclasses import dataclass class X(missing.Base): a: int def __init__(self, *selfa_, a, **selfa__) -> None: ... + +[case testAlwaysUsePEP604Union] +import typing +import typing as t +from typing import Optional, Union, Optional as O, Union as U +import x + +union = Union[int, str] +bad_union = Union[int] +nested_union = Optional[Union[int, str]] +not_union = x.Union[int, str] +u = U[int, str] +o = O[int] + +def f1(a: Union["int", Optional[tuple[int, t.Optional[int]]]]) -> int: ... +def f2(a: typing.Union[int | x.Union[int, int], O[float]]) -> int: ... + +[out] +import x +from _typeshed import Incomplete + +union = int | str +bad_union = int +nested_union = int | str | None +not_union: Incomplete +u = int | str +o = int | None + +def f1(a: int | tuple[int, int | None] | None) -> int: ... +def f2(a: int | x.Union[int, int] | float | None) -> int: ...