Skip to content

Commit

Permalink
Add support for type aliases
Browse files Browse the repository at this point in the history
Summary:
Treat module-level assignments where the RHS is a type as equivalent to class declarations.

Type aliases stopped working when we decided not to store inferred types for
top-level variables, due to corner cases with `global` later mutating those
variables (see facebookincubator/cinder#116 for
discussion).

However, in this specific case, it should be safe to treat the aliases as
declared rather than inferred, since their values will be used by type
annotations before we run any functions that could modify them via `global`.

Reviewed By: alexmalyshev

Differential Revision: D69254419

fbshipit-source-id: d29a268ed15a035dbc1a3f4d1a2eb32280637e45
  • Loading branch information
Martin DeMello authored and facebook-github-bot committed Feb 7, 2025
1 parent fe9adb5 commit 0019b52
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 7 deletions.
10 changes: 10 additions & 0 deletions PythonLib/cinderx/compiler/static/declaration_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

import ast

from ast import (
AnnAssign,
Assign,
Expand Down Expand Up @@ -60,6 +62,10 @@ def declare_variable(self, node: AnnAssign, module: ModuleTable) -> None:
def declare_variables(self, node: Assign, module: ModuleTable) -> None:
pass

# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type
def declare_type_alias(self, node: ast.TypeAlias) -> None:
pass


TScopeTypes = Union[ModuleTable, Class, Function, NestedScope]

Expand Down Expand Up @@ -104,6 +110,10 @@ def visitAnnAssign(self, node: AnnAssign) -> None:
def visitAssign(self, node: Assign) -> None:
self.parent_scope().declare_variables(node, self.module)

# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type
def visitTypeAlias(self, node: ast.TypeAlias) -> None:
self.parent_scope().declare_type_alias(node)

def visitClassDef(self, node: ClassDef) -> None:
parent_scope = self.parent_scope()
qualname = make_qualname(parent_scope.qualname, node.name)
Expand Down
69 changes: 62 additions & 7 deletions PythonLib/cinderx/compiler/static/module_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,20 @@ def resolve_annotation(
# TODO until we support runtime checking of unions, we must for
# safety resolve union annotations to dynamic (except for
# optionals, which we can check at runtime)
if (
isinstance(klass, UnionType)
and klass is not self.type_env.union
and klass is not self.type_env.optional
and klass.opt_type is None
):
if self.is_unsupported_union_type(klass):
return None

return klass

def is_unsupported_union_type(self, klass: Value) -> bool:
"""Returns True if klass is an unsupported union type."""
return (
isinstance(klass, UnionType)
and klass is not self.type_env.union
and klass is not self.type_env.optional
and klass.opt_type is None
)

def visitSubscript(self, node: Subscript) -> Value | None:
target = self.resolve_annotation(node.value, is_declaration=True)
if target is None:
Expand Down Expand Up @@ -267,6 +271,7 @@ def __init__(
self.flags: set[ModuleFlag] = set()
self.decls: list[tuple[AST, str | None, Value | None]] = []
self.implicit_decl_names: set[str] = set()
self.type_alias_names: set[str] = set()
self.compile_non_static: set[AST] = set()
# {local-name: {(mod, qualname)}} for decl-time deps
self.decl_deps: dict[str, set[tuple[str, str]]] = {}
Expand Down Expand Up @@ -353,6 +358,52 @@ def error_context(self, node: AST | None) -> ContextManager[None]:
return nullcontext()
return self.compiler.error_sink.error_context(self.filename, node)

def maybe_set_type_alias(
self,
# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type
node: ast.Assign | ast.TypeAlias,
name: str,
*,
require_type: bool = False,
) -> None:
"""
Check if we are assigning a Class or Union value to a variable at
module scope, and if so, store it as a type alias.
"""
try:
value = self.resolve_type(node.value, name)
except SyntaxError:
# We should not crash here if we raise an error when analysing a
# top-level assignment.
# TODO: The SyntaxError is raised when we call ast.parse in
# AnnotationVisitor.visitConstant; we need to make that more
# robust.
value = None

if value and (
value.klass is self.compiler.type_env.type or isinstance(value, UnionType)
):
if self.ann_visitor.is_unsupported_union_type(value):
# While union types are currently unsupported by the static
# compiler, they are syntactically valid and should therefore
# not raise an error even if require_type is set.
self.implicit_decl_names.add(name)
else:
# Treat this similarly to a class declaration
self.decls.append((node, name, value))
self._children[name] = value
self.type_alias_names.add(name)
else:
# Treat the type as dynamic if it is an assignment,
# raise an error if it is an explicit type alias.
if require_type:
raise TypedSyntaxError(f"RHS of type alias {name} is not a type")
self.implicit_decl_names.add(name)

# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type
def declare_type_alias(self, node: ast.TypeAlias) -> None:
self.maybe_set_type_alias(node, node.name.id, require_type=True)

def declare_class(self, node: ClassDef, klass: Class) -> None:
if self.first_pass_done:
raise ModuleTableException(
Expand Down Expand Up @@ -456,7 +507,7 @@ def finish_bind(self) -> None:

def validate_overrides(self) -> None:
for _node, name, _value in self.decls:
if name is None:
if name is None or name in self.type_alias_names:
continue

child = self._children.get(name, None)
Expand Down Expand Up @@ -552,6 +603,10 @@ def declare_variable(self, node: ast.AnnAssign, module: ModuleTable) -> None:

def declare_variables(self, node: ast.Assign, module: ModuleTable) -> None:
targets = node.targets
if len(targets) == 1 and isinstance(targets[0], ast.Name):
# pyre-ignore[16]: `ast.expr` has no attribute `id`
return self.maybe_set_type_alias(node, targets[0].id)

for target in targets:
if isinstance(target, ast.Name):
self.implicit_decl_names.add(target.id)
Expand Down
8 changes: 8 additions & 0 deletions PythonLib/cinderx/compiler/static/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,10 @@ def declare_variable(self, node: AnnAssign, module: ModuleTable) -> None:
def declare_variables(self, node: Assign, module: ModuleTable) -> None:
pass

# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type.
def declare_type_alias(self, node: ast.TypeAlias) -> None:
pass

def reflected_method_types(self, type_env: TypeEnvironment) -> dict[str, Class]:
return {}

Expand Down Expand Up @@ -3887,6 +3891,10 @@ def declare_function(self, func: Function) -> None:
def declare_variables(self, node: Assign, module: ModuleTable) -> None:
pass

# pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type.
def declare_type_alias(self, node: ast.TypeAlias) -> None:
pass

def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Class | None
) -> NarrowingEffect:
Expand Down
174 changes: 174 additions & 0 deletions PythonLib/test_cinderx/test_compiler/test_static/type_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import sys
import unittest
from unittest import skipIf

from re import escape

from cinderx.compiler.static.types import TypedSyntaxError

from .common import StaticTestBase


@skipIf(sys.version_info < (3, 12), "New in 3.12")
class TypeAliasTests(StaticTestBase):

def test_assign(self):
codestr = """
class B: pass
A = B
def f(x: A):
pass
"""
with self.in_module(codestr) as mod:
mod.f(mod.B())
with self.assertRaises(TypeError):
mod.f("hello")

def test_alias(self):
codestr = """
type A = int
def f(x: A):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
with self.assertRaises(TypeError):
mod.f("hello")

def test_optional_assign(self):
codestr = """
A = int | None
def f(x: A):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_optional_alias(self):
codestr = """
type A = int | None
def f(x: A):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_transitive_alias(self):
codestr = """
type A = int | None
type B = A
def f(x: B):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_transitive_assign(self):
codestr = """
A = int | None
B = A
def f(x: B):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_transitive_alias_and_assign(self):
codestr = """
A = int | None
type B = A
def f(x: B):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_transitive_alias_in_optional(self):
codestr = """
A = int
type B = A | None
def f(x: B):
pass
"""
with self.in_module(codestr) as mod:
mod.f(42)
mod.f(None)
with self.assertRaises(TypeError):
mod.f("hello")

def test_alias_check_in_module(self):
codestr = """
class B: pass
A = B
def f(x: A):
pass
f("hello")
"""
with self.assertRaises(TypedSyntaxError):
with self.in_module(codestr):
pass

def test_overload(self):
codestr = """
A = int
class B:
def f(self, x: A):
pass
class D(B):
def f(self, x: int):
super().f(x)
D().f(10)
"""
with self.in_module(codestr):
pass

def test_type_alias_error(self):
codestr = """
type A = 42
"""
with self.assertRaisesRegex(TypedSyntaxError, "A is not a type"):
with self.in_module(codestr):
pass

def test_assign_to_constructor(self):
# Regression test for a crash when calling resolve_type on the rhs of B
codestr = """
B = str("<unknown>")
"""
with self.in_module(codestr):
pass


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@
"test_cinderx.test_compiler.test_static.super",
"test_cinderx.test_compiler.test_static.sys_hexversion",
"test_cinderx.test_compiler.test_static.top_level",
"test_cinderx.test_compiler.test_static.type_alias",
"test_cinderx.test_compiler.test_static.type_params",
"test_cinderx.test_compiler.test_static.typed_dict",
"test_cinderx.test_compiler.test_static.union",
Expand Down
1 change: 1 addition & 0 deletions TestScripts/3.12-python-expected-tests-dev-nosan.json
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@
"test_cinderx.test_compiler.test_static.super",
"test_cinderx.test_compiler.test_static.sys_hexversion",
"test_cinderx.test_compiler.test_static.top_level",
"test_cinderx.test_compiler.test_static.type_alias",
"test_cinderx.test_compiler.test_static.type_params",
"test_cinderx.test_compiler.test_static.typed_dict",
"test_cinderx.test_compiler.test_static.union",
Expand Down
1 change: 1 addition & 0 deletions TestScripts/3.12-python-expected-tests-dev.json
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@
"test_cinderx.test_compiler.test_static.super",
"test_cinderx.test_compiler.test_static.sys_hexversion",
"test_cinderx.test_compiler.test_static.top_level",
"test_cinderx.test_compiler.test_static.type_alias",
"test_cinderx.test_compiler.test_static.type_params",
"test_cinderx.test_compiler.test_static.typed_dict",
"test_cinderx.test_compiler.test_static.union",
Expand Down
1 change: 1 addition & 0 deletions TestScripts/3.12-python-expected-tests-opt.json
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@
"test_cinderx.test_compiler.test_static.super",
"test_cinderx.test_compiler.test_static.sys_hexversion",
"test_cinderx.test_compiler.test_static.top_level",
"test_cinderx.test_compiler.test_static.type_alias",
"test_cinderx.test_compiler.test_static.type_params",
"test_cinderx.test_compiler.test_static.typed_dict",
"test_cinderx.test_compiler.test_static.union",
Expand Down

0 comments on commit 0019b52

Please sign in to comment.