From 40000055fddf437a3e8b938e31d916ed8b1a9e61 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 4 Oct 2023 13:56:41 +0200 Subject: [PATCH] [dace] Minor edit --- .../runners/dace_iterator/itir_to_sdfg.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 92d9b598ce..72cc38f4e0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -258,10 +258,24 @@ def visit_StencilClosure( closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} for name in [*input_names, *conn_names, *output_names]: - assert name not in closure_sdfg.arrays or (name in input_names and name in output_names) if name in closure_sdfg.arrays: + assert name in input_names and name in output_names + elif isinstance(self.storage_types[name], ts.FieldType): + closure_sdfg.add_array( + name, + shape=array_table[name].shape, + strides=array_table[name].strides, + dtype=array_table[name].dtype, + ) + else: + assert isinstance(self.storage_types[name], ts.ScalarType) + + # Create a copy of all input fields to transient arrays in order to be able to handle in/out fields, + # and let DaCe remove unnecessary data movements. + input_transients_mapping = {} + for name in input_names: + if isinstance(self.storage_types[name], ts.FieldType): # in/out array, create transient for input read access to avoid race conditions transient_name = unique_var_name() closure_sdfg.add_array( @@ -277,15 +291,6 @@ def visit_StencilClosure( create_memlet_full(name, closure_sdfg.arrays[name]), ) input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) # Get output domain of the closure program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} @@ -312,7 +317,7 @@ def visit_StencilClosure( # Map SDFG tasklet arguments to parameters input_access_names = [ input_transients_mapping[input_name] - if input_name in output_names + if input_name in input_transients_mapping else input_name if isinstance(self.storage_types[input_name], ts.FieldType) else cast(ValueExpr, program_arg_syms[input_name]).value.data @@ -323,13 +328,10 @@ def visit_StencilClosure( ] conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] - output_memlets = [] output_connectors_mapping = {} # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names - for output_name in output_names: - transient_name = unique_var_name() - output_connectors_mapping[transient_name] = output_name + output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} # scan operator should always be the first function call in a closure if is_scan(node.stencil): assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" @@ -361,10 +363,10 @@ def visit_StencilClosure( output_subset = "0" - for output_name in output_connectors_mapping.values(): - output_memlets.append( - create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) - ) + output_memlets = [ + create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + for output_name in output_connectors_mapping.values() + ] input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)}