Skip to content

Commit

Permalink
[dace] Minor edit
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Oct 4, 2023
1 parent f0f9552 commit 4000005
Showing 1 changed file with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,24 @@ def visit_StencilClosure(
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init")

# Add DaCe arrays for inputs, outputs and connectivities to closure SDFG.
input_transients_mapping = {}
for name in [*input_names, *conn_names, *output_names]:
assert name not in closure_sdfg.arrays or (name in input_names and name in output_names)
if name in closure_sdfg.arrays:
assert name in input_names and name in output_names
elif isinstance(self.storage_types[name], ts.FieldType):
closure_sdfg.add_array(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
)
else:
assert isinstance(self.storage_types[name], ts.ScalarType)

# Create a copy of all input fields to transient arrays in order to be able to handle in/out fields,
# and let DaCe remove unnecessary data movements.
input_transients_mapping = {}
for name in input_names:
if isinstance(self.storage_types[name], ts.FieldType):
# in/out array, create transient for input read access to avoid race conditions
transient_name = unique_var_name()
closure_sdfg.add_array(
Expand All @@ -277,15 +291,6 @@ def visit_StencilClosure(
create_memlet_full(name, closure_sdfg.arrays[name]),
)
input_transients_mapping[name] = transient_name
elif isinstance(self.storage_types[name], ts.FieldType):
closure_sdfg.add_array(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
)
else:
assert isinstance(self.storage_types[name], ts.ScalarType)

# Get output domain of the closure
program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {}
Expand All @@ -312,7 +317,7 @@ def visit_StencilClosure(
# Map SDFG tasklet arguments to parameters
input_access_names = [
input_transients_mapping[input_name]
if input_name in output_names
if input_name in input_transients_mapping
else input_name
if isinstance(self.storage_types[input_name], ts.FieldType)
else cast(ValueExpr, program_arg_syms[input_name]).value.data
Expand All @@ -323,13 +328,10 @@ def visit_StencilClosure(
]
conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names]

output_memlets = []
output_connectors_mapping = {}
# create and write to transient that is then copied back to actual output array to avoid aliasing of
# same memory in nested SDFG with different names
for output_name in output_names:
transient_name = unique_var_name()
output_connectors_mapping[transient_name] = output_name
output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names}
# scan operator should always be the first function call in a closure
if is_scan(node.stencil):
assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs"
Expand Down Expand Up @@ -361,10 +363,10 @@ def visit_StencilClosure(

output_subset = "0"

for output_name in output_connectors_mapping.values():
output_memlets.append(
create_memlet_at(output_name, tuple(idx for idx in map_domain.keys()))
)
output_memlets = [
create_memlet_at(output_name, tuple(idx for idx in map_domain.keys()))
for output_name in output_connectors_mapping.values()
]

input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)}
Expand Down

0 comments on commit 4000005

Please sign in to comment.