Skip to content

Commit

Permalink
fix[next]: DaCe field addressing in builtin_neighbors (#1349)
Browse files Browse the repository at this point in the history
Bugfix in DaCe backend to make field addressing in builtin_neighbors consistent with the canonical representation (field dimensions alphabetically sorted).
  • Loading branch information
edopao authored Oct 17, 2023
1 parent 90eea30 commit f96ead5
Showing 1 changed file with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,8 @@ def builtin_neighbors(
)
# select full shape only in the neighbor-axis dimension
field_subset = [
f"0:{sdfg.arrays[iterator.field.data].shape[idx]}"
if dim == table.neighbor_axis.value
else f"i_{dim}"
for idx, dim in enumerate(iterator.dimensions)
f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape)
]
state.add_memlet_path(
iterator.field,
Expand Down Expand Up @@ -575,6 +573,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
return iterator

args: list[ValueExpr]
sorted_dims = sorted(iterator.dimensions)
if self.context.reduce_limit:
# we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing
result_name = unique_var_name()
Expand All @@ -596,7 +595,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
# if dim is not found in iterator indices, we take the neighbor index over the reduction domain
flat_index = [
f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name
for dim in sorted(iterator.dimensions)
for dim in sorted_dims
]
args = [ValueExpr(iterator.field, iterator.dtype)] + [
ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices
Expand Down Expand Up @@ -629,11 +628,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
return [ValueExpr(value=result_access, dtype=iterator.dtype)]

else:
sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0])
flat_index = [
ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions
args = [ValueExpr(iterator.field, iterator.dtype)] + [
ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims
]
args = [ValueExpr(iterator.field, iterator.dtype), *flat_index]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{', '.join(internals[1:])}]"
return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref")
Expand Down

0 comments on commit f96ead5

Please sign in to comment.