Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style[next]: more strict typing #1494

Merged
merged 9 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,16 @@ warn_unused_ignores = false

[[tool.mypy.overrides]]
# # TODO: this should be changed to true after a transition period
disallow_incomplete_defs = false
disallow_incomplete_defs = true
module = 'gt4py.next.*'

[[tool.mypy.overrides]]
# TODO: temporarily to propagate it to all of next
disallow_incomplete_defs = true
module = 'gt4py.next.ffront.*'
disallow_incomplete_defs = false
module = 'gt4py.next.iterator.*'

[[tool.mypy.overrides]]
disallow_incomplete_defs = false
module = 'gt4py.next.program_processors.runners.dace_iterator.*'

[[tool.mypy.overrides]]
ignore_errors = true
Expand Down
17 changes: 17 additions & 0 deletions src/gt4py/eve/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,23 @@ def generic_dump(cls, node: RootNode, **kwargs: Any) -> str:
"""
return str(node)

@overload
def generic_visit(self, node: Node, **kwargs: Any) -> str: ...

@overload
def generic_visit(
self,
node: Union[
list,
tuple,
collections.abc.Set,
collections.abc.Sequence,
dict,
collections.abc.Mapping,
],
**kwargs: Any,
) -> Collection[str]: ...

def generic_visit(self, node: RootNode, **kwargs: Any) -> Union[str, Collection[str]]:
if isinstance(node, Node):
template, key = self.get_template(node)
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import itertools
import operator

from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.eve.extended_typing import Any, Generator, Optional, Sequence, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions

Expand Down Expand Up @@ -148,7 +148,9 @@ def restrict_to_intersection(
)


def iterate_domain(domain: common.Domain):
def iterate_domain(
domain: common.Domain,
) -> Generator[tuple[tuple[common.Dimension, int], ...], None, None]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))

Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import contextlib
import contextvars as cvars
from typing import Any
from typing import Any, Generator

import gt4py.eve as eve
import gt4py.next.common as common
Expand All @@ -39,7 +39,7 @@ def new_context(
*,
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
):
) -> Generator[cvars.Context, None, None]:
import gt4py.next.embedded.context as this_module

updates: list[tuple[cvars.ContextVar[Any], Any]] = []
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArra


def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
builtin_name: str, array_builtin_name: str, reverse: bool = False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
cls_ = _get_nd_array_class(*fields)
Expand Down
23 changes: 14 additions & 9 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar
from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar

import numpy as np

Expand Down Expand Up @@ -46,7 +46,7 @@ def __call__( # type: ignore[override]
self,
*args: common.Field | core_defs.Scalar,
**kwargs: common.Field | core_defs.Scalar, # type: ignore[override]
) -> common.Field:
) -> common.Field | tuple[common.Field | tuple, ...]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be something like :

Suggested change
) -> common.Field | tuple[common.Field | tuple, ...]:
) -> common.Field[Any, _R] | tuple[common.Field[Any, _R] | tuple, ...]:

to bound the _R typevar to this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init was missing the correct typevar and here it gets a bit more complicated, see solution.

scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
Expand Down Expand Up @@ -91,7 +91,7 @@ def _get_out_domain(
])


def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> Optional[_R]:
if "out" in kwargs:
# called from program or direct field_operator as program
new_context_kwargs = {}
Expand All @@ -118,9 +118,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
res = ctx.run(op, *args, **kwargs)
_tuple_assign_field(
out,
res,
res, # type: ignore[arg-type] # maybe can't inferred properly because decorator.py is not properly typed yet
domain=out_domain,
)
return None
else:
# called from other field_operator or missing `out` argument
if "offset_provider" in kwargs:
Expand All @@ -139,9 +140,9 @@ def _tuple_assign_field(
target: tuple[common.MutableField | tuple, ...] | common.MutableField,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: common.Domain,
):
) -> None:
@utils.tree_map
def impl(target: common.MutableField, source: common.Field):
def impl(target: common.MutableField, source: common.Field) -> None:
if common.is_field(source):
target[domain] = source[domain]
else:
Expand Down Expand Up @@ -169,8 +170,12 @@ def _get_array_ns(


def _construct_scan_array(
domain: common.Domain, xp: ModuleType
): # TODO(havogt) introduce a NDArrayNamespace protocol
domain: common.Domain,
xp: ModuleType, # TODO(havogt) introduce a NDArrayNamespace protocol
) -> Callable[
[core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]],
common.Field | tuple[common.Field | tuple, ...],
]:
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.Field:
return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)
Expand All @@ -184,7 +189,7 @@ def _tuple_assign_value(
source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...],
) -> None:
@utils.tree_map
def impl(target: common.MutableField, source: core_defs.Scalar):
def impl(target: common.MutableField, source: core_defs.Scalar) -> None:
target[pos] = source

impl(target, source)
Expand Down
7 changes: 5 additions & 2 deletions src/gt4py/next/errors/excepthook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"""

import sys
from typing import Callable
import types
from typing import Callable, Optional

from gt4py.next import config

Expand All @@ -41,7 +42,9 @@ def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -
return formatting.format_compilation_error(type(err), err.message, err.location)


def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None:
def compilation_error_hook(
fallback: Callable, type_: type, value: BaseException, tb: Optional[types.TracebackType]
) -> None:
"""
Format `CompilationError`s in a neat way.

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,6 @@ def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
Expand All @@ -891,6 +890,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)
assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

return foast.Call(
func=node.func,
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType:
# TODO: we want some generic field type here, but our type system does not support it yet.
return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype)

return type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True)
res = type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True)
assert isinstance(res, (ts.FieldType, ts.TupleType))
return res


@type_info.function_signature_incompatibilities.register
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/otf/binding/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gt4py.next.otf import languages


def format_source(settings: languages.LanguageSettings, source):
def format_source(settings: languages.LanguageSettings, source: str) -> str:
return codegen.format_source(settings.formatter_key, source, style=settings.formatter_style)


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/otf/binding/nanobind.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ class BindingCodeGenerator(TemplatedGenerator):

BindingFunction = as_jinja("""module.def("{{exported_name}}", &{{wrapper_name}}, "{{doc}}");""")

def visit_FunctionCall(self, call: FunctionCall):
def visit_FunctionCall(self, call: FunctionCall) -> str:
args = [self.visit(arg) for arg in call.args]
return cpp_interface.render_function_call(call.target, args)

def visit_BufferSID(self, sid: BufferSID, **kwargs):
def visit_BufferSID(self, sid: BufferSID, **kwargs: Any) -> str:
pybuffer = f"{sid.source_buffer}.first"
dims = [self.visit(dim) for dim in sid.dimensions]
origin = f"{sid.source_buffer}.second"
Expand All @@ -158,7 +158,7 @@ def visit_BufferSID(self, sid: BufferSID, **kwargs):
renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})"
return renamed

def visit_CompositeSID(self, node: CompositeSID, **kwargs):
def visit_CompositeSID(self, node: CompositeSID, **kwargs: Any) -> str:
kwargs["composite_ids"] = (
f"gridtools::integral_constant<int,{i}>" for i in range(len(node.elems))
)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/otf/compilation/build_systems/cmake_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator):
"""
)

def visit_FindDependency(self, dep: FindDependency):
def visit_FindDependency(self, dep: FindDependency) -> str:
# TODO(ricoh): do not add more libraries here
# and do not use this design in a new build system.
# Instead, design this to be extensible (refer to ADR-0016).
Expand All @@ -103,7 +103,7 @@ def visit_FindDependency(self, dep: FindDependency):
case _:
raise ValueError(f"Library '{dep.name}' is not supported")

def visit_LinkDependency(self, dep: LinkDependency):
def visit_LinkDependency(self, dep: LinkDependency) -> str:
# TODO(ricoh): do not add more libraries here
# and do not use this design in a new build system.
# Instead, design this to be extensible (refer to ADR-0016).
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/otf/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def build(self) -> None: ...
class CompiledProgram(Protocol):
"""Executable python representation of a program."""

def __call__(self, *args, **kwargs) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> None: ...


def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]:
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str:
TernaryExpr = as_fmt("({cond}?{true_expr}:{false_expr})")
CastExpr = as_fmt("static_cast<{new_dtype}>({obj_expr})")

def visit_TaggedValues(self, node: gtfn_ir.TaggedValues, **kwargs):
def visit_TaggedValues(self, node: gtfn_ir.TaggedValues, **kwargs: Any) -> str:
tags = self.visit(node.tags)
values = self.visit(node.values)
if self.is_cartesian:
Expand All @@ -135,7 +135,7 @@ def visit_OffsetLiteral(self, node: gtfn_ir.OffsetLiteral, **kwargs: Any) -> str
"::gridtools::sid::composite::keys<${','.join(f'::gridtools::integral_constant<int,{i}>' for i in range(len(values)))}>::make_values(${','.join(values)})"
)

def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> str:
if (
isinstance(node.fun, gtfn_ir_common.SymRef)
and node.fun.id in self.user_defined_function_ids
Expand Down Expand Up @@ -179,7 +179,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
"""
)

def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs):
def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs: Any) -> str:
expr_ = "return " + self.visit(node.expr)
return self.generic_visit(node, expr_=expr_)

Expand Down Expand Up @@ -330,14 +330,14 @@ class GTFNIMCodegen(GTFNCodegen):

ReturnStmt = as_fmt("return {ret};")

def visit_Conditional(self, node: gtfn_im_ir.Conditional, **kwargs):
def visit_Conditional(self, node: gtfn_im_ir.Conditional, **kwargs: Any) -> str:
if_rhs_ = self.visit(node.if_stmt.rhs)
else_rhs_ = self.visit(node.else_stmt.rhs)
return self.generic_visit(node, if_rhs_=if_rhs_, else_rhs_=else_rhs_)

def visit_ImperativeFunctionDefinition(
self, node: gtfn_im_ir.ImperativeFunctionDefinition, **kwargs
):
self, node: gtfn_im_ir.ImperativeFunctionDefinition, **kwargs: Any
) -> str:
expr_ = "".join(self.visit(stmt) for stmt in node.fun)
return self.generic_visit(node, expr_=expr_)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ class GTFN_IM_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
# stable across multiple runs (required for caching to properly work)
uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator)

def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs):
def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs: Any) -> gtfn_ir_common.SymRef:
if "localized_symbols" in kwargs and node.id in kwargs["localized_symbols"]:
return gtfn_ir_common.SymRef(id=kwargs["localized_symbols"][node.id])
return node

def commit_args(self, node: gtfn_ir.FunCall, tmp_id: str, fun_id: str, **kwargs):
def commit_args(
self, node: gtfn_ir.FunCall, tmp_id: str, fun_id: str, **kwargs: Any
) -> gtfn_ir.FunCall:
for i, arg in enumerate(node.args):
expr = self.visit(arg, **kwargs)
self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{tmp_id}_{i}"), rhs=expr))
Expand All @@ -182,8 +184,8 @@ def _expand_lambda(
new_args: List[gtfn_ir.FunCall],
red_idx: str,
max_neighbors: int,
**kwargs,
):
**kwargs: Any,
) -> None:
fun, init = node.fun.args # type: ignore
param_to_args = dict(zip([param.id for param in fun.params[1:]], new_args))
acc = fun.params[0]
Expand Down Expand Up @@ -212,8 +214,8 @@ def _expand_symref(
new_args: List[gtfn_ir.FunCall],
red_idx: str,
max_neighbors: int,
**kwargs,
):
**kwargs: Any,
) -> None:
fun, init = node.fun.args # type: ignore

red_lit = gtfn_ir_common.Sym(id=f"{red_idx}")
Expand All @@ -232,7 +234,7 @@ def _expand_symref(
)
self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs))

def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs):
def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef:
offset_provider = kwargs["offset_provider"]
assert offset_provider is not None

Expand All @@ -258,7 +260,7 @@ def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs):

return gtfn_ir_common.SymRef(id=red_idx)

def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr:
if any(
isinstance(
arg,
Expand Down Expand Up @@ -309,7 +311,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
args=[self.visit(arg, **kwargs) for arg in node.args],
)

def visit_TernaryExpr(self, node: gtfn_ir.TernaryExpr, **kwargs):
def visit_TernaryExpr(self, node: gtfn_ir.TernaryExpr, **kwargs: Any) -> gtfn_ir_common.SymRef:
cond = self.visit(node.cond, **kwargs)
if_ = self.visit(node.true_expr, **kwargs)
else_ = self.visit(node.false_expr, **kwargs)
Expand Down
Loading
Loading