Skip to content

Commit

Permalink
refactor[next]: Simplify ir_makers.domain (#1903)
Browse files Browse the repository at this point in the history
The `domain` ir maker now only accepts Dimensions, not strings. This
simplifies the typing in some places and is less error prone, since one
can not accidentally create a domain with the wrong kind, e.g. by using
`"KDim"`.

Co-authored-by: Till Ehrengruber <till.ehrengruber@cscs.ch>
  • Loading branch information
SF-N and tehrengruber authored Mar 6, 2025
1 parent c610561 commit 098d325
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain:
return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above

def as_expr(self) -> itir.FunCall:
converted_ranges: dict[common.Dimension | str, tuple[itir.Expr, itir.Expr]] = {
converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = {
key: (value.start, value.stop) for key, value in self.ranges.items()
}
return im.domain(self.grid_type, converted_ranges)
Expand Down
18 changes: 5 additions & 13 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,22 +402,14 @@ def _impl(*its: itir.Expr) -> itir.FunCall:

def domain(
grid_type: Union[common.GridType, str],
ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]],
ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]],
) -> itir.FunCall:
"""
>>> str(
... domain(
... common.GridType.CARTESIAN,
... {
... common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL): (0, 10),
... common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL): (0, 20),
... },
... )
... )
>>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL)
>>> JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL)
>>> str(domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (0, 20)}))
'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
>>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)}))
'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
>>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)}))
>>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)}))
'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
"""
if isinstance(grid_type, common.GridType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_cond(offset_provider):

testee = im.if_(cond, field_1, field_2)

domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider)
expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}}
expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)(
Expand Down

0 comments on commit 098d325

Please sign in to comment.