From d11246ec828acd6b904dadb80ac535ebd21b5359 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 20 Oct 2023 10:37:57 +0200 Subject: [PATCH] feat[next]: DaCe support for tuple returns (#1343) This PR adds support in DaCe backends for closures with tuple returns. The motivation is that tuple returns are used in icon4py stencils, although the internal expressions do not operate on tuples. Tuples are just a mean to aggregate multiple-outputs from one stencil. For that reason, this PR does not contain support for scan or conditional expressions with tuples. --- .../runners/dace_iterator/itir_to_sdfg.py | 200 +++++++++--------- .../runners/dace_iterator/itir_to_tasklet.py | 45 ++-- .../runners/dace_iterator/utility.py | 10 +- .../ffront_tests/test_execution.py | 5 +- .../ffront_tests/test_program.py | 3 - .../iterator_tests/test_tuple.py | 6 +- .../iterator_tests/test_column_stencil.py | 2 +- 7 files changed, 136 insertions(+), 135 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 7017815688..580486aa4a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ create_memlet_at, create_memlet_full, filter_neighbor_tables, + flatten_list, get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, @@ -124,6 +125,13 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset raise NotImplementedError() self.storage_types[name] = type_ + def get_output_nodes( + self, closure: itir.StencilClosure, context: Context + ) -> dict[str, dace.nodes.AccessNode]: + translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types) + output_nodes = flatten_list(translator.visit(closure.output)) + return {node.value.data: node.value for node in output_nodes} + def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) last_state = program_sdfg.add_state("program_entry") @@ -145,50 +153,29 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Create a nested SDFG for all stencil closures. for closure in node.closures: - assert isinstance(closure.output, itir.SymRef) - - # filter out arguments with scalar type, because they are passed as symbols - input_names = [ - str(inp.id) - for inp in closure.inputs - if isinstance(self.storage_types[inp.id], ts.FieldType) - ] - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_names = [str(closure.output.id)] - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg = self.visit(closure, array_table=program_sdfg.arrays) + closure_sdfg, input_names, output_names = self.visit( + closure, array_table=program_sdfg.arrays + ) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) # Create memlets to transfer the program parameters - input_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names - ] - connectivity_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in connectivity_names - ] - output_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names - ] - - input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) + input_mapping = { + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names } output_mapping = { - param: arg_memlet for param, arg_memlet in zip(output_names, output_memlets) + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, array_mapping) + symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) # Insert the closure's SDFG as a nested SDFG of the program. nsdfg_node = last_state.add_nested_sdfg( sdfg=closure_sdfg, parent=program_sdfg, - inputs=set(input_names) | set(connectivity_names), + inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, ) @@ -198,49 +185,78 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): access_node = last_state.add_access(inner_name) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in connectivity_mapping.items(): - access_node = last_state.add_access(inner_name) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in output_mapping.items(): access_node = last_state.add_access(inner_name) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) + program_sdfg.validate() return program_sdfg def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> dace.SDFG: + ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) assert ItirToSDFG._check_shift_offsets_are_literals(node) - assert isinstance(node.output, itir.SymRef) - - neighbor_tables = filter_neighbor_tables(self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_name = str(node.output.id) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - # Add DaCe arrays for inputs, output and connectivities to closure SDFG. - for name in [*input_names, *conn_names, output_name]: - assert name not in closure_sdfg.arrays or (name in input_names and name == output_name) + program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) + neighbor_tables = filter_neighbor_tables(self.offset_provider) + + input_names = [str(inp.id) for inp in node.inputs] + conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + + output_nodes = self.get_output_nodes(node, closure_ctx) + output_names = [k for k, _ in output_nodes.items()] + + # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. + input_transients_mapping = {} + for name in [*input_names, *conn_names, *output_names]: if name in closure_sdfg.arrays: - # in/out parameter, container already added for in parameter - continue - if isinstance(self.storage_types[name], ts.FieldType): + assert name in input_names and name in output_names + # In case of closures with in/out fields, there is risk of race condition + # between read/write access nodes in the (asynchronous) map tasklet. + transient_name = unique_var_name() + closure_sdfg.add_array( + transient_name, + shape=array_table[name].shape, + strides=array_table[name].strides, + dtype=array_table[name].dtype, + transient=True, + ) + closure_init_state.add_nedge( + closure_init_state.add_access(name), + closure_init_state.add_access(transient_name), + create_memlet_full(name, closure_sdfg.arrays[name]), + ) + input_transients_mapping[name] = transient_name + elif isinstance(self.storage_types[name], ts.FieldType): closure_sdfg.add_array( name, shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, ) + else: + assert isinstance(self.storage_types[name], ts.ScalarType) - # Get output domain of the closure - program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + input_field_names = [ + input_name + for input_name in input_names + if isinstance(self.storage_types[input_name], ts.FieldType) + ] + + # Closure outputs should all be fields + assert all( + isinstance(self.storage_types[output_name], ts.FieldType) + for output_name in output_names + ) + + # Update symbol table and get output domain of the closure for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): if name in input_names: @@ -258,73 +274,64 @@ def visit_StencilClosure( program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_)) - domain_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, domain_ctx) + closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters input_access_names = [ - input_name - if isinstance(self.storage_types[input_name], ts.FieldType) + input_transients_mapping[input_name] + if input_name in input_transients_mapping + else input_name + if input_name in input_field_names else cast(ValueExpr, program_arg_syms[input_name]).value.data for input_name in input_names ] input_memlets = [ create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names ] - conn_memlet = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] + conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] - transient_to_arg_name_mapping = {} # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names - nsdfg_output_name = unique_var_name() - output_descriptor = closure_sdfg.arrays[output_name] - transient_to_arg_name_mapping[nsdfg_output_name] = output_name + output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} # scan operator should always be the first function call in a closure if is_scan(node.stencil): + assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" + transient_name, output_name = next(iter(output_connectors_mapping.items())) + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, nsdfg_output_name + node, closure_sdfg.arrays, closure_domain, transient_name ) - results = [nsdfg_output_name] + results = [transient_name] _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] output_subset = f"{scan_lb.value}:{scan_ub.value}" - closure_sdfg.add_array( - nsdfg_output_name, - dtype=output_descriptor.dtype, - shape=(output_descriptor.shape[scan_dim_index],), - strides=(output_descriptor.strides[scan_dim_index],), - transient=True, - ) - - output_memlet = create_memlet_at( - output_name, - tuple( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{output_descriptor.shape[scan_dim_index]}" - for dim, _ in closure_domain - ), - ) + output_memlets = [ + create_memlet_at( + output_name, + tuple( + f"i_{dim}" + if f"i_{dim}" in map_ranges + else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + for dim, _ in closure_domain + ), + ) + ] else: nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) - assert len(results) == 1 output_subset = "0" - closure_sdfg.add_scalar( - nsdfg_output_name, - dtype=output_descriptor.dtype, - transient=True, - ) - - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) + output_memlets = [ + create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) + for output_name in output_connectors_mapping.values() + ] input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} - conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlet)} + output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)} + conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)} array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) @@ -336,11 +343,12 @@ def visit_StencilClosure( inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, + output_nodes=output_nodes, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data - if memlet.data not in transient_to_arg_name_mapping: + if memlet.data not in output_connectors_mapping: continue transient_access = closure_state.add_access(memlet.data) closure_state.add_edge( @@ -355,21 +363,9 @@ def visit_StencilClosure( ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) - access_nodes[memlet.data].data = transient_to_arg_name_mapping[memlet.data] - - for _, (lb, ub) in closure_domain: - for b in lb, ub: - if isinstance(b, SymbolExpr): - continue - map_entry.add_in_connector(b.value.data) - closure_state.add_edge( - b.value, - None, - map_entry, - b.value.data, - dace.Memlet.simple(b.value.data, "0"), - ) - return closure_sdfg + access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] + + return closure_sdfg, input_field_names + conn_names, output_names def _visit_scan_stencil_closure( self, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 610698646a..b28703feef 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -36,6 +36,7 @@ connectivity_identifier, create_memlet_full, filter_neighbor_tables, + flatten_list, map_nested_sdfg_symbols, unique_name, unique_var_name, @@ -423,32 +424,36 @@ def visit_Lambda( context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) # Translate the function's body - result: ValueExpr | SymbolExpr = self.visit(node.expr)[0] - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - if isinstance(result, ValueExpr): - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result.dtype, transient=True) - result_access = self.context.state.add_access(result_name) - self.context.state.add_edge( - result.value, - None, - result_access, - None, - # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), - ) - result = ValueExpr(value=result_access, dtype=result.dtype) - else: - result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0] - self.context.body.arrays[result.value.data].transient = False - self.context = prev_context + results: list[ValueExpr] = [] + # We are flattening the returned list of value expressions because the multiple outputs of a lamda + # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + for expr in flatten_list(self.visit(node.expr)): + if isinstance(expr, ValueExpr): + result_name = unique_var_name() + self.context.body.add_scalar(result_name, expr.dtype, transient=True) + result_access = self.context.state.add_access(result_name) + self.context.state.add_edge( + expr.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), + ) + result = ValueExpr(value=result_access, dtype=expr.dtype) + else: + # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors + result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + self.context.body.arrays[result.value.data].transient = False + results.append(result) + self.context = prev_context for node in context.state.nodes(): if isinstance(node, dace.nodes.AccessNode): if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0: context.state.remove_node(node) - return context, inputs, [result] + return context, inputs, results def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: if node.id not in self.context.symbol_map: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 7e6fe13ac7..1fdd022a49 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import itertools from typing import Any, Sequence import dace @@ -166,3 +166,11 @@ def unique_name(prefix): def unique_var_name(): return unique_name("__var") + + +def flatten_list(node_list: list[Any]) -> list[Any]: + return list( + itertools.chain.from_iterable( + [flatten_list(e) if e.__class__ == list else [e] for e in node_list] + ) + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 865950eeab..61b34460ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -159,7 +159,6 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:]) -@pytest.mark.uses_tuple_returns def test_tuples(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField: @@ -400,7 +399,6 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD assert np.allclose(out, ref) -@pytest.mark.uses_tuple_returns def test_nested_tuple_return(cartesian_case): @gtx.field_operator def pack_tuple( @@ -476,7 +474,7 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_constant_fields def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): @gtx.field_operator def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: @@ -840,7 +838,6 @@ def program_domain( ) -@pytest.mark.uses_tuple_returns def test_domain_tuple(cartesian_case): @gtx.field_operator def fieldop_domain_tuple( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f489126fa7..d86bc21679 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -128,7 +128,6 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): ) -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside(cartesian_case): @gtx.field_operator def pack_tuple( @@ -155,7 +154,6 @@ def prog( assert np.allclose((a, b), (out_a, out_b)) -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @gtx.field_operator def pack_tuple( @@ -183,7 +181,6 @@ def prog( assert out_a[0] == 0 and out_b[0] == 0 -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_nested(cartesian_case): @gtx.field_operator def pack_tuple( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index bd5a717bb2..67b439507c 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -148,7 +148,6 @@ def stencil(inp1, inp2, inp3, inp4): "stencil", [tuple_output1, tuple_output2], ) -@pytest.mark.uses_tuple_returns def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor @@ -194,7 +193,6 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): assert np.allclose(inp2, out2) -@pytest.mark.uses_tuple_returns def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor @@ -288,7 +286,7 @@ def tuple_input(inp): return tuple_get(0, inp_deref) + tuple_get(1, inp_deref) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -348,7 +346,7 @@ def tuple_tuple_input(inp): ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 41d6c8f0f9..04cf8c6f9c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -149,7 +149,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): ), ], ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor