Skip to content

Commit

Permalink
feat[dace]: Computing SDFG call arguments (#1398)
Browse files Browse the repository at this point in the history
Added a function to get the arguments to call an SDFG.

This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`.
This made it very complex to run an SDFG outside this function.

One should consider this as an amend to [PR #1379](#1379).
  • Loading branch information
philip-paul-mueller authored Dec 19, 2023
1 parent 6c7c5d5 commit 315d920
Showing 1 changed file with 49 additions and 30 deletions.
79 changes: 49 additions & 30 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def preprocess_program(
return fencil_definition


def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
return {name.id: convert_arg(arg) for name, arg in zip(params, args)}
def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)}


def _ensure_is_on_device(
Expand Down Expand Up @@ -127,13 +128,16 @@ def get_shape_args(


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
sdfg: dace.SDFG,
args: Sequence[Any],
) -> Mapping[str, int]:
sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays
sdfg_params: Sequence[str] = sdfg.arg_names
return {
str(sym): -drange.start
for param, arg in zip(params, args)
for sdfg_param, arg in zip(sdfg_params, args)
if common.is_field(arg)
for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain))
}


Expand Down Expand Up @@ -189,6 +193,45 @@ def get_cache_id(
return m.hexdigest()


def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
"""Extracts the arguments needed to call the SDFG.
This function can handle the same arguments that are passed to `run_dace_iterator()`.
Args:
sdfg: The SDFG for which we want to get the arguments.
""" # noqa: D401
offset_provider = kwargs["offset_provider"]
on_gpu = kwargs.get("on_gpu", False)

neighbor_tables = filter_neighbor_tables(offset_provider)
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = get_args(sdfg, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg, args)
all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
return expected_args


def build_sdfg_from_itir(
program: itir.FencilDefinition,
*args,
Expand Down Expand Up @@ -248,8 +291,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
if build_cache is not None and cache_id in build_cache:
Expand Down Expand Up @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
if build_cache is not None:
build_cache[cache_id] = sdfg_program

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
expected_args = get_sdfg_args(sdfg, *args, **kwargs)

with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
Expand Down

0 comments on commit 315d920

Please sign in to comment.