Skip to content

Commit

Permalink
feat[next]: DaCe support for tuple returns (#1343)
Browse files Browse the repository at this point in the history
This PR adds support in DaCe backends for closures with tuple returns. The motivation is that tuple returns are used in icon4py stencils, although the internal expressions do not operate on tuples. Tuples are just a mean to aggregate multiple-outputs from one stencil. For that reason, this PR does not contain support for scan or conditional expressions with tuples.
  • Loading branch information
edopao authored Oct 20, 2023
1 parent f96ead5 commit d11246e
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 135 deletions.
200 changes: 98 additions & 102 deletions src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
flatten_list,
get_sorted_dims,
map_nested_sdfg_symbols,
unique_var_name,
Expand Down Expand Up @@ -124,6 +125,13 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset
raise NotImplementedError()
self.storage_types[name] = type_

def get_output_nodes(
self, closure: itir.StencilClosure, context: Context
) -> dict[str, dace.nodes.AccessNode]:
translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types)
output_nodes = flatten_list(translator.visit(closure.output))
return {node.value.data: node.value for node in output_nodes}

def visit_FencilDefinition(self, node: itir.FencilDefinition):
program_sdfg = dace.SDFG(name=node.id)
last_state = program_sdfg.add_state("program_entry")
Expand All @@ -145,50 +153,29 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):

# Create a nested SDFG for all stencil closures.
for closure in node.closures:
assert isinstance(closure.output, itir.SymRef)

# filter out arguments with scalar type, because they are passed as symbols
input_names = [
str(inp.id)
for inp in closure.inputs
if isinstance(self.storage_types[inp.id], ts.FieldType)
]
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
output_names = [str(closure.output.id)]

# Translate the closure and its stencil's body to an SDFG.
closure_sdfg = self.visit(closure, array_table=program_sdfg.arrays)
closure_sdfg, input_names, output_names = self.visit(
closure, array_table=program_sdfg.arrays
)

# Create a new state for the closure.
last_state = program_sdfg.add_state_after(last_state)

# Create memlets to transfer the program parameters
input_memlets = [
create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names
]
connectivity_memlets = [
create_memlet_full(name, program_sdfg.arrays[name]) for name in connectivity_names
]
output_memlets = [
create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names
]

input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
connectivity_mapping = {
param: arg for param, arg in zip(connectivity_names, connectivity_memlets)
input_mapping = {
name: create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names
}
output_mapping = {
param: arg_memlet for param, arg_memlet in zip(output_names, output_memlets)
name: create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names
}

array_mapping = {**input_mapping, **connectivity_mapping}
symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, array_mapping)
symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping)

# Insert the closure's SDFG as a nested SDFG of the program.
nsdfg_node = last_state.add_nested_sdfg(
sdfg=closure_sdfg,
parent=program_sdfg,
inputs=set(input_names) | set(connectivity_names),
inputs=set(input_names),
outputs=set(output_names),
symbol_mapping=symbol_mapping,
)
Expand All @@ -198,49 +185,78 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
access_node = last_state.add_access(inner_name)
last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet)

for inner_name, memlet in connectivity_mapping.items():
access_node = last_state.add_access(inner_name)
last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet)

for inner_name, memlet in output_mapping.items():
access_node = last_state.add_access(inner_name)
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

program_sdfg.validate()
return program_sdfg

def visit_StencilClosure(
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
) -> dace.SDFG:
) -> tuple[dace.SDFG, list[str], list[str]]:
assert ItirToSDFG._check_no_lifts(node)
assert ItirToSDFG._check_shift_offsets_are_literals(node)
assert isinstance(node.output, itir.SymRef)

neighbor_tables = filter_neighbor_tables(self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
output_name = str(node.output.id)

# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init")

# Add DaCe arrays for inputs, output and connectivities to closure SDFG.
for name in [*input_names, *conn_names, output_name]:
assert name not in closure_sdfg.arrays or (name in input_names and name == output_name)
program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {}
closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms)
neighbor_tables = filter_neighbor_tables(self.offset_provider)

input_names = [str(inp.id) for inp in node.inputs]
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]

output_nodes = self.get_output_nodes(node, closure_ctx)
output_names = [k for k, _ in output_nodes.items()]

# Add DaCe arrays for inputs, outputs and connectivities to closure SDFG.
input_transients_mapping = {}
for name in [*input_names, *conn_names, *output_names]:
if name in closure_sdfg.arrays:
# in/out parameter, container already added for in parameter
continue
if isinstance(self.storage_types[name], ts.FieldType):
assert name in input_names and name in output_names
# In case of closures with in/out fields, there is risk of race condition
# between read/write access nodes in the (asynchronous) map tasklet.
transient_name = unique_var_name()
closure_sdfg.add_array(
transient_name,
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
transient=True,
)
closure_init_state.add_nedge(
closure_init_state.add_access(name),
closure_init_state.add_access(transient_name),
create_memlet_full(name, closure_sdfg.arrays[name]),
)
input_transients_mapping[name] = transient_name
elif isinstance(self.storage_types[name], ts.FieldType):
closure_sdfg.add_array(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
)
else:
assert isinstance(self.storage_types[name], ts.ScalarType)

# Get output domain of the closure
program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {}
input_field_names = [
input_name
for input_name in input_names
if isinstance(self.storage_types[input_name], ts.FieldType)
]

# Closure outputs should all be fields
assert all(
isinstance(self.storage_types[output_name], ts.FieldType)
for output_name in output_names
)

# Update symbol table and get output domain of the closure
for name, type_ in self.storage_types.items():
if isinstance(type_, ts.ScalarType):
if name in input_names:
Expand All @@ -258,73 +274,64 @@ def visit_StencilClosure(
program_arg_syms[name] = value
else:
program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_))
domain_ctx = Context(closure_sdfg, closure_state, program_arg_syms)
closure_domain = self._visit_domain(node.domain, domain_ctx)
closure_domain = self._visit_domain(node.domain, closure_ctx)

# Map SDFG tasklet arguments to parameters
input_access_names = [
input_name
if isinstance(self.storage_types[input_name], ts.FieldType)
input_transients_mapping[input_name]
if input_name in input_transients_mapping
else input_name
if input_name in input_field_names
else cast(ValueExpr, program_arg_syms[input_name]).value.data
for input_name in input_names
]
input_memlets = [
create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names
]
conn_memlet = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names]
conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names]

transient_to_arg_name_mapping = {}
# create and write to transient that is then copied back to actual output array to avoid aliasing of
# same memory in nested SDFG with different names
nsdfg_output_name = unique_var_name()
output_descriptor = closure_sdfg.arrays[output_name]
transient_to_arg_name_mapping[nsdfg_output_name] = output_name
output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names}
# scan operator should always be the first function call in a closure
if is_scan(node.stencil):
assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs"
transient_name, output_name = next(iter(output_connectors_mapping.items()))

nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure(
node, closure_sdfg.arrays, closure_domain, nsdfg_output_name
node, closure_sdfg.arrays, closure_domain, transient_name
)
results = [nsdfg_output_name]
results = [transient_name]

_, (scan_lb, scan_ub) = closure_domain[scan_dim_index]
output_subset = f"{scan_lb.value}:{scan_ub.value}"

closure_sdfg.add_array(
nsdfg_output_name,
dtype=output_descriptor.dtype,
shape=(output_descriptor.shape[scan_dim_index],),
strides=(output_descriptor.strides[scan_dim_index],),
transient=True,
)

output_memlet = create_memlet_at(
output_name,
tuple(
f"i_{dim}"
if f"i_{dim}" in map_ranges
else f"0:{output_descriptor.shape[scan_dim_index]}"
for dim, _ in closure_domain
),
)
output_memlets = [
create_memlet_at(
output_name,
tuple(
f"i_{dim}"
if f"i_{dim}" in map_ranges
else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}"
for dim, _ in closure_domain
),
)
]
else:
nsdfg, map_ranges, results = self._visit_parallel_stencil_closure(
node, closure_sdfg.arrays, closure_domain
)
assert len(results) == 1

output_subset = "0"

closure_sdfg.add_scalar(
nsdfg_output_name,
dtype=output_descriptor.dtype,
transient=True,
)

output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys()))
output_memlets = [
create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys()))
for output_name in output_connectors_mapping.values()
]

input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])}
conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlet)}
output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)}
conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)}

array_mapping = {**input_mapping, **conn_mapping}
symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping)
Expand All @@ -336,11 +343,12 @@ def visit_StencilClosure(
inputs=array_mapping,
outputs=output_mapping,
symbol_mapping=symbol_mapping,
output_nodes=output_nodes,
)
access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)}
for edge in closure_state.in_edges(map_exit):
memlet = edge.data
if memlet.data not in transient_to_arg_name_mapping:
if memlet.data not in output_connectors_mapping:
continue
transient_access = closure_state.add_access(memlet.data)
closure_state.add_edge(
Expand All @@ -355,21 +363,9 @@ def visit_StencilClosure(
)
closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet)
closure_state.remove_edge(edge)
access_nodes[memlet.data].data = transient_to_arg_name_mapping[memlet.data]

for _, (lb, ub) in closure_domain:
for b in lb, ub:
if isinstance(b, SymbolExpr):
continue
map_entry.add_in_connector(b.value.data)
closure_state.add_edge(
b.value,
None,
map_entry,
b.value.data,
dace.Memlet.simple(b.value.data, "0"),
)
return closure_sdfg
access_nodes[memlet.data].data = output_connectors_mapping[memlet.data]

return closure_sdfg, input_field_names + conn_names, output_names

def _visit_scan_stencil_closure(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
connectivity_identifier,
create_memlet_full,
filter_neighbor_tables,
flatten_list,
map_nested_sdfg_symbols,
unique_name,
unique_var_name,
Expand Down Expand Up @@ -423,32 +424,36 @@ def visit_Lambda(
context.body.add_array(name, shape=shape, strides=strides, dtype=dtype)

# Translate the function's body
result: ValueExpr | SymbolExpr = self.visit(node.expr)[0]
# Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors
if isinstance(result, ValueExpr):
result_name = unique_var_name()
self.context.body.add_scalar(result_name, result.dtype, transient=True)
result_access = self.context.state.add_access(result_name)
self.context.state.add_edge(
result.value,
None,
result_access,
None,
# in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
)
result = ValueExpr(value=result_access, dtype=result.dtype)
else:
result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0]
self.context.body.arrays[result.value.data].transient = False
self.context = prev_context
results: list[ValueExpr] = []
# We are flattening the returned list of value expressions because the multiple outputs of a lamda
# should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this.
for expr in flatten_list(self.visit(node.expr)):
if isinstance(expr, ValueExpr):
result_name = unique_var_name()
self.context.body.add_scalar(result_name, expr.dtype, transient=True)
result_access = self.context.state.add_access(result_name)
self.context.state.add_edge(
expr.value,
None,
result_access,
None,
# in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
)
result = ValueExpr(value=result_access, dtype=expr.dtype)
else:
# Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors
result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0]
self.context.body.arrays[result.value.data].transient = False
results.append(result)

self.context = prev_context
for node in context.state.nodes():
if isinstance(node, dace.nodes.AccessNode):
if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0:
context.state.remove_node(node)

return context, inputs, [result]
return context, inputs, results

def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr:
if node.id not in self.context.symbol_map:
Expand Down
Loading

0 comments on commit d11246e

Please sign in to comment.