Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Fix shift / remap lowering #1231

Merged
merged 7 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class TernaryExpr(Expr):


class Call(Expr):
func: Name
func: Expr
args: list[Expr]
kwargs: dict[str, Expr]

Expand Down
58 changes: 35 additions & 23 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,29 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
arg_types = [arg.type for arg in new_args]
kwarg_types = {name: arg.type for name, arg in new_kwargs.items()}

func_str_repr: str # just for error handling
if isinstance(
new_func.type,
(ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType),
):
# Since we use the `id` attribute in the latter part of the toolchain ensure we
# have the proper format here.
if not isinstance(
new_func,
(foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name),
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg="Functions can only be called directly!"
)
func_str_repr = new_func.id
elif isinstance(new_func.type, ts.FieldType):
func_str_repr = str(new_func)
else:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.",
)

# ensure signature is valid
try:
type_info.accepts_args(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet, including the type_info.accepts_args is confusing. How does that work in case of new_func.type == FieldType? Let's talk about it.

Expand All @@ -556,7 +579,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
node, msg=f"Invalid argument types in call to `{func_str_repr}`!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you have implemented the change that you wanted to make to avoid func_str_repr, it should be good to go.

) from err

return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types)
Expand All @@ -571,34 +594,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:

if (
isinstance(new_func.type, ts.FunctionType)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.MATH_BUILTIN_NAMES
):
return self._visit_math_built_in(new_node, **kwargs)
elif (
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)

return new_node

def _ensure_signature_valid(self, node: foast.Call, **kwargs) -> None:
try:
type_info.accepts_args(
cast(ts.FunctionType, node.func.type),
with_args=[arg.type for arg in node.args],
with_kwargs={keyword: arg.type for keyword, arg in node.kwargs.items()},
raise_exception=True,
)
except GTTypeError as err:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Invalid argument types in call to `{node.func.id}`!"
) from err

def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call:
func_name = node.func.id
func_name = cast(foast.Name, node.func).id

# validate arguments
error_msg_preamble = f"Incompatible argument in call to `{func_name}`."
Expand Down Expand Up @@ -670,7 +682,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call:
field_dims_str = ", ".join(str(dim) for dim in field_type.dims)
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible field argument in call to `{node.func.id}`. "
msg=f"Incompatible field argument in call to `{str(node.func)}`. "
f"Expected a field with dimension {reduction_dim}, but got "
f"{field_dims_str}.",
)
Expand Down Expand Up @@ -726,15 +738,15 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_integral(arg_1):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"Excepted integer for offset field dtype, but got {arg_1.dtype}"
f"{node.location}",
)

if arg_0.source not in arg_1.dims:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. "
msg=f"Incompatible argument in call to `{str(node.func)}`. "
f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. "
f"{node.location}",
)
Expand All @@ -755,7 +767,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
if not type_info.is_logical(mask_type):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`. Expected "
msg=f"Incompatible argument in call to `{str(node.func)}`. Expected "
f"a field with dtype bool, but got `{mask_type}`.",
)

Expand All @@ -773,7 +785,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Return arguments need to be of same type in {node.func.id}, but got: "
msg=f"Return arguments need to be of same type in {str(node.func)}, but got: "
f"{node.args[1].type} and {node.args[2].type}",
)
else:
Expand All @@ -785,7 +797,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
except GTTypeError as ex:
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible argument in call to `{node.func.id}`.",
msg=f"Incompatible argument in call to `{str(node.func)}`.",
) from ex

return foast.Call(
Expand All @@ -803,7 +815,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimension type in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected "
f"all broadcast dimensions to be of type Dimension.",
)

Expand All @@ -812,7 +824,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=f"Incompatible broadcast dimensions in {node.func.id}. Expected "
msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected "
f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}",
)

Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return shift_offset(self.visit(node.func, **kwargs))
return im.lift_(im.lambda__("it")(im.deref_(shift_offset("it"))))(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reviewer this is the core of this PR.

self.visit(node.func, **kwargs)
)

def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
return self._visit_shift(node, **kwargs)
elif node.func.id in MATH_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES:
return self._visit_math_built_in(node, **kwargs)
elif node.func.id in FUN_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES:
visitor = getattr(self, f"_visit_{node.func.id}")
return visitor(node, **kwargs)
elif node.func.id in TYPE_BUILTIN_NAMES:
elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES:
return self._visit_type_constr(node, **kwargs)
elif isinstance(
node.func.type,
Expand Down
19 changes: 8 additions & 11 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,14 @@ def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call:
if not isinstance(node.func, ast.Name):
raise FieldOperatorSyntaxError.from_AST(
node, msg="Functions can only be called directly!"
)

func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if func_name in fbuiltins.FUN_BUILTIN_NAMES:
self._verify_builtin_function(node)
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

return foast.Call(
func=self.visit(node.func, **kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def debug_itir(tree):

Vertex = Dimension("Vertex")
Edge = Dimension("Edge")
Cell = Dimension("Cell")
EdgeOffset = FieldOffset("EdgeOffset", source=Edge, target=(Edge,))

size = 10
Expand All @@ -65,10 +66,11 @@ def debug_itir(tree):
@pytest.fixture
def reduction_setup():
num_vertices = 9
edge = Dimension("Edge")
vertex = Dimension("Vertex")
num_cells = 8
v2edim = Dimension("V2E", kind=DimensionKind.LOCAL)
e2vdim = Dimension("E2V", kind=DimensionKind.LOCAL)
c2vdim = Dimension("C2V", kind=DimensionKind.LOCAL)
c2edim = Dimension("C2E", kind=DimensionKind.LOCAL)

v2e_arr = np.array(
[
Expand All @@ -84,6 +86,32 @@ def reduction_setup():
]
)

c2v_arr = np.array(
[
[0, 1, 4, 3],
[1, 2, 5, 6],
[3, 4, 7, 6],
[4, 5, 8, 7],
[6, 7, 1, 0],
[7, 8, 2, 1],
[2, 0, 3, 5],
[5, 3, 6, 8],
]
)

c2e_arr = np.array(
[
[0, 10, 3, 9],
[1, 11, 4, 10],
[3, 13, 6, 12],
[4, 14, 7, 13],
[6, 16, 0, 15],
[7, 17, 1, 16],
[2, 9, 5, 11],
[5, 12, 8, 14],
]
)

# create e2v connectivity by inverting v2e
num_edges = np.max(v2e_arr) + 1
e2v_arr = [[] for _ in range(0, num_edges)]
Expand All @@ -98,12 +126,15 @@ def reduction_setup():
[
"num_vertices",
"num_edges",
"Edge",
"Vertex",
"num_cells",
"V2EDim",
"E2VDim",
"C2VDim",
"C2EDim",
"V2E",
"E2V",
"C2V",
"C2E",
"inp",
"out",
"offset_provider",
Expand All @@ -113,18 +144,23 @@ def reduction_setup():
)(
num_vertices=num_vertices,
num_edges=num_edges,
Edge=edge,
Vertex=vertex,
num_cells=num_cells,
V2EDim=v2edim,
E2VDim=e2vdim,
V2E=FieldOffset("V2E", source=edge, target=(vertex, v2edim)),
E2V=FieldOffset("E2V", source=vertex, target=(edge, e2vdim)),
C2VDim=c2vdim,
C2EDim=c2edim,
V2E=FieldOffset("V2E", source=Edge, target=(Vertex, v2edim)),
E2V=FieldOffset("E2V", source=Vertex, target=(Edge, e2vdim)),
C2V=FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)),
C2E=FieldOffset("C2E", source=Edge, target=(Cell, c2edim)),
# inp=index_field(edge, dtype=np.int64), # TODO enable once we support index_fields in bindings
inp=np_as_located_field(edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(vertex)(np.zeros([num_vertices], dtype=np.int64)),
inp=np_as_located_field(Edge)(np.arange(num_edges, dtype=np.int64)),
out=np_as_located_field(Vertex)(np.zeros([num_vertices], dtype=np.int64)),
offset_provider={
"V2E": NeighborTableOffsetProvider(v2e_arr, vertex, edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, edge, vertex, 2, has_skip_values=False),
"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4),
"E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False),
"C2V": NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False),
"C2E": NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False),
},
v2e_table=v2e_arr,
e2v_table=e2v_arr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,62 @@ def fencil(inp: Field[[IDim], float64], out: Field[[IDim], float64]) -> None:


def test_unstructured_shift(reduction_setup, fieldview_backend):
Vertex = reduction_setup.Vertex
Edge = reduction_setup.Edge
E2V = reduction_setup.E2V

a = np_as_located_field(Vertex)(np.zeros(reduction_setup.num_vertices))
a = np_as_located_field(Vertex)(np.arange(0, reduction_setup.num_vertices, dtype=np.float64))
b = np_as_located_field(Edge)(np.zeros(reduction_setup.num_edges))

@field_operator(backend=fieldview_backend)
def shift_by_one(inp: Field[[Vertex], float64]) -> Field[[Edge], float64]:
def shift_unstructured(inp: Field[[Vertex], float64]) -> Field[[Edge], float64]:
return inp(E2V[0])

shift_by_one(a, out=b, offset_provider={"E2V": reduction_setup.offset_provider["E2V"]})
shift_unstructured(a, out=b, offset_provider={"E2V": reduction_setup.offset_provider["E2V"]})

ref = np.asarray(a)[reduction_setup.offset_provider["E2V"].table[slice(0, None), 0]]

assert np.allclose(b, ref)


def test_composed_unstructured_shift(reduction_setup, fieldview_backend):
E2V = reduction_setup.E2V
C2E = reduction_setup.C2E
e2v_table = reduction_setup.offset_provider["E2V"].table[slice(0, None), 0]
c2e_table = reduction_setup.offset_provider["C2E"].table[slice(0, None), 0]

a = np_as_located_field(Vertex)(np.arange(0, reduction_setup.num_vertices, dtype=np.float64))
b = np_as_located_field(Cell)(np.zeros(reduction_setup.num_cells))

@field_operator(backend=fieldview_backend)
def composed_shift_unstructured_flat(inp: Field[[Vertex], float64]) -> Field[[Cell], float64]:
return inp(E2V[0])(C2E[0])

@field_operator(backend=fieldview_backend)
def composed_shift_unstructured_intermediate_result(
inp: Field[[Vertex], float64]
) -> Field[[Cell], float64]:
tmp = inp(E2V[0])
return tmp(C2E[0])

@field_operator(backend=fieldview_backend)
def shift_e2v(inp: Field[[Vertex], float64]) -> Field[[Edge], float64]:
return inp(E2V[0])

@field_operator(backend=fieldview_backend)
def composed_shift_unstructured(inp: Field[[Vertex], float64]) -> Field[[Cell], float64]:
return shift_e2v(inp)(C2E[0])

ref = np.asarray(a)[e2v_table][c2e_table]

for field_op in [
composed_shift_unstructured_flat,
composed_shift_unstructured_intermediate_result,
composed_shift_unstructured,
]:
field_op(a, out=b, offset_provider=reduction_setup.offset_provider)

assert np.allclose(b, ref)


def test_fold_shifts(fieldview_backend):
"""Shifting the result of an addition should work."""
a = np_as_located_field(IDim)(np.arange(size + 1, dtype=np.float64))
Expand Down