Skip to content

Commit

Permalink
bug[next]: Fix shift / remap lowering (#1231)
Browse files Browse the repository at this point in the history
While working on something else I recognized a rather sneaky bug in the lowering of shifts from FOAST to ITIR. Contrary to the rest of the lowering we were not wrapping the shifts inside a lifted stencil, which can lead to bugs in user code. This is particularly devastating when executing with the gtfn backend in Release Mode, as you might just run into a segfault or incorrect results. I am actually rather surprised this did not surface earlier somewhere as something like this
```python
@field_operator(backend=fieldview_backend)
def composed_shift_unstructured_intermediate_result(
    inp: Field[[Vertex], float64]
) -> Field[[Cell], float64]:
    tmp = inp(E2V[0])
    return tmp(C2E[0])
```
lowers to the (incorrect) ITIR:
```
λ(inp) → ·(λ(tmp__0) → ⟪C2Eₒ, 0ₒ⟫(tmp__0))(⟪E2Vₒ, 0ₒ⟫(inp))
```
(Note how the ordering of the shifts is wrong). I had found this bug much earlier while working on (#965), but sadly the change went lost while merging (was very easy to overlook so @havogt and me bost missed it). What I did not recognize back then was that this bug could actually be triggered. I assumed it only occurred for expression like `field(E2V[0])(C2E[0])` which were not allowed (for unrelated reasons). For completeness this PR also adds support for such expressions.
  • Loading branch information
tehrengruber authored Apr 24, 2023
1 parent cb76afa commit 4998dec
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class TernaryExpr(Expr):


class Call(Expr):
func: Name
func: Expr
args: list[Expr]
kwargs: dict[str, Expr]

Expand Down
56 changes: 33 additions & 23 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,27 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
arg_types = [arg.type for arg in new_args]
kwarg_types = {name: arg.type for name, arg in new_kwargs.items()}

if isinstance(
new_func.type,
(ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType),
):
# Since we use the `id` attribute in the latter part of the toolchain ensure we
# have the proper format here.
if not isinstance(
new_func,
(foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name),
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg="Functions can only be called directly!"
)
elif isinstance(new_func.type, ts.FieldType):
pass
else:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.",
)

# ensure signature is valid
try:
type_info.accepts_args(
Expand All @@ -556,7 +577,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
node, msg=f"Invalid argument types in call to `{new_func}`!"
) from err

return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types)
Expand All @@ -571,34 +592,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:

if (
isinstance(new_func.type, ts.FunctionType)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.MATH_BUILTIN_NAMES
):
return self._visit_math_built_in(new_node, **kwargs)
elif (
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)

return new_node

def _ensure_signature_valid(self, node: foast.Call, **kwargs) -> None:
try:
type_info.accepts_args(
cast(ts.FunctionType, node.func.type),
with_args=[arg.type for arg in node.args],
with_kwargs={keyword: arg.type for keyword, arg in node.kwargs.items()},
raise_exception=True,
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
) from err

def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call:
func_name = node.func.id
func_name = cast(foast.Name, node.func).id

# validate arguments
error_msg_preamble = f"Incompatible argument in call to `{func_name}`."
Expand Down Expand Up @@ -670,7 +680,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call:
field_dims_str = ", ".join(str(dim) for dim in field_type.dims)
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible field argument in call to `{node.func.id}`. "
msg=f"Incompatible field argument in call to `{str(node.func)}`. "
f"Expected a field with dimension {reduction_dim}, but got "
f"{field_dims_str}.",
)
Expand Down Expand Up @@ -726,15 +736,15 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_integral(arg_1):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"Excepted integer for offset field dtype, but got {arg_1.dtype}"
f"{node.location}",
)

if arg_0.source not in arg_1.dims:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. "
f"{node.location}",
)
Expand All @@ -755,7 +765,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_logical(mask_type):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. Expected "
msg=f"Incompatible argument in call to `{str(node.func)}`. Expected "
f"a field with dtype bool, but got `{mask_type}`.",
)

Expand All @@ -773,7 +783,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Return arguments need to be of same type in {node.func.id}, but got: "
msg=f"Return arguments need to be of same type in {str(node.func)}, but got: "
f"{node.args[1].type} and {node.args[2].type}",
)
else:
Expand All @@ -785,7 +795,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
except GTTypeError as ex:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`.",
msg=f"Incompatible argument in call to `{str(node.func)}`.",
) from ex

return foast.Call(
Expand All @@ -803,7 +813,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimension type in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected "
f"all broadcast dimensions to be of type Dimension.",
)

Expand All @@ -812,7 +822,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimensions in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected "
f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}",
)

Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,19 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return shift_offset(self.visit(node.func, **kwargs))
return im.lift_(im.lambda__("it")(im.deref_(shift_offset("it"))))(
self.visit(node.func, **kwargs)
)

def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
return self._visit_shift(node, **kwargs)
elif node.func.id in MATH_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES:
return self._visit_math_built_in(node, **kwargs)
elif node.func.id in FUN_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES:
visitor = getattr(self, f"_visit_{node.func.id}")
return visitor(node, **kwargs)
elif node.func.id in TYPE_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES:
return self._visit_type_constr(node, **kwargs)
elif isinstance(
node.func.type,
Expand Down
19 changes: 8 additions & 11 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,14 @@ def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call:
if not isinstance(node.func, ast.Name):
raise FieldOperatorSyntaxError.from_AST(
node, msg="Functions can only be called directly!"
)

func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

return foast.Call(
func=self.visit(node.func, **kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def debug_itir(tree):

Vertex = Dimension("Vertex")
Edge = Dimension("Edge")
Cell = Dimension("Cell")
EdgeOffset = FieldOffset("EdgeOffset", source=Edge, target=(Edge,))

size = 10
Expand All @@ -65,10 +66,11 @@ def debug_itir(tree):
@pytest.fixture
def reduction_setup():
num_vertices = 9
edge = Dimension("Edge")
vertex = Dimension("Vertex")
num_cells = 8
v2edim = Dimension("V2E", kind=DimensionKind.LOCAL)
e2vdim = Dimension("E2V", kind=DimensionKind.LOCAL)
c2vdim = Dimension("C2V", kind=DimensionKind.LOCAL)
c2edim = Dimension("C2E", kind=DimensionKind.LOCAL)

v2e_arr = np.array(
[
Expand All @@ -84,6 +86,32 @@ def reduction_setup():
]
)

c2v_arr = np.array(
[
[0, 1, 4, 3],
[1, 2, 5, 6],
[3, 4, 7, 6],
[4, 5, 8, 7],
[6, 7, 1, 0],
[7, 8, 2, 1],
[2, 0, 3, 5],
[5, 3, 6, 8],
]
)

c2e_arr = np.array(
[
[0, 10, 3, 9],
[1, 11, 4, 10],
[3, 13, 6, 12],
[4, 14, 7, 13],
[6, 16, 0, 15],
[7, 17, 1, 16],
[2, 9, 5, 11],
[5, 12, 8, 14],
]
)

# create e2v connectivity by inverting v2e
num_edges = np.max(v2e_arr) + 1
e2v_arr = [[] for _ in range(0, num_edges)]
Expand All @@ -98,12 +126,15 @@ def reduction_setup():
[
"num_vertices",
"num_edges",
"Edge",
"Vertex",
"num_cells",
"V2EDim",
"E2VDim",
"C2VDim",
"C2EDim",
"V2E",
"E2V",
"C2V",
"C2E",
"inp",
"out",
"offset_provider",
Expand All @@ -113,18 +144,23 @@ def reduction_setup():
)(
num_vertices=num_vertices,
num_edges=num_edges,
Edge=edge,
Vertex=vertex,
num_cells=num_cells,
V2EDim=v2edim,
E2VDim=e2vdim,
V2E=FieldOffset("V2E", source=edge, target=(vertex, v2edim)),
E2V=FieldOffset("E2V", source=vertex, target=(edge, e2vdim)),
C2VDim=c2vdim,
C2EDim=c2edim,
V2E=FieldOffset("V2E", source=Edge, target=(Vertex, v2edim)),
E2V=FieldOffset("E2V", source=Vertex, target=(Edge, e2vdim)),
C2V=FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)),
C2E=FieldOffset("C2E", source=Edge, target=(Cell, c2edim)),
# inp=index_field(edge, dtype=np.int64), # TODO enable once we support index_fields in bindings
inp=np_as_located_field(edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(vertex)(np.zeros([num_vertices], dtype=np.int64)),
inp=np_as_located_field(Edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(Vertex)(np.zeros([num_vertices], dtype=np.int64)),
offset_provider={
"V2E": NeighborTableOffsetProvider(v2e_arr, vertex, edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, edge, vertex, 2, has_skip_values=False),
"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False),
"C2V": NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False),
"C2E": NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False),
},
v2e_table=v2e_arr,
e2v_table=e2v_arr,
Expand Down
Loading

0 comments on commit 4998dec

Please sign in to comment.