diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 4f93777215..56031d8555 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -32,6 +32,7 @@ is_scan, ) from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, @@ -321,7 +322,7 @@ def visit_StencilClosure( array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) - nsdfg_node, map_entry, map_exit = self._add_mapped_nested_sdfg( + nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, map_ranges=map_domain or {"__dummy": "0"}, @@ -584,76 +585,6 @@ def _visit_parallel_stencil_closure( return context.body, map_domain, [r.value.data for r in results] - def _add_mapped_nested_sdfg( - self, - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] - | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, - ) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - def _visit_domain( self, node: itir.FunCall, context: Context ) -> tuple[tuple[str, tuple[ValueExpr, ValueExpr]], ...]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 875a23353b..56d66e6436 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,6 +19,8 @@ import dace import numpy as np +from dace.transformation.dataflow import MapFusion +from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen from gt4py.next import Dimension, type_inference as next_typing @@ -29,12 +31,14 @@ from gt4py.next.type_system import type_specifications as ts from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, + unique_name, unique_var_name, ) @@ -56,6 +60,21 @@ def itir_type_as_dace_type(type_: next_typing.Type): raise NotImplementedError() +def get_reduce_identity_value(op_name_: str, type_: Any): + if op_name_ == "plus": + init_value = type_(0) + elif op_name_ == "multiplies": + init_value = type_(1) + elif op_name_ == "minimum": + init_value = type_("inf") + elif op_name_ == "maximum": + init_value = type_("-inf") + else: + raise NotImplementedError() + + return init_value + + _MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -135,6 +154,21 @@ class Context: body: dace.SDFG state: dace.SDFGState symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + # if we encounter a reduction node, the reduction state needs to be pushed to child nodes + reduce_limit: int + reduce_wcr: Optional[str] + + def __init__( + self, + body: dace.SDFG, + state: dace.SDFGState, + symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + ): + self.body = body + self.state = state + self.symbol_map = symbol_map + self.reduce_limit = 0 + self.reduce_wcr = None def builtin_neighbors( @@ -166,13 +200,15 @@ def builtin_neighbors( table_name = connectivity_identifier(offset_dim) table_array = sdfg.arrays[table_name] + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__neigh_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={"neigh_idx": f"0:{table.max_neighbors}"}, + ndrange={index_name: f"0:{table.max_neighbors}"}, ) shift_tasklet = state.add_tasklet( "shift", - code="__result = __table[__idx, neigh_idx]", + code=f"__result = __table[__idx, {index_name}]", inputs={"__table", "__idx"}, outputs={"__result"}, ) @@ -226,7 +262,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset="neigh_idx"), + memlet=dace.Memlet(data=result_name, subset=index_name), src_conn="__result", ) @@ -348,6 +384,8 @@ def visit_Lambda( value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) symbol_map[param] = value context = Context(context_sdfg, context_state, symbol_map) + context.reduce_limit = prev_context.reduce_limit + context.reduce_wcr = prev_context.reduce_wcr self.context = context # Add input parameters as arrays @@ -394,7 +432,12 @@ def visit_Lambda( self.context.body.add_scalar(result_name, result.dtype, transient=True) result_access = self.context.state.add_access(result_name) self.context.state.add_edge( - result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]") + result.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=result.dtype) else: @@ -530,15 +573,71 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr return iterator - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions - ] - args: list[ValueExpr] = [ValueExpr(iterator.field, iterator.dtype), *flat_index] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + args: list[ValueExpr] + if self.context.reduce_limit: + # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing + result_name = unique_var_name() + self.context.body.add_array( + result_name, + dtype=iterator.dtype, + shape=(self.context.reduce_limit,), + transient=True, + ) + result_access = self.context.state.add_access(result_name) + + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__deref_idx") + me, mx = self.context.state.add_map( + "deref_map", + ndrange={index_name: f"0:{self.context.reduce_limit}"}, + ) + + # if dim is not found in iterator indices, we take the neighbor index over the reduction domain + array_index = [ + f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name + for dim in sorted(iterator.dimensions) + ] + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ] + internals = [f"{arg.value.data}_v" for arg in args] + + deref_tasklet = self.context.state.add_tasklet( + name="deref", + inputs=set(internals), + outputs={"__result"}, + code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + ) + + for arg, internal in zip(args, internals): + input_memlet = create_memlet_full( + arg.value.data, self.context.body.arrays[arg.value.data] + ) + self.context.state.add_memlet_path( + arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal + ) + + self.context.state.add_memlet_path( + deref_tasklet, + mx, + result_access, + memlet=dace.Memlet(data=result_name, subset=index_name), + src_conn="__result", + ) + + return [ValueExpr(value=result_access, dtype=iterator.dtype)] + + else: + sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) + flat_index = [ + ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + ] + + args = [ValueExpr(iterator.field, int), *flat_index] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{', '.join(internals[1:])}]" + return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") def _split_shift_args( self, args: list[itir.Expr] @@ -625,47 +724,156 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def _visit_reduce(self, node: itir.FunCall): - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] - - nreduce = self.context.body.arrays[args[0].value.data].shape[0] - result_name = unique_var_name() result_access = self.context.state.add_access(result_name) - self.context.body.add_scalar(result_name, args[0].dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - dace.Memlet(data=args[0].value.data, subset=f"0:{nreduce}"), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_access, args[0].dtype)] + + if len(node.args) == 1: + assert ( + isinstance(node.args[0], itir.FunCall) + and isinstance(node.args[0].fun, itir.SymRef) + and node.args[0].fun.id == "neighbors" + ) + args = self.visit(node.args) + assert len(args) == 1 + args = args[0] + assert len(args) == 1 + neighbors_expr = args[0] + result_dtype = neighbors_expr.dtype + assert isinstance(node.fun, itir.FunCall) + op_name = node.fun.args[0] + assert isinstance(op_name, itir.SymRef) + init = node.fun.args[1] + + nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + + self.context.body.add_scalar(result_name, result_dtype, transient=True) + op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") + reduce_tasklet = self.context.state.add_tasklet( + "reduce", + code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + inputs={"__values"}, + outputs={"__result"}, + ) + self.context.state.add_edge( + args[0].value, + None, + reduce_tasklet, + "__values", + dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + ) + self.context.state.add_edge( + reduce_tasklet, + "__result", + result_access, + None, + dace.Memlet(data=result_name, subset="0"), + ) + else: + assert isinstance(node.fun, itir.FunCall) + assert isinstance(node.fun.args[0], itir.Lambda) + fun_node = node.fun.args[0] + + args = [] + for node_arg in node.args: + if ( + isinstance(node_arg, itir.FunCall) + and isinstance(node_arg.fun, itir.SymRef) + and node_arg.fun.id == "neighbors" + ): + expr = self.visit(node_arg) + args.append(*expr) + else: + args.append(None) + + # first visit only arguments for neighbor selection, all other arguments are none + neighbor_args = [arg for arg in args if arg] + + # check that all neighbors expression have the same range + assert ( + len( + set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) + ) + == 1 + ) + + nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] + nreduce_domain = {"__idx": f"0:{nreduce}"} + + result_dtype = neighbor_args[0].dtype + self.context.body.add_scalar(result_name, result_dtype, transient=True) + + assert isinstance(fun_node.expr, itir.FunCall) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + + # initialize the reduction result based on type of operation + init_value = get_reduce_identity_value(op_name.id, result_dtype) + init_state = self.context.body.add_state_before(self.context.state, "init") + init_tasklet = init_state.add_tasklet( + "init_reduce", {}, {"__out"}, f"__out = {init_value}" + ) + init_state.add_edge( + init_tasklet, + "__out", + init_state.add_access(result_name), + None, + dace.Memlet.simple(result_name, "0"), + ) + + # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet + self.context.reduce_limit = nreduce + self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + + # visit child nodes for input arguments + for i, node_arg in enumerate(node.args): + if not args[i]: + args[i] = self.visit(node_arg)[0] + + lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + + # clear context + self.context.reduce_limit = 0 + self.context.reduce_wcr = None + + # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG + neighbor_tables = filter_neighbor_tables(self.offset_provider) + for conn, _ in neighbor_tables: + var = connectivity_identifier(conn) + lambda_context.body.remove_data(var) + # cleanup symbols previously used for shape and stride of connectivity arrays + p = RemoveUnusedSymbols() + p.apply_pass(lambda_context.body, {}) + + input_memlets = [ + create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + ] + output_memlet = dace.Memlet.simple(result_name, "0") + + input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} + output_mapping = {inner_outputs[0].value.data: output_memlet} + symbol_mapping = map_nested_sdfg_symbols( + self.context.body, lambda_context.body, input_mapping + ) + + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( + self.context.state, + sdfg=lambda_context.body, + map_ranges=nreduce_domain, + inputs=input_mapping, + outputs=output_mapping, + symbol_mapping=symbol_mapping, + input_nodes={arg.value.data: arg.value for arg in args}, + output_nodes={result_name: result_access}, + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) + + return [ValueExpr(result_access, result_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 85b1445dd9..889a1ab150 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -81,10 +81,83 @@ def map_nested_sdfg_symbols( return symbol_mapping +def add_mapped_nested_sdfg( + state: dace.SDFGState, + map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], + sdfg: dace.SDFG, + symbol_mapping: dict[str, Any] | None = None, + schedule: Any = dace.dtypes.ScheduleType.Default, + unroll_map: bool = False, + location: Any = None, + debuginfo: Any = None, + input_nodes: dict[str, dace.nodes.AccessNode] | None = None, + output_nodes: dict[str, dace.nodes.AccessNode] | None = None, +) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: + if not symbol_mapping: + symbol_mapping = {sym: sym for sym in sdfg.free_symbols} + + nsdfg_node = state.add_nested_sdfg( + sdfg, + None, + set(inputs.keys()), + set(outputs.keys()), + symbol_mapping, + name=sdfg.name, + schedule=schedule, + location=location, + debuginfo=debuginfo, + ) + + map_entry, map_exit = state.add_map( + f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo + ) + + if input_nodes is None: + input_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + } + if output_nodes is None: + output_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + } + if not inputs: + state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) + for name, memlet in inputs.items(): + state.add_memlet_path( + input_nodes[memlet.data], + map_entry, + nsdfg_node, + memlet=memlet, + src_conn=None, + dst_conn=name, + propagate=True, + ) + if not outputs: + state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) + for name, memlet in outputs.items(): + state.add_memlet_path( + nsdfg_node, + map_exit, + output_nodes[memlet.data], + memlet=memlet, + src_conn=name, + dst_conn=None, + propagate=True, + ) + + return nsdfg_node, map_entry, map_exit + + _unique_id = 0 -def unique_var_name(): +def unique_name(prefix): global _unique_id _unique_id += 1 - return f"__var_{_unique_id}" + return f"{prefix}_{_unique_id}" + + +def unique_var_name(): + return unique_name("__var") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index f2c8525346..7f2b11afff 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -28,9 +28,6 @@ def test_external_local_field(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over non-field expressions") - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 5f19311a32..c10eb533f9 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -92,9 +92,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_expression_in_call(unstructured_case): if unstructured_case.backend == dace_iterator.run_dace_iterator: - # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir, - # so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") + pytest.xfail("Not supported in DaCe backend: make_const_list") @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: @@ -115,9 +113,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_with_common_expression(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index b0d04d4379..20d1be25f5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -93,8 +93,8 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor_no_dace_exec, lift_mode, stencil): - program_processor, validate = program_processor_no_dace_exec +def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2e_arr)) @@ -116,10 +116,8 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: map_ builtin, neighbors, reduce") +def test_map_neighbors(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -144,9 +142,7 @@ def map_make_const_list(in_edges): def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: map_ builtin, neighbors, reduce, make_const_list" - ) + pytest.xfail("Not supported in DaCe backend: make_const_list") inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -168,10 +164,8 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) -def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil( - program_processor_no_dace_exec, lift_mode -): - program_processor, validate = program_processor_no_dace_exec +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode): + program_processor, validate = program_processor inp = vertex_index_field() out = gtx.np_as_located_field(Cell)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr)) @@ -196,10 +190,10 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2EDim)(np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -219,10 +213,10 @@ def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): assert np.allclose(out, ref) -def test_sparse_input_field_v2v(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field_v2v(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -278,10 +272,10 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_slice_twice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim, V2VDim)(v2v_arr[v2v_arr]) - out = gtx.np_as_located_field(Vertex)(np.zeros([9])) + out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[v2v_arr][:, 2, 1] run_processor(