From 315d9203bb667baa3daaea4b797a0846a2b70887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:35:51 +0100 Subject: [PATCH] feat[dace]: Computing SDFG call arguments (#1398) Added a function to get the arguments to call an SDFG. This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`. This made it very complex to run an SDFG outside this function. One should consider this as an amend to [PR #1379](https://github.com/GridTools/gt4py/pull/1379). --- .../runners/dace_iterator/__init__.py | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 59569de30b..97dd90eb54 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -90,8 +90,9 @@ def preprocess_program( return fencil_definition -def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: + sdfg_params: Sequence[str] = sdfg.arg_names + return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)} def _ensure_is_on_device( @@ -127,13 +128,16 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + sdfg: dace.SDFG, + args: Sequence[Any], ) -> Mapping[str, int]: + sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays + sdfg_params: Sequence[str] = sdfg.arg_names return { str(sym): -drange.start - for param, arg in zip(params, args) + for sdfg_param, arg in zip(sdfg_params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -189,6 +193,45 @@ def get_cache_id( return m.hexdigest() +def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: + """Extracts the arguments needed to call the SDFG. + + This function can handle the same arguments that are passed to `run_dace_iterator()`. + + Args: + sdfg: The SDFG for which we want to get the arguments. + """ # noqa: D401 + offset_provider = kwargs["offset_provider"] + on_gpu = kwargs.get("on_gpu", False) + + neighbor_tables = filter_neighbor_tables(offset_provider) + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + dace_args = get_args(sdfg, args) + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_conn_args = get_connectivity_args(neighbor_tables, device) + dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) + dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) + dace_strides = get_stride_args(sdfg.arrays, dace_field_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg, args) + all_args = { + **dace_args, + **dace_conn_args, + **dace_shapes, + **dace_conn_shapes, + **dace_strides, + **dace_conn_strides, + **dace_offsets, + } + expected_args = { + key: value + for key, value in all_args.items() + if key in sdfg.signature_arglist(with_types=False) + } + return expected_args + + def build_sdfg_from_itir( program: itir.FencilDefinition, *args, @@ -248,8 +291,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables, device) - dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) - dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) - - all_args = { - **dace_args, - **dace_conn_args, - **dace_shapes, - **dace_conn_shapes, - **dace_strides, - **dace_conn_strides, - **dace_offsets, - } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True)