Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Fix shift / remap lowering #1231

Merged
merged 7 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class TernaryExpr(Expr):


class Call(Expr):
func: Name
func: Expr
args: list[Expr]
kwargs: dict[str, Expr]

Expand Down
56 changes: 33 additions & 23 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,27 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
arg_types = [arg.type for arg in new_args]
kwarg_types = {name: arg.type for name, arg in new_kwargs.items()}

if isinstance(
new_func.type,
(ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType),
):
# Since we use the `id` attribute in the latter part of the toolchain ensure we
# have the proper format here.
if not isinstance(
new_func,
(foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name),
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg="Functions can only be called directly!"
)
elif isinstance(new_func.type, ts.FieldType):
pass
else:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.",
)

# ensure signature is valid
try:
type_info.accepts_args(
Expand All @@ -556,7 +577,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
node, msg=f"Invalid argument types in call to `{new_func}`!"
) from err

return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types)
Expand All @@ -571,34 +592,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:

if (
isinstance(new_func.type, ts.FunctionType)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.MATH_BUILTIN_NAMES
):
return self._visit_math_built_in(new_node, **kwargs)
elif (
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)

return new_node

def _ensure_signature_valid(self, node: foast.Call, **kwargs) -> None:
try:
type_info.accepts_args(
cast(ts.FunctionType, node.func.type),
with_args=[arg.type for arg in node.args],
with_kwargs={keyword: arg.type for keyword, arg in node.kwargs.items()},
raise_exception=True,
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
) from err

def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call:
func_name = node.func.id
func_name = cast(foast.Name, node.func).id

# validate arguments
error_msg_preamble = f"Incompatible argument in call to `{func_name}`."
Expand Down Expand Up @@ -670,7 +680,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call:
field_dims_str = ", ".join(str(dim) for dim in field_type.dims)
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible field argument in call to `{node.func.id}`. "
msg=f"Incompatible field argument in call to `{str(node.func)}`. "
f"Expected a field with dimension {reduction_dim}, but got "
f"{field_dims_str}.",
)
Expand Down Expand Up @@ -726,15 +736,15 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_integral(arg_1):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"Excepted integer for offset field dtype, but got {arg_1.dtype}"
f"{node.location}",
)

if arg_0.source not in arg_1.dims:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. "
f"{node.location}",
)
Expand All @@ -755,7 +765,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_logical(mask_type):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. Expected "
msg=f"Incompatible argument in call to `{str(node.func)}`. Expected "
f"a field with dtype bool, but got `{mask_type}`.",
)

Expand All @@ -773,7 +783,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Return arguments need to be of same type in {node.func.id}, but got: "
msg=f"Return arguments need to be of same type in {str(node.func)}, but got: "
f"{node.args[1].type} and {node.args[2].type}",
)
else:
Expand All @@ -785,7 +795,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
except GTTypeError as ex:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`.",
msg=f"Incompatible argument in call to `{str(node.func)}`.",
) from ex

return foast.Call(
Expand All @@ -803,7 +813,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimension type in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected "
f"all broadcast dimensions to be of type Dimension.",
)

Expand All @@ -812,7 +822,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimensions in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected "
f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}",
)

Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return shift_offset(self.visit(node.func, **kwargs))
return im.lift_(im.lambda__("it")(im.deref_(shift_offset("it"))))(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reviewer this is the core of this PR.

self.visit(node.func, **kwargs)
)

def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
return self._visit_shift(node, **kwargs)
elif node.func.id in MATH_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES:
return self._visit_math_built_in(node, **kwargs)
elif node.func.id in FUN_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES:
visitor = getattr(self, f"_visit_{node.func.id}")
return visitor(node, **kwargs)
elif node.func.id in TYPE_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES:
return self._visit_type_constr(node, **kwargs)
elif isinstance(
node.func.type,
Expand Down
19 changes: 8 additions & 11 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,14 @@ def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call:
if not isinstance(node.func, ast.Name):
raise FieldOperatorSyntaxError.from_AST(
node, msg="Functions can only be called directly!"
)

func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

return foast.Call(
func=self.visit(node.func, **kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def debug_itir(tree):

Vertex = Dimension("Vertex")
Edge = Dimension("Edge")
Cell = Dimension("Cell")
EdgeOffset = FieldOffset("EdgeOffset", source=Edge, target=(Edge,))

size = 10
Expand All @@ -65,10 +66,11 @@ def debug_itir(tree):
@pytest.fixture
def reduction_setup():
num_vertices = 9
edge = Dimension("Edge")
vertex = Dimension("Vertex")
num_cells = 8
v2edim = Dimension("V2E", kind=DimensionKind.LOCAL)
e2vdim = Dimension("E2V", kind=DimensionKind.LOCAL)
c2vdim = Dimension("C2V", kind=DimensionKind.LOCAL)
c2edim = Dimension("C2E", kind=DimensionKind.LOCAL)

v2e_arr = np.array(
[
Expand All @@ -84,6 +86,32 @@ def reduction_setup():
]
)

c2v_arr = np.array(
[
[0, 1, 4, 3],
[1, 2, 5, 6],
[3, 4, 7, 6],
[4, 5, 8, 7],
[6, 7, 1, 0],
[7, 8, 2, 1],
[2, 0, 3, 5],
[5, 3, 6, 8],
]
)

c2e_arr = np.array(
[
[0, 10, 3, 9],
[1, 11, 4, 10],
[3, 13, 6, 12],
[4, 14, 7, 13],
[6, 16, 0, 15],
[7, 17, 1, 16],
[2, 9, 5, 11],
[5, 12, 8, 14],
]
)

# create e2v connectivity by inverting v2e
num_edges = np.max(v2e_arr) + 1
e2v_arr = [[] for _ in range(0, num_edges)]
Expand All @@ -98,12 +126,15 @@ def reduction_setup():
[
"num_vertices",
"num_edges",
"Edge",
"Vertex",
"num_cells",
"V2EDim",
"E2VDim",
"C2VDim",
"C2EDim",
"V2E",
"E2V",
"C2V",
"C2E",
"inp",
"out",
"offset_provider",
Expand All @@ -113,18 +144,23 @@ def reduction_setup():
)(
num_vertices=num_vertices,
num_edges=num_edges,
Edge=edge,
Vertex=vertex,
num_cells=num_cells,
V2EDim=v2edim,
E2VDim=e2vdim,
V2E=FieldOffset("V2E", source=edge, target=(vertex, v2edim)),
E2V=FieldOffset("E2V", source=vertex, target=(edge, e2vdim)),
C2VDim=c2vdim,
C2EDim=c2edim,
V2E=FieldOffset("V2E", source=Edge, target=(Vertex, v2edim)),
E2V=FieldOffset("E2V", source=Vertex, target=(Edge, e2vdim)),
C2V=FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)),
C2E=FieldOffset("C2E", source=Edge, target=(Cell, c2edim)),
# inp=index_field(edge, dtype=np.int64), # TODO enable once we support index_fields in bindings
inp=np_as_located_field(edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(vertex)(np.zeros([num_vertices], dtype=np.int64)),
inp=np_as_located_field(Edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(Vertex)(np.zeros([num_vertices], dtype=np.int64)),
offset_provider={
"V2E": NeighborTableOffsetProvider(v2e_arr, vertex, edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, edge, vertex, 2, has_skip_values=False),
"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False),
"C2V": NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False),
"C2E": NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False),
},
v2e_table=v2e_arr,
e2v_table=e2v_arr,
Expand Down
Loading