From 8f741dccd85968a2bed1d0ce4e55b24e1449a5ab Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 06:57:58 +0000 Subject: [PATCH 01/21] refactor[next] Prepare new Field for embedded --- docs/user/cartesian/Makefile | 15 +- docs/user/cartesian/arrays.rst | 6 +- src/gt4py/_core/definitions.py | 52 +++-- src/gt4py/next/__init__.py | 1 + src/gt4py/next/common.py | 107 ++++++---- src/gt4py/next/embedded/nd_array_field.py | 183 +++++++++++++----- src/gt4py/next/ffront/decorator.py | 10 +- .../embedded_tests/test_nd_array_field.py | 95 +++++++-- 8 files changed, 346 insertions(+), 123 deletions(-) diff --git a/docs/user/cartesian/Makefile b/docs/user/cartesian/Makefile index 091bc3b8d2..13e692b96d 100644 --- a/docs/user/cartesian/Makefile +++ b/docs/user/cartesian/Makefile @@ -2,12 +2,13 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -SRCDIR = ../../../src/gt4py -AUTODOCDIR = _source -BUILDDIR = _build +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +SRCDIR = ../../../src/gt4py +SPHINX_APIDOC_OPTS = --private # private modules for gt4py._core +AUTODOCDIR = _source +BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) @@ -55,7 +56,7 @@ clean: autodoc: @echo @echo "Running sphinx-apidoc..." - sphinx-apidoc ${SPHINX_OPTS} -o ${AUTODOCDIR} ${SRCDIR} + sphinx-apidoc ${SPHINX_APIDOC_OPTS} -o ${AUTODOCDIR} ${SRCDIR} @echo @echo "sphinx-apidoc finished. The generated autodocs are in $(AUTODOCDIR)." diff --git a/docs/user/cartesian/arrays.rst b/docs/user/cartesian/arrays.rst index 6788e2757f..6ef7c6e5c1 100644 --- a/docs/user/cartesian/arrays.rst +++ b/docs/user/cartesian/arrays.rst @@ -39,6 +39,8 @@ Internally, gt4py uses the utilities :code:`gt4py.utils.as_numpy` and :code:`gt4 buffers. GT4Py developers are advised to always use those utilities as to guarantee support across gt4py as the supported interfaces are extended. +.. _cartesian-arrays-dimension-mapping: + Dimension Mapping ^^^^^^^^^^^^^^^^^ @@ -56,6 +58,8 @@ which implements this lookup. Note: Support for xarray can be added manually by the user by means of the mechanism described `here `_. +.. _cartesian-arrays-default-origin: + Default Origin ^^^^^^^^^^^^^^ @@ -180,4 +184,4 @@ Additionally, these **optional** keyword-only parameters are accepted: determine the default layout for the storage. Currently supported will be :code:`"I"`, :code:`"J"`, :code:`"K"` and additional dimensions as string representations of integers, starting at :code:`"0"`. (This information is not retained in the resulting array, and needs to be specified instead - with the :code:`__gt_dims__` interface. ) \ No newline at end of file + with the :code:`__gt_dims__` interface. ) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 2546ae3e4e..f49bac531a 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -213,18 +213,13 @@ class DType(Generic[ScalarT]): """ scalar_type: Type[ScalarT] - tensor_shape: TensorShape + tensor_shape: TensorShape = dataclasses.field(default=()) - def __init__( - self, scalar_type: Type[ScalarT], tensor_shape: Sequence[IntegralScalar] = () - ) -> None: - if not isinstance(scalar_type, type): - raise TypeError(f"Invalid scalar type '{scalar_type}'") - if not is_valid_tensor_shape(tensor_shape): - raise TypeError(f"Invalid tensor shape '{tensor_shape}'") - - object.__setattr__(self, "scalar_type", scalar_type) - object.__setattr__(self, "tensor_shape", tensor_shape) + def __post_init__(self) -> None: + if not isinstance(self.scalar_type, type): + raise TypeError(f"Invalid scalar type '{self.scalar_type}'") + if not is_valid_tensor_shape(self.tensor_shape): + raise TypeError(f"Invalid tensor shape '{self.tensor_shape}'") @functools.cached_property def kind(self) -> DTypeKind: @@ -251,6 +246,14 @@ def lanes(self) -> int: def subndim(self) -> int: return len(self.tensor_shape) + def __eq__(self, other: Any) -> bool: + # TODO: discuss (make concrete subclasses equal to instances of this with the same type) + return ( + isinstance(other, DType) + and self.scalar_type == other.scalar_type + and self.tensor_shape == other.tensor_shape + ) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): @@ -322,6 +325,11 @@ class Float64DType(FloatingDType[float64]): scalar_type: Final[Type[float64]] = dataclasses.field(default=float64, init=False) +@dataclasses.dataclass(frozen=True) +class BoolDType(DType[bool_]): + scalar_type: Final[Type[bool_]] = dataclasses.field(default=bool_, init=False) + + DTypeLike = Union[DType, npt.DTypeLike] @@ -332,11 +340,29 @@ def dtype(dtype_like: DTypeLike) -> DType: # -- Custom protocols -- class GTDimsInterface(Protocol): - __gt_dims__: Tuple[str, ...] + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming the buffer dimensions. + + In `gt4py.cartesian` the allowed values are `"I"`, `"J"` and `"K"` with the established semantics. + + See :ref:`cartesian-arrays-dimension-mapping` for details. + """ + + @property + def __gt_dims__(self) -> Tuple[str, ...]: + ... class GTOriginInterface(Protocol): - __gt_origin__: Tuple[int, ...] + """ + A `GTOriginInterface` is an object providing `__gt_origin__`, describing the origin of a buffer. + + See :ref:`cartesian-arrays-default-origin` for details. + """ + + @property + def __gt_origin__(self) -> Tuple[int, ...]: + ... # -- Device representation -- diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index b4d1fc0c09..5d7a5f480e 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,6 +25,7 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType +from .embedded import nd_array_field from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index e06f9c54b1..6d261212b6 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -38,14 +38,14 @@ TypeAlias, TypeVar, extended_runtime_checkable, - final, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimT = TypeVar("DimT", bound="Dimension") -DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) +DimsT = TypeVar( + "DimsT", covariant=True +) # bound to `Sequence[Dimension]` if instance of Dimension would be a type class Infinity(int): @@ -136,21 +136,45 @@ def __and__(self, other: Set[Any]) -> UnitRange: raise NotImplementedError("Can only find the intersection between UnitRange instances.") -DomainRange: TypeAlias = UnitRange | int +IntIndex: TypeAlias = int | np.integer +DomainRange: TypeAlias = UnitRange | IntIndex NamedRange: TypeAlias = tuple[Dimension, UnitRange] -NamedIndex: TypeAlias = tuple[Dimension, int] +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] FieldSlice: TypeAlias = ( DomainSlice - | tuple[slice | int | EllipsisType, ...] + | tuple[slice | IntIndex | EllipsisType, ...] | slice - | int + | IntIndex | EllipsisType | NamedRange | NamedIndex ) +def is_int_index(p: Any) -> TypeGuard[IntIndex]: + return isinstance(p, (int, np.integer)) + + +def is_named_range(v: Any) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) + and len(v) == 2 + and isinstance(v[0], Dimension) + and isinstance(v[1], UnitRange) + ) + + +def is_named_index(v: Any) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) + ) + + +def is_domain_slice(v: Any) -> TypeGuard[DomainSlice]: + return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + + @dataclasses.dataclass(frozen=True) class Domain(Sequence[NamedRange]): dims: tuple[Dimension, ...] @@ -212,8 +236,7 @@ def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange(Infinity.negative(), Infinity.positive()) - for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims ) @@ -229,8 +252,22 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... +class NextGTDimsInterface(Protocol): + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. + + The dimension names are objects of type :class:`Dimension`, in contrast to :py:mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :py:class:`~gt4py._core.definitions.GTDimsInterface` . + """ + + # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + ... + + @extended_runtime_checkable -class Field(Protocol[DimsT, core_defs.ScalarT]): +class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property @@ -241,17 +278,12 @@ def domain(self) -> Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... - @property - def value_type(self) -> type[core_defs.ScalarT]: - ... - @property def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: - codomain = self.value_type.__name__ - return f"⟨{self.domain!s} → {codomain}⟩" + return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod def remap(self, index_field: Field) -> Field: @@ -325,20 +357,29 @@ def __pow__(self, other: Field | core_defs.ScalarT) -> Field: def is_field( v: Any, -) -> TypeGuard[Field]: # this function is introduced to localize the `type: ignore`` +) -> TypeGuard[Field]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed return isinstance(v, Field) # type: ignore[misc] # we use extended_runtime_checkable -class FieldABC(Field[DimsT, core_defs.ScalarT]): - """Abstract base class for implementations of the :class:`Field` protocol.""" +@extended_runtime_checkable +class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): + @abc.abstractmethod + def __setitem__(self, index: FieldSlice, value: Field | core_defs.ScalarT) -> None: + ... - @final - def __setattr__(self, key, value) -> None: - raise TypeError("Immutable type") - @final - def __setitem__(self, key, value) -> None: - raise TypeError("Immutable type") +def is_mutable_field( + v: Any, +) -> TypeGuard[MutableField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable @functools.singledispatch @@ -347,7 +388,7 @@ def field( /, *, domain: Optional[Any] = None, # TODO(havogt): provide domain_like to Domain conversion - value_type: Optional[type] = None, + dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -467,17 +508,3 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: ) return topologically_sorted_list - - -def is_named_range(v: Any) -> TypeGuard[NamedRange]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], UnitRange) - - -def is_named_index(v: Any) -> TypeGuard[NamedIndex]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], int) - - -def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: - return isinstance(index, Sequence) and all( - is_named_range(idx) or is_named_index(idx) for idx in index - ) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9813efdd22..a82b7f32c6 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -56,7 +56,7 @@ def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_nam def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field): + if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): if not a.domain == b.domain: domain_intersection = a.domain & b.domain a_broadcasted = _broadcast(a, domain_intersection.dims) @@ -82,7 +82,7 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): +class _BaseNdArrayField(common.MutableField[common.DimsT, core_defs.ScalarT]): """ Shared field implementation for NumPy-like fields. @@ -94,7 +94,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): _domain: common.Domain _ndarray: core_defs.NDArrayObject - _value_type: type[core_defs.ScalarT] array_ns: ClassVar[ ModuleType @@ -133,13 +132,28 @@ def register_builtin_func( def domain(self) -> common.Domain: return self._domain + @property + def shape(self) -> tuple[int, ...]: + return self._ndarray.shape + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return self._domain.dims + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple(-r.start for _, r in self._domain) + @property def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray + def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + return np.asarray(self._ndarray, dtype) + @property - def value_type(self) -> type[core_defs.ScalarT]: - return self._value_type + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(self._ndarray.dtype.type) @classmethod def from_array( @@ -149,37 +163,54 @@ def from_array( /, *, domain: common.Domain, - value_type: Optional[type] = None, + dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: xp = cls.array_ns - dtype = None - if value_type is not None: - dtype = xp.dtype(value_type) - array = xp.asarray(data, dtype=dtype) - value_type = array.dtype.type # TODO add support for Dimensions as value_type + xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype_like is not None: + assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - assert all(isinstance(d, common.Dimension) for d, r in domain), domain + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim assert all( - len(nr[1]) == s or (s == 1 and nr[1] == common.UnitRange.infinity()) - for nr, s in zip(domain, array.shape) + len(r) == s or (s == 1 and r == common.UnitRange.infinity()) + for r, s in zip(domain.ranges, array.shape) ) - assert value_type is not None # for mypy - return cls(domain, array, value_type) + return cls(domain, array) def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() + def restrict(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: + index = _tuplize_field_slice(index) + + if common.is_domain_slice(index): + return self._getitem_absolute_slice(index) + + assert isinstance(index, tuple) + if all( + isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index + ): + return self._getitem_relative_slice(index) + + raise IndexError(f"Unsupported index type: {index}") + + __getitem__ = restrict + __call__ = None # type: ignore[assignment] # TODO: remap __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") @@ -194,27 +225,37 @@ def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") - def __getitem__(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: - if ( - not isinstance(index, tuple) - and not common.is_domain_slice(index) - or common.is_named_index(index) - or common.is_named_range(index) - ): - index = cast(common.FieldSlice, (index,)) + __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - if common.is_domain_slice(index): - return self._getitem_absolute_slice(index) + def __and__(self, other: common.Field) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( + self, other + ) + raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") - assert isinstance(index, tuple) - if all(isinstance(idx, (slice, int)) or idx is Ellipsis for idx in index): - return self._getitem_relative_slice(index) + __rand__ = __and__ - raise IndexError(f"Unsupported index type: {index}") + def __or__(self, other: common.Field) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") - restrict = ( - __getitem__ # type:ignore[assignment] # TODO(havogt) I don't see the problem that mypy has - ) + __ror__ = __or__ + + def __xor__(self, other: common.Field) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( + self, other + ) + raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") + + __rxor__ = __xor__ + + def __invert__(self) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") def _getitem_absolute_slice( self, index: common.DomainSlice @@ -241,10 +282,10 @@ def _getitem_absolute_slice( assert core_defs.is_scalar_type(new) return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + return self.__class__.from_array(new, domain=new_domain) def _getitem_relative_slice( - self, indices: tuple[slice | int | EllipsisType, ...] + self, indices: tuple[slice | common.IntIndex | EllipsisType, ...] ) -> common.Field | core_defs.ScalarT: new = self.ndarray[indices] new_dims = [] @@ -257,7 +298,7 @@ def _getitem_relative_slice( new_dims.append(dim) new_ranges.append(_slice_range(rng, idx)) else: - assert isinstance(idx, int) # not in new_domain + assert common.is_int_index(idx) # not in new_domain new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) @@ -265,7 +306,7 @@ def _getitem_relative_slice( assert core_defs.is_scalar_type(new), new return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + return self.__class__.from_array(new, domain=new_domain) # -- Specialized implementations for intrinsic operations on array fields -- @@ -295,6 +336,30 @@ def _getitem_relative_slice( fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] ) + +def _np_cp_setitem( + self, + index: common.FieldSlice, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, +) -> None: + index = _tuplize_field_slice(index) + if common.is_field(value): + # TODO(havogt): in case of `is_field(value)` we should additionally check that `value.domain == self[slice].domain` + value = value.ndarray + + if common.is_domain_slice(index): + slices = _get_slices_from_domain_slice(self.domain, index) + self.ndarray[slices] = value + return + + assert isinstance(index, tuple) + if all(isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index): + self.ndarray[index] = value + return + + raise IndexError(f"Unsupported index type: {index}") + + # -- Concrete array implementations -- # NumPy _nd_array_implementations = [np] @@ -304,6 +369,8 @@ def _getitem_relative_slice( class NumPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = np + __setitem__ = _np_cp_setitem + common.field.register(np.ndarray, NumPyArrayField.from_array) @@ -315,6 +382,8 @@ class NumPyArrayField(_BaseNdArrayField): class CuPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = cp + __setitem__ = _np_cp_setitem + common.field.register(cp.ndarray, CuPyArrayField.from_array) # JAX @@ -325,9 +394,33 @@ class CuPyArrayField(_BaseNdArrayField): class JaxArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = jnp + def __setitem__( + self, + index: common.FieldSlice, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + # use `self.ndarray.at(index).set(value)` + raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") + common.field.register(jnp.ndarray, JaxArrayField.from_array) +def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: + """ + Wrap a single index/slice/range into a tuple. + + Note: the condition is complex as `NamedRange`, `NamedIndex` are implemented as `tuple`. + """ + if ( + not isinstance(v, tuple) + and not common.is_domain_slice(v) + or common.is_named_index(v) + or common.is_named_range(v) + ): + return cast(common.FieldSlice, (v,)) + return v + + def _find_index_of_dim( dim: common.Dimension, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], @@ -373,7 +466,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | int | None, ...]: +) -> tuple[slice | common.IntIndex | None, ...]: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -388,7 +481,7 @@ def _get_slices_from_domain_slice( specified in the Domain. If a dimension is not included in the named indices or ranges, a None is used to indicate expansion along that axis. """ - slice_indices: list[slice | int | None] = [] + slice_indices: list[slice | common.IntIndex | None] = [] for pos_old, (dim, _) in enumerate(domain): if (pos := _find_index_of_dim(dim, domain_slice)) is not None: @@ -399,7 +492,9 @@ def _get_slices_from_domain_slice( return tuple(slice_indices) -def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> slice | int: +def _compute_slice( + rng: common.DomainRange, domain: common.Domain, pos: int +) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. Args: @@ -421,7 +516,7 @@ def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> rng.start - domain.ranges[pos].start, rng.stop - domain.ranges[pos].start, ) - elif isinstance(rng, int): + elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") @@ -443,9 +538,9 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit def _expand_ellipsis( - indices: tuple[int | slice | EllipsisType, ...], target_size: int -) -> tuple[int | slice, ...]: - expanded_indices: list[int | slice] = [] + indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int +) -> tuple[common.IntIndex | slice, ...]: + expanded_indices: list[common.IntIndex | slice] = [] for idx in indices: if idx is Ellipsis: expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 12ab3955ab..1dc68d0f1e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -132,9 +132,11 @@ def _field_constituents_shape_and_dims( yield from _field_constituents_shape_and_dims(el, el_type) elif isinstance(arg_type, ts.FieldType): dims = type_info.extract_dims(arg_type) - if hasattr(arg, "shape"): - assert len(arg.shape) == len(dims) - yield (arg.shape, dims) + if hasattr(arg, "domain"): + assert len(arg.domain) == len(dims) + assert all(rg.start == 0 for _, rg in arg.domain) + shape = tuple(rg.stop - rg.start for _, rg in arg.domain) + yield (shape, dims) else: yield (None, dims) elif isinstance(arg_type, ts.ScalarType): @@ -796,7 +798,7 @@ def scan_operator( ... def scan_operator(carry: float, val: float) -> float: ... return carry+val >>> scan_operator(inp, out=out, offset_provider={}) # doctest: +SKIP - >>> out.array() # doctest: +SKIP + >>> out.ndarray # doctest: +SKIP array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) """ # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index a2aa3112bd..bda6ffb871 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -40,16 +40,42 @@ def nd_array_implementation(request): @pytest.fixture( - params=[operator.add, operator.sub, operator.mul, operator.truediv, operator.floordiv], + params=[ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + ] ) -def binary_op(request): +def binary_arithmetic_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation): +@pytest.fixture( + params=[operator.xor, operator.and_, operator.or_], +) +def binary_logical_op(request): + yield request.param + + +@pytest.fixture(params=[operator.neg, operator.pos]) +def unary_arithmetic_op(request): + yield request.param + + +@pytest.fixture(params=[operator.invert]) +def unary_logical_op(request): + yield request.param + + +def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): + if not dtype: + dtype = nd_array_implementation.float32 return common.field( - nd_array_implementation.asarray(lst, dtype=nd_array_implementation.float32), - domain=((common.Dimension("foo"), common.UnitRange(0, len(lst))),), + nd_array_implementation.asarray(lst, dtype=dtype), + domain=common.Domain((common.Dimension("foo"),), (common.UnitRange(0, len(lst)),)), ) @@ -72,16 +98,57 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) -def test_binary_ops(binary_op, nd_array_implementation): +def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - result = binary_op(*field_inputs) + result = binary_arithmetic_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_binary_logical_ops(binary_logical_op, nd_array_implementation): + inp_a = [True, True, False, False] + inp_b = [True, False, True, False] + inputs = [inp_a, inp_b] + + expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) + + field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + + result = binary_logical_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_logical_ops(unary_logical_op, nd_array_implementation): + inp = [ + True, + False, + ] + + expected = unary_logical_op(np.asarray(inp)) + + field_input = _make_field(inp, nd_array_implementation, dtype=bool) + + result = unary_logical_op(field_input) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): + inp = [1.0, -2.0, 0.0] + + expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) + + field_input = _make_field(inp, nd_array_implementation) + + result = unary_arithmetic_op(field_input) assert np.allclose(result.ndarray, expected) @@ -93,7 +160,7 @@ def test_binary_ops(binary_op, nd_array_implementation): ((JDim,), (None, slice(5, 10))), ], ) -def test_binary_operations_with_intersection(binary_op, dims, expected_indices): +def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): arr1 = np.arange(10) arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) @@ -103,8 +170,8 @@ def test_binary_operations_with_intersection(binary_op, dims, expected_indices): field1 = common.field(arr1, domain=arr1_domain) field2 = common.field(arr2, domain=arr2_domain) - op_result = binary_op(field1, field2) - expected_result = binary_op(arr1[expected_indices[0], expected_indices[1]], arr2) + op_result = binary_arithmetic_op(field1, field2) + expected_result = binary_arithmetic_op(arr1[expected_indices[0], expected_indices[1]], arr2) assert op_result.ndarray.shape == (5, 5) assert np.allclose(op_result.ndarray, expected_result) @@ -282,7 +349,7 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): field = common.field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain.dims == expected_dimensions @@ -325,7 +392,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -369,7 +436,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain From 3d9ae6cca4c508a1bf048a003dfe9d9108565356 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 07:36:43 +0000 Subject: [PATCH 02/21] remove changes belonging to itir.embedded pr --- src/gt4py/next/ffront/decorator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 1dc68d0f1e..12ab3955ab 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -132,11 +132,9 @@ def _field_constituents_shape_and_dims( yield from _field_constituents_shape_and_dims(el, el_type) elif isinstance(arg_type, ts.FieldType): dims = type_info.extract_dims(arg_type) - if hasattr(arg, "domain"): - assert len(arg.domain) == len(dims) - assert all(rg.start == 0 for _, rg in arg.domain) - shape = tuple(rg.stop - rg.start for _, rg in arg.domain) - yield (shape, dims) + if hasattr(arg, "shape"): + assert len(arg.shape) == len(dims) + yield (arg.shape, dims) else: yield (None, dims) elif isinstance(arg_type, ts.ScalarType): @@ -798,7 +796,7 @@ def scan_operator( ... def scan_operator(carry: float, val: float) -> float: ... return carry+val >>> scan_operator(inp, out=out, offset_provider={}) # doctest: +SKIP - >>> out.ndarray # doctest: +SKIP + >>> out.array() # doctest: +SKIP array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) """ # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons From 6f7241715d6425a8bea37f97346e1d1a5cb33423 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 08:16:41 +0000 Subject: [PATCH 03/21] cleanup setitem --- src/gt4py/next/common.py | 9 +- src/gt4py/next/embedded/nd_array_field.py | 91 +++++++++---------- .../embedded_tests/test_nd_array_field.py | 39 ++++++++ 3 files changed, 86 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6d261212b6..68e2a4d0b9 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -141,14 +141,9 @@ def __and__(self, other: Set[Any]) -> UnitRange: NamedRange: TypeAlias = tuple[Dimension, UnitRange] NamedIndex: TypeAlias = tuple[Dimension, IntIndex] DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] +BufferSlice: TypeAlias = tuple[slice | IntIndex | EllipsisType, ...] FieldSlice: TypeAlias = ( - DomainSlice - | tuple[slice | IntIndex | EllipsisType, ...] - | slice - | IntIndex - | EllipsisType - | NamedRange - | NamedIndex + DomainSlice | BufferSlice | slice | IntIndex | EllipsisType | NamedRange | NamedIndex ) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a82b7f32c6..6fdb220fe5 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -188,18 +188,14 @@ def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() def restrict(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: - index = _tuplize_field_slice(index) - - if common.is_domain_slice(index): - return self._getitem_absolute_slice(index) + new_domain, buffer_slice = self._slice(index) - assert isinstance(index, tuple) - if all( - isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index - ): - return self._getitem_relative_slice(index) - - raise IndexError(f"Unsupported index type: {index}") + new_buffer = self.ndarray[buffer_slice] + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new_buffer) + return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new_buffer, domain=new_domain) __getitem__ = restrict @@ -257,13 +253,26 @@ def __invert__(self) -> _BaseNdArrayField: return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") - def _getitem_absolute_slice( + def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.BufferSlice]: + index = _tuplize_field_slice(index) + + if common.is_domain_slice(index): + return self._absolute_slice(index) + + assert isinstance(index, tuple) + if all( + isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index + ): + return self._relative_slice(index) + + raise IndexError(f"Unsupported index type: {index}") + + def _absolute_slice( self, index: common.DomainSlice - ) -> common.Field | core_defs.ScalarT: + ) -> tuple[common.Domain, common.BufferSlice]: slices = _get_slices_from_domain_slice(self.domain, index) new_ranges = [] new_dims = [] - new = self.ndarray[slices] for i, dim in enumerate(self.domain.dims): if (pos := _find_index_of_dim(dim, index)) is not None: @@ -278,21 +287,21 @@ def _getitem_absolute_slice( new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new) - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain) + return new_domain, slices - def _getitem_relative_slice( - self, indices: tuple[slice | common.IntIndex | EllipsisType, ...] - ) -> common.Field | core_defs.ScalarT: - new = self.ndarray[indices] + def _relative_slice( + self, indices: common.BufferSlice + ) -> tuple[common.Domain, common.BufferSlice]: new_dims = [] new_ranges = [] + expanded = _expand_ellipsis(indices, len(self.domain)) + if len(self.domain) < len(expanded): + raise IndexError( + f"Trying to index a `Field` with {len(self.domain)} dimensions with {indices}." + ) for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - self.domain, _expand_ellipsis(indices, len(self.domain)), fillvalue=slice(None) + self.domain, expanded, fillvalue=slice(None) ): if isinstance(idx, slice): new_dims.append(dim) @@ -301,12 +310,7 @@ def _getitem_relative_slice( assert common.is_int_index(idx) # not in new_domain new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) - - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new), new - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain) + return new_domain, indices # -- Specialized implementations for intrinsic operations on array fields -- @@ -338,26 +342,21 @@ def _getitem_relative_slice( def _np_cp_setitem( - self, + self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], index: common.FieldSlice, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - index = _tuplize_field_slice(index) + target_domain, target_slice = self._slice(index) + if common.is_field(value): - # TODO(havogt): in case of `is_field(value)` we should additionally check that `value.domain == self[slice].domain` + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) value = value.ndarray - if common.is_domain_slice(index): - slices = _get_slices_from_domain_slice(self.domain, index) - self.ndarray[slices] = value - return - - assert isinstance(index, tuple) - if all(isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index): - self.ndarray[index] = value - return - - raise IndexError(f"Unsupported index type: {index}") + assert hasattr(self.ndarray, "__setitem__") + self.ndarray[target_slice] = value # -- Concrete array implementations -- @@ -466,7 +465,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | common.IntIndex | None, ...]: +) -> tuple[slice | common.IntIndex, ...]: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -481,7 +480,7 @@ def _get_slices_from_domain_slice( specified in the Domain. If a dimension is not included in the named indices or ranges, a None is used to indicate expansion along that axis. """ - slice_indices: list[slice | common.IntIndex | None] = [] + slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): if (pos := _find_index_of_dim(dim, domain_slice)) is not None: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index bda6ffb871..6ce2eb7f2a 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -477,3 +477,42 @@ def test_slice_range(): result = _slice_range(input_range, slice_obj) assert result == expected + + +@pytest.mark.parametrize( + "index, value", + [ + ((1, 1), 42.0), + ((1, slice(None)), np.ones((10,)) * 42.0), + ( + (1, slice(None)), + common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim,), (UnitRange(0, 10),))), + ), + ], +) +def test_setitem(index, value): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + expected = np.copy(field.ndarray) + expected[index] = value + + field[index] = value + + assert np.allclose(field.ndarray, expected) + + +def test_setitem_wrong_domain(): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + value_incompatible = common.field( + np.ones((10,)) * 42.0, domain=common.Domain((JDim,), (UnitRange(-5, 5),)) + ) + + with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + field[(1, slice(None))] = value_incompatible From 907d9bcd168fed0b9a7fbe3d48ab9d3ff3381581 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 12:08:39 +0000 Subject: [PATCH 04/21] extract domain slicing --- src/gt4py/next/embedded/common.py | 122 ++++++++++++++++ src/gt4py/next/embedded/nd_array_field.py | 131 +++--------------- .../embedded_tests/test_nd_array_field.py | 3 +- 3 files changed, 140 insertions(+), 116 deletions(-) create mode 100644 src/gt4py/next/embedded/common.py diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py new file mode 100644 index 0000000000..8ccdac122e --- /dev/null +++ b/src/gt4py/next/embedded/common.py @@ -0,0 +1,122 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import itertools +from types import EllipsisType +from typing import Any, Optional, Sequence, cast + +from gt4py.next import common + + +def sub_domain(domain: common.Domain, index: common.FieldSlice) -> common.Domain: + index = _tuplize_field_slice(index) + + if common.is_domain_slice(index): + return _absolute_sub_domain(domain, index) + + assert isinstance(index, tuple) + if all(isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index): + return _relative_sub_domain(domain, index) + + raise IndexError(f"Unsupported index type: {index}") + + +def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> common.Domain: + new_dims = [] + new_ranges = [] + + expanded = _expand_ellipsis(index, len(domain)) + if len(domain) < len(expanded): + raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... + domain, expanded, fillvalue=slice(None) + ): + if isinstance(idx, slice): + new_dims.append(dim) + new_ranges.append(_slice_range(rng, idx)) + else: + assert common.is_int_index(idx) # not in new_domain + + return common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + + +def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> common.Domain: + new_ranges = [] + new_dims = [] + + for i, dim in enumerate(domain.dims): + if (pos := _find_index_of_dim(dim, index)) is not None: + index_or_range = index[pos][1] + if isinstance(index_or_range, common.UnitRange): + new_ranges.append(index_or_range) + new_dims.append(dim) + else: + # dimension not mentioned in slice + new_ranges.append(domain.ranges[i]) + new_dims.append(dim) + + return common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + + +def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: + """ + Wrap a single index/slice/range into a tuple. + + Note: the condition is complex as `NamedRange`, `NamedIndex` are implemented as `tuple`. + """ + if ( + not isinstance(v, tuple) + and not common.is_domain_slice(v) + or common.is_named_index(v) + or common.is_named_range(v) + ): + return cast(common.FieldSlice, (v,)) + return v + + +def _expand_ellipsis( + indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int +) -> tuple[common.IntIndex | slice, ...]: + expanded_indices: list[common.IntIndex | slice] = [] + for idx in indices: + if idx is Ellipsis: + expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) + else: + expanded_indices.append(idx) + return tuple(expanded_indices) + + +def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: + # handle slice(None) case + if slice_obj == slice(None): + return common.UnitRange(input_range.start, input_range.stop) + + start = ( + input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop + ) + (slice_obj.start or 0) + stop = ( + input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop + ) + (slice_obj.stop or len(input_range)) + + return common.UnitRange(start, stop) + + +def _find_index_of_dim( + dim: common.Dimension, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> Optional[int]: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i + return None diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 6fdb220fe5..ecc158e820 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,16 +16,16 @@ import dataclasses import functools -import itertools from collections.abc import Callable, Sequence -from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast, overload +from types import ModuleType +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, overload import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common +from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -254,63 +254,17 @@ def __invert__(self) -> _BaseNdArrayField: raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.BufferSlice]: - index = _tuplize_field_slice(index) - - if common.is_domain_slice(index): - return self._absolute_slice(index) - - assert isinstance(index, tuple) - if all( - isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index - ): - return self._relative_slice(index) - - raise IndexError(f"Unsupported index type: {index}") - - def _absolute_slice( - self, index: common.DomainSlice - ) -> tuple[common.Domain, common.BufferSlice]: - slices = _get_slices_from_domain_slice(self.domain, index) - new_ranges = [] - new_dims = [] - - for i, dim in enumerate(self.domain.dims): - if (pos := _find_index_of_dim(dim, index)) is not None: - index_or_range = index[pos][1] - if isinstance(index_or_range, common.UnitRange): - new_ranges.append(index_or_range) - new_dims.append(dim) - else: - # dimension not mentioned in slice - new_ranges.append(self.domain.ranges[i]) - new_dims.append(dim) - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) - - return new_domain, slices - - def _relative_slice( - self, indices: common.BufferSlice - ) -> tuple[common.Domain, common.BufferSlice]: - new_dims = [] - new_ranges = [] - - expanded = _expand_ellipsis(indices, len(self.domain)) - if len(self.domain) < len(expanded): - raise IndexError( - f"Trying to index a `Field` with {len(self.domain)} dimensions with {indices}." - ) - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - self.domain, expanded, fillvalue=slice(None) - ): - if isinstance(idx, slice): - new_dims.append(dim) - new_ranges.append(_slice_range(rng, idx)) - else: - assert common.is_int_index(idx) # not in new_domain + new_domain = embedded_common.sub_domain(self.domain, index) + + index = embedded_common._tuplize_field_slice(index) - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) - return new_domain, indices + slice_ = ( + _get_slices_from_domain_slice(self.domain, index) + if common.is_domain_slice(index) + else index + ) + assert isinstance(slice_, common.BufferSlice) + return new_domain, slice_ # -- Specialized implementations for intrinsic operations on array fields -- @@ -404,38 +358,12 @@ def __setitem__( common.field.register(jnp.ndarray, JaxArrayField.from_array) -def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: - """ - Wrap a single index/slice/range into a tuple. - - Note: the condition is complex as `NamedRange`, `NamedIndex` are implemented as `tuple`. - """ - if ( - not isinstance(v, tuple) - and not common.is_domain_slice(v) - or common.is_named_index(v) - or common.is_named_range(v) - ): - return cast(common.FieldSlice, (v,)) - return v - - -def _find_index_of_dim( - dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> Optional[int]: - for i, (d, _) in enumerate(domain_slice): - if dim == d: - return i - return None - - def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice: list[slice | None] = [] new_domain_dims = [] new_domain_ranges = [] for dim in new_dimensions: - if (pos := _find_index_of_dim(dim, field.domain)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) new_domain_dims.append(dim) new_domain_ranges.append(field.domain[pos][1]) @@ -465,7 +393,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | common.IntIndex, ...]: +) -> common.BufferSlice: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -483,7 +411,7 @@ def _get_slices_from_domain_slice( slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := _find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: @@ -519,30 +447,3 @@ def _compute_slice( return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") - - -def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: - # handle slice(None) case - if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) - - start = ( - input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop - ) + (slice_obj.start or 0) - stop = ( - input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop - ) + (slice_obj.stop or len(input_range)) - - return common.UnitRange(start, stop) - - -def _expand_ellipsis( - indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int -) -> tuple[common.IntIndex | slice, ...]: - expanded_indices: list[common.IntIndex | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 6ce2eb7f2a..1b90eb164e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -23,7 +23,8 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange from gt4py.next.embedded import nd_array_field -from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice, _slice_range +from gt4py.next.embedded.common import _slice_range +from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data From dd8e0a68f564911ed382deed198e7d2d98a18ca8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 13:53:17 +0000 Subject: [PATCH 05/21] add tests --- src/gt4py/next/common.py | 6 + src/gt4py/next/embedded/nd_array_field.py | 2 +- .../unit_tests/embedded_tests/test_common.py | 131 ++++++++++++++++++ .../embedded_tests/test_nd_array_field.py | 10 -- 4 files changed, 138 insertions(+), 11 deletions(-) create mode 100644 tests/next_tests/unit_tests/embedded_tests/test_common.py diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 68e2a4d0b9..c197007fa7 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -170,6 +170,12 @@ def is_domain_slice(v: Any) -> TypeGuard[DomainSlice]: return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) +def is_buffer_slice(v: Any) -> TypeGuard[BufferSlice]: + return isinstance(v, tuple) and all( + isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v + ) + + @dataclasses.dataclass(frozen=True) class Domain(Sequence[NamedRange]): dims: tuple[Dimension, ...] diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ecc158e820..55fbafd8e4 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -263,7 +263,7 @@ def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.Buffer if common.is_domain_slice(index) else index ) - assert isinstance(slice_, common.BufferSlice) + assert common.is_buffer_slice(slice_), slice_ return new_domain, slice_ diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py new file mode 100644 index 0000000000..0a101d8273 --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -0,0 +1,131 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Sequence + +import pytest + +from gt4py.next import common +from gt4py.next.common import UnitRange +from gt4py.next.embedded.common import _slice_range, sub_domain + + +def _d(*dom: tuple[common.Dimension, tuple[int, int]]): + dims = [] + rngs = [] + for dim, (start, stop) in dom: + dims.append(dim) + rngs.append(common.UnitRange(start, stop)) + return common.Domain(tuple(dims), tuple(rngs)) + + +def test_slice_range(): + input_range = UnitRange(2, 10) + slice_obj = slice(2, -2) + expected = UnitRange(4, 8) + + result = _slice_range(input_range, slice_obj) + assert result == expected + + +I = common.Dimension("I") +J = common.Dimension("J") +K = common.Dimension("K") + + +@pytest.mark.parametrize( + "domain, index, expected", + [ + (_d((I, (2, 5))), 1, _d()), + (_d((I, (2, 5))), slice(1, 2), _d((I, (3, 4)))), + (_d((I, (2, 5))), (I, 2), _d()), + (_d((I, (2, 5))), (I, UnitRange(2, 3)), _d((I, (2, 3)))), + (_d((I, (-2, 3))), 1, _d()), + (_d((I, (-2, 3))), slice(1, 2), _d((I, (-1, 0)))), + (_d((I, (-2, 3))), (I, 1), _d()), + (_d((I, (-2, 3))), (I, UnitRange(2, 3)), _d((I, (2, 3)))), + (_d((I, (-2, 3))), -5, _d()), + # (_d((I, (-2, 3))), -6, IndexError), + # (_d((I, (-2, 3))), slice(-6, -7), IndexError), + (_d((I, (-2, 3))), 4, _d()), + # (_d((I, (-2, 3))), 5, IndexError), + # (_d((I, (-2, 3))), slice(4, 5), IndexError), + # (_d((I, (-2, 3))), (I, -3), IndexError), + # (_d((I, (-2, 3))), (I, UnitRange(-3, -2)), IndexError), + # (_d((I, (-2, 3))), (I, 3), IndexError), + # (_d((I, (-2, 3))), (I, UnitRange(3, 4)), IndexError), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + 2, + _d((J, (3, 6)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + slice(2, 3), + _d((I, (4, 5)), (J, (3, 6)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (I, 2), + _d((J, (3, 6)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (I, UnitRange(2, 3)), + _d((I, (2, 3)), (J, (3, 6)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (J, 3), + _d((I, (2, 5)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (J, UnitRange(4, 5)), + _d((I, (2, 5)), (J, (4, 5)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + ((J, 3), (I, 2)), + _d((K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + ((J, UnitRange(4, 5)), (I, 2)), + _d((J, (4, 5)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (slice(1, 2), slice(2, 3)), + _d((I, (3, 4)), (J, (5, 6)), (K, (4, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (Ellipsis, slice(2, 3)), + _d((I, (2, 5)), (J, (3, 6)), (K, (6, 7))), + ), + ( + _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + (slice(1, 2), Ellipsis, slice(2, 3)), + _d((I, (3, 4)), (J, (3, 6)), (K, (6, 7))), + ), + ], +) +def test_sub_domain(domain, index, expected): + if expected is IndexError: + with pytest.raises(IndexError): + print(sub_domain(domain, index)) + else: + result = sub_domain(domain, index) + assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1b90eb164e..a81df1a977 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -23,7 +23,6 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange from gt4py.next.embedded import nd_array_field -from gt4py.next.embedded.common import _slice_range from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -471,15 +470,6 @@ def test_field_unsupported_index(index): field[index] -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) - - result = _slice_range(input_range, slice_obj) - assert result == expected - - @pytest.mark.parametrize( "index, value", [ From 29ee689f8bc742199f383503d2a176d85a85746a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 28 Aug 2023 23:32:22 +0200 Subject: [PATCH 06/21] domain_like and domain construction --- src/gt4py/next/common.py | 109 ++++++++++++++++-- src/gt4py/next/embedded/common.py | 21 ++-- src/gt4py/next/embedded/nd_array_field.py | 20 ++-- src/gt4py/next/utils.py | 6 +- .../unit_tests/embedded_tests/test_common.py | 93 +++++++-------- .../embedded_tests/test_nd_array_field.py | 52 ++++++--- tests/next_tests/unit_tests/test_common.py | 79 ++++++++++--- 7 files changed, 256 insertions(+), 124 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index c197007fa7..341831063b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,7 +19,7 @@ import enum import functools import sys -from collections.abc import Sequence, Set +from collections.abc import Mapping, Sequence, Set from types import EllipsisType from typing import TypeGuard, overload @@ -90,6 +90,17 @@ def __post_init__(self): object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) + @classmethod + def from_unit_range_like(cls, r: UnitRangeLike) -> UnitRange: + assert is_unit_range_like(r) + if isinstance(r, UnitRange): + return r + if isinstance(r, range) and r.step == 1: + return cls(r.start, r.stop) + if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + return cls(r[0], r[1]) + raise RuntimeError("Unreachable") + @classmethod def infinity(cls) -> UnitRange: return cls(Infinity.negative(), Infinity.positive()) @@ -145,6 +156,10 @@ def __and__(self, other: Set[Any]) -> UnitRange: FieldSlice: TypeAlias = ( DomainSlice | BufferSlice | slice | IntIndex | EllipsisType | NamedRange | NamedIndex ) +UnitRangeLike: TypeAlias = UnitRange | range | tuple[int, int] +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, UnitRangeLike]] | Mapping[Dimension, UnitRangeLike] +) def is_int_index(p: Any) -> TypeGuard[IntIndex]: @@ -176,19 +191,86 @@ def is_buffer_slice(v: Any) -> TypeGuard[BufferSlice]: ) -@dataclasses.dataclass(frozen=True) +def is_unit_range_like(v: Any) -> TypeGuard[UnitRangeLike]: + return ( + isinstance(v, UnitRange) + or (isinstance(v, range) and v.step == 1) + or (isinstance(v, tuple) and isinstance(v[0], int) and isinstance(v[1], int)) + ) + + +def is_domain_like(v: Any) -> TypeGuard[DomainLike]: + return ( + isinstance(v, Sequence) + and all( + isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) + for e in v + ) + ) or ( + isinstance(v, Mapping) + and all(isinstance(d, Dimension) and is_unit_range_like(r) for d, r in v.items()) + ) + + +def to_named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: + return (v[0], UnitRange.from_unit_range_like(v[1])) + + +@dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[NamedRange]): dims: tuple[Dimension, ...] ranges: tuple[UnitRange, ...] - def __post_init__(self): + def __init__( + self, + *args: NamedRange, + dims: Optional[tuple[Dimension, ...]] = None, + ranges: Optional[tuple[UnitRange, ...]] = None, + ): + if dims is not None or ranges is not None: + if dims is None and ranges is None: + raise ValueError("Either both none of `dims` and `ranges` must be specified.") + if len(args) > 0: + raise ValueError( + "No extra `args` allowed when constructing fomr `dims` and `ranges`." + ) + + assert dims is not None + assert ranges is not None + if len(dims) != len(ranges): + raise ValueError( + f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." + ) + + object.__setattr__(self, "dims", dims) + object.__setattr__(self, "ranges", ranges) + else: + assert all(is_named_range(arg) for arg in args) + dims, ranges = zip(*args) if len(args) > 0 else ((), ()) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) + if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") - if len(self.dims) != len(self.ranges): - raise ValueError( - f"Number of provided dimensions ({len(self.dims)}) does not match number of provided ranges ({len(self.ranges)})." + @classmethod + def from_domain_like(cls, domain_like: DomainLike) -> Domain: + assert is_domain_like(domain_like) + if isinstance(domain_like, Domain): + return domain_like + if isinstance(domain_like, Sequence) and all( + isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) + for e in domain_like + ): + return cls(*tuple(to_named_range(d) for d in domain_like)) + if isinstance(domain_like, Mapping) and all( + isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items() + ): + return cls( + dims=tuple(domain_like.keys()), + ranges=tuple(UnitRange.from_unit_range_like(r) for r in domain_like.values()), ) + raise RuntimeError("Unreachable") def __len__(self) -> int: return len(self.ranges) @@ -211,7 +293,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] - return Domain(dims_slice, ranges_slice) + return Domain(dims=dims_slice, ranges=ranges_slice) elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) @@ -221,7 +303,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") - def __and__(self, other: "Domain") -> "Domain": + def __and__(self, other: Domain) -> Domain: broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -230,7 +312,7 @@ def __and__(self, other: "Domain") -> "Domain": _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(broadcast_dims, intersected_ranges) + return Domain(dims=broadcast_dims, ranges=intersected_ranges) def _broadcast_ranges( @@ -388,7 +470,7 @@ def field( definition: Any, /, *, - domain: Optional[Any] = None, # TODO(havogt): provide domain_like to Domain conversion + domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -509,3 +591,10 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: ) return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list + return topologically_sorted_list diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 8ccdac122e..53c8cc30a6 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -33,8 +33,7 @@ def sub_domain(domain: common.Domain, index: common.FieldSlice) -> common.Domain def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> common.Domain: - new_dims = [] - new_ranges = [] + named_ranges: list[common.NamedRange] = [] expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): @@ -43,30 +42,25 @@ def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> co domain, expanded, fillvalue=slice(None) ): if isinstance(idx, slice): - new_dims.append(dim) - new_ranges.append(_slice_range(rng, idx)) + named_ranges.append((dim, _slice_range(rng, idx))) else: assert common.is_int_index(idx) # not in new_domain - return common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + return common.Domain(*named_ranges) def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> common.Domain: - new_ranges = [] - new_dims = [] - + named_ranges: list[common.NamedRange] = [] for i, dim in enumerate(domain.dims): if (pos := _find_index_of_dim(dim, index)) is not None: index_or_range = index[pos][1] if isinstance(index_or_range, common.UnitRange): - new_ranges.append(index_or_range) - new_dims.append(dim) + named_ranges.append((dim, index_or_range)) else: # dimension not mentioned in slice - new_ranges.append(domain.ranges[i]) - new_dims.append(dim) + named_ranges.append((dim, domain.ranges[i])) - return common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + return common.Domain(*named_ranges) def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: @@ -120,3 +114,4 @@ def _find_index_of_dim( if dim == d: return i return None + return None diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 55fbafd8e4..7867b02b32 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -162,9 +162,10 @@ def from_array( | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike /, *, - domain: common.Domain, + domain: common.DomainLike, dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: + domain = common.Domain.from_domain_like(domain) xp = cls.array_ns xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) @@ -360,23 +361,17 @@ def __setitem__( def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice: list[slice | None] = [] - new_domain_dims = [] - new_domain_ranges = [] + named_ranges = [] for dim in new_dimensions: if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - new_domain_dims.append(dim) - new_domain_ranges.append(field.domain[pos][1]) + named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - new_domain_dims.append(dim) - new_domain_ranges.append( - common.UnitRange(common.Infinity.negative(), common.Infinity.positive()) + named_ranges.append( + (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) ) - return common.field( - field.ndarray[tuple(domain_slice)], - domain=common.Domain(tuple(new_domain_dims), tuple(new_domain_ranges)), - ) + return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) def _builtins_broadcast( @@ -447,3 +442,4 @@ def _compute_slice( return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 0c5de764f2..3e7dc2d4d3 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeGuard class RecursionGuard: @@ -49,3 +49,7 @@ def __enter__(self): def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) + + +def is_tuple_of(v: Any, t: type) -> TypeGuard[tuple]: + return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 0a101d8273..065a8ccd6f 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -21,15 +21,6 @@ from gt4py.next.embedded.common import _slice_range, sub_domain -def _d(*dom: tuple[common.Dimension, tuple[int, int]]): - dims = [] - rngs = [] - for dim, (start, stop) in dom: - dims.append(dim) - rngs.append(common.UnitRange(start, stop)) - return common.Domain(tuple(dims), tuple(rngs)) - - def test_slice_range(): input_range = UnitRange(2, 10) slice_obj = slice(2, -2) @@ -47,85 +38,87 @@ def test_slice_range(): @pytest.mark.parametrize( "domain, index, expected", [ - (_d((I, (2, 5))), 1, _d()), - (_d((I, (2, 5))), slice(1, 2), _d((I, (3, 4)))), - (_d((I, (2, 5))), (I, 2), _d()), - (_d((I, (2, 5))), (I, UnitRange(2, 3)), _d((I, (2, 3)))), - (_d((I, (-2, 3))), 1, _d()), - (_d((I, (-2, 3))), slice(1, 2), _d((I, (-1, 0)))), - (_d((I, (-2, 3))), (I, 1), _d()), - (_d((I, (-2, 3))), (I, UnitRange(2, 3)), _d((I, (2, 3)))), - (_d((I, (-2, 3))), -5, _d()), - # (_d((I, (-2, 3))), -6, IndexError), - # (_d((I, (-2, 3))), slice(-6, -7), IndexError), - (_d((I, (-2, 3))), 4, _d()), - # (_d((I, (-2, 3))), 5, IndexError), - # (_d((I, (-2, 3))), slice(4, 5), IndexError), - # (_d((I, (-2, 3))), (I, -3), IndexError), - # (_d((I, (-2, 3))), (I, UnitRange(-3, -2)), IndexError), - # (_d((I, (-2, 3))), (I, 3), IndexError), - # (_d((I, (-2, 3))), (I, UnitRange(3, 4)), IndexError), + ([(I, (2, 5))], 1, []), + ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), + ([(I, (2, 5))], (I, 2), []), + ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], 1, []), + ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), + ([(I, (-2, 3))], (I, 1), []), + ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], -5, []), + # ([(I, (-2, 3))], -6, IndexError), + # ([(I, (-2, 3))], slice(-6, -7), IndexError), + ([(I, (-2, 3))], 4, []), + # ([(I, (-2, 3))], 5, IndexError), + # ([(I, (-2, 3))], slice(4, 5), IndexError), + # ([(I, (-2, 3))], (I, -3), IndexError), + # ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), + # ([(I, (-2, 3))], (I, 3), IndexError), + # ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], 2, - _d((J, (3, 6)), (K, (4, 7))), + [(J, (3, 6)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], slice(2, 3), - _d((I, (4, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (I, 2), - _d((J, (3, 6)), (K, (4, 7))), + [(J, (3, 6)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (I, UnitRange(2, 3)), - _d((I, (2, 3)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (J, 3), - _d((I, (2, 5)), (K, (4, 7))), + [(I, (2, 5)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (J, UnitRange(4, 5)), - _d((I, (2, 5)), (J, (4, 5)), (K, (4, 7))), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], ((J, 3), (I, 2)), - _d((K, (4, 7))), + [(K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], ((J, UnitRange(4, 5)), (I, 2)), - _d((J, (4, 5)), (K, (4, 7))), + [(J, (4, 5)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (slice(1, 2), slice(2, 3)), - _d((I, (3, 4)), (J, (5, 6)), (K, (4, 7))), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (Ellipsis, slice(2, 3)), - _d((I, (2, 5)), (J, (3, 6)), (K, (6, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], ), ( - _d((I, (2, 5)), (J, (3, 6)), (K, (4, 7))), + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], (slice(1, 2), Ellipsis, slice(2, 3)), - _d((I, (3, 4)), (J, (3, 6)), (K, (6, 7))), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], ), ], ) def test_sub_domain(domain, index, expected): + domain = common.Domain.from_domain_like(domain) if expected is IndexError: with pytest.raises(IndexError): - print(sub_domain(domain, index)) + sub_domain(domain, index) else: + expected = common.Domain.from_domain_like(expected) result = sub_domain(domain, index) assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index a81df1a977..16711e3203 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -75,7 +75,7 @@ def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): dtype = nd_array_implementation.float32 return common.field( nd_array_implementation.asarray(lst, dtype=dtype), - domain=common.Domain((common.Dimension("foo"),), (common.UnitRange(0, len(lst)),)), + domain={common.Dimension("foo"): (0, len(lst))}, ) @@ -189,10 +189,8 @@ def product_nd_array_implementation(request): def test_mixed_fields(product_nd_array_implementation): first_impl, second_impl = product_nd_array_implementation - if (first_impl.__name__ == "cupy" and second_impl.__name__ == "numpy") or ( - first_impl.__name__ == "numpy" and second_impl.__name__ == "cupy" - ): - pytest.skip("Binary operation between CuPy and NumPy requires explicit conversion.") + if "numpy" in first_impl.__name__ and "cupy" in second_impl.__name__: + pytest.skip("Binary operation between NumPy and CuPy requires explicit conversion.") inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] @@ -371,19 +369,23 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 4))), + Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + ), + ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ( + (Ellipsis, 1), + (10,), + Domain((IDim, UnitRange(5, 15))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 12)))), - ((Ellipsis, 1), (10,), Domain((IDim,), (UnitRange(5, 15),))), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, JDim), (UnitRange(7, 8), UnitRange(7, 9))), + Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim,), (UnitRange(6, 7),)), + Domain((IDim, UnitRange(6, 7))), ), ], ) @@ -400,32 +402,44 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain((IDim, JDim), (UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None), slice(None), slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), - ((0, Ellipsis, 0), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ], ) @@ -477,7 +491,7 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim,), (UnitRange(0, 10),))), + common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), ), ], ) @@ -502,7 +516,7 @@ def test_setitem_wrong_domain(): ) value_incompatible = common.field( - np.ones((10,)) * 42.0, domain=common.Domain((JDim,), (UnitRange(-5, 5),)) + np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) ) with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8cdc96254c..3eba53963c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -15,7 +15,15 @@ import pytest -from gt4py.next.common import Dimension, DimensionKind, Domain, Infinity, UnitRange, promote_dims +from gt4py.next.common import ( + Dimension, + DimensionKind, + Domain, + Infinity, + UnitRange, + promote_dims, + to_named_range, +) IDim = Dimension("IDim") @@ -26,14 +34,7 @@ @pytest.fixture def domain(): - range1 = UnitRange(0, 10) - range2 = UnitRange(5, 15) - range3 = UnitRange(20, 30) - - dimensions = (IDim, JDim, KDim) - ranges = (range1, range2, range3) - - return Domain(dimensions, ranges) + return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) def test_empty_range(): @@ -53,6 +54,11 @@ def test_unit_range_length(rng): assert len(rng) == 10 +@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) +def test_from_unit_range_like(rng_like): + assert UnitRange.from_unit_range_like(rng_like) == UnitRange(2, 4) + + def test_unit_range_repr(rng): assert repr(rng) == "UnitRange(-5, 5)" @@ -142,10 +148,36 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() +@pytest.mark.parametrize( + "named_rng_like", + [ + (IDim, (2, 4)), + (IDim, range(2, 4)), + (IDim, UnitRange(2, 4)), + ], +) +def test_to_named_range(named_rng_like): + assert to_named_range(named_rng_like) == (IDim, UnitRange(2, 4)) + + def test_domain_length(domain): assert len(domain) == 3 +@pytest.mark.parametrize( + "domain_like", + [ + (Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))), + ((IDim, (2, 4)), (JDim, (3, 5))), + ({IDim: (2, 4), JDim: (3, 5)}), + ], +) +def test_from_domain_like(domain_like): + assert Domain.from_domain_like(domain_like) == Domain( + dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) + ) + + def test_domain_iteration(domain): iterated_values = [val for val in domain] assert iterated_values == list(zip(domain.dims, domain.ranges)) @@ -160,16 +192,25 @@ def test_domain_contains_named_range(domain): "second_domain, expected", [ ( - Domain((IDim, JDim), (UnitRange(2, 12), UnitRange(7, 17))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30)), + ), ), ( - Domain((IDim, KDim), (UnitRange(2, 12), UnitRange(7, 27))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27))), + Domain(dims=(IDim, KDim), ranges=(UnitRange(2, 12), UnitRange(7, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27)), + ), ), ( - Domain((JDim, KDim), (UnitRange(2, 12), UnitRange(4, 27))), - Domain((IDim, JDim, KDim), (UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27))), + Domain(dims=(JDim, KDim), ranges=(UnitRange(2, 12), UnitRange(4, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27)), + ), ), ], ) @@ -181,9 +222,7 @@ def test_domain_intersection_different_dimensions(domain, second_domain, expecte def test_domain_intersection_reversed_dimensions(domain): - dimensions = (JDim, IDim) - ranges = (UnitRange(2, 12), UnitRange(7, 17)) - domain2 = Domain(dimensions, ranges) + domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) with pytest.raises( ValueError, @@ -249,7 +288,7 @@ def test_domain_repeat_dims(): dims = (IDim, JDim, IDim) ranges = (UnitRange(0, 5), UnitRange(0, 8), UnitRange(0, 3)) with pytest.raises(NotImplementedError, match=r"Domain dimensions must be unique, not .*"): - Domain(dims, ranges) + Domain(dims=dims, ranges=ranges) def test_domain_dims_ranges_length_mismatch(): @@ -305,3 +344,5 @@ def test_dimension_promotion( promote_dims(*dim_list) assert exc_info.match(expected_error_msg) + assert exc_info.match(expected_error_msg) + assert exc_info.match(expected_error_msg) From 0c858f327f4d97d522fd3aa28789d13d6216bb60 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 29 Aug 2023 17:51:16 +0200 Subject: [PATCH 07/21] add out-of-bounds check --- src/gt4py/next/common.py | 5 +-- src/gt4py/next/embedded/common.py | 42 +++++++++++++++---- src/gt4py/next/embedded/exceptions.py | 29 +++++++++++++ src/gt4py/next/errors/exceptions.py | 12 +++--- .../unit_tests/embedded_tests/test_common.py | 21 ++++++---- 5 files changed, 85 insertions(+), 24 deletions(-) create mode 100644 src/gt4py/next/embedded/exceptions.py diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 341831063b..d13ca4f03a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -151,11 +151,10 @@ def __and__(self, other: Set[Any]) -> UnitRange: DomainRange: TypeAlias = UnitRange | IntIndex NamedRange: TypeAlias = tuple[Dimension, UnitRange] NamedIndex: TypeAlias = tuple[Dimension, IntIndex] +AnyIndex: TypeAlias = IntIndex | NamedRange | NamedIndex | slice | EllipsisType DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] BufferSlice: TypeAlias = tuple[slice | IntIndex | EllipsisType, ...] -FieldSlice: TypeAlias = ( - DomainSlice | BufferSlice | slice | IntIndex | EllipsisType | NamedRange | NamedIndex -) +FieldSlice: TypeAlias = DomainSlice | BufferSlice | AnyIndex UnitRangeLike: TypeAlias = UnitRange | range | tuple[int, int] DomainLike: TypeAlias = ( Sequence[tuple[Dimension, UnitRangeLike]] | Mapping[Dimension, UnitRangeLike] diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 53c8cc30a6..9f2303d95d 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -17,6 +17,7 @@ from typing import Any, Optional, Sequence, cast from gt4py.next import common +from gt4py.next.embedded import exceptions as embedded_exceptions def sub_domain(domain: common.Domain, index: common.FieldSlice) -> common.Domain: @@ -42,20 +43,45 @@ def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> co domain, expanded, fillvalue=slice(None) ): if isinstance(idx, slice): - named_ranges.append((dim, _slice_range(rng, idx))) + try: + sliced = _slice_range(rng, idx) + named_ranges.append((dim, sliced)) + except IndexError: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) else: - assert common.is_int_index(idx) # not in new_domain + # not in new domain + assert common.is_int_index(idx) + new_index = (rng.start if idx >= 0 else rng.stop) + idx + if new_index < rng.start or new_index >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) return common.Domain(*named_ranges) def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> common.Domain: named_ranges: list[common.NamedRange] = [] - for i, dim in enumerate(domain.dims): + for i, (dim, rng) in enumerate(domain): if (pos := _find_index_of_dim(dim, index)) is not None: - index_or_range = index[pos][1] - if isinstance(index_or_range, common.UnitRange): - named_ranges.append((dim, index_or_range)) + named_idx = index[pos] + idx = named_idx[1] + if isinstance(idx, common.UnitRange): + if not idx <= rng: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + + named_ranges.append((dim, idx)) + else: + # not in new domain + assert common.is_int_index(idx) + if idx < rng.start or idx >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) else: # dimension not mentioned in slice named_ranges.append((dim, domain.ranges[i])) @@ -92,7 +118,6 @@ def _expand_ellipsis( def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: - # handle slice(None) case if slice_obj == slice(None): return common.UnitRange(input_range.start, input_range.stop) @@ -103,6 +128,9 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop ) + (slice_obj.stop or len(input_range)) + if start < input_range.start or stop > input_range.stop: + raise IndexError() + return common.UnitRange(start, stop) diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py new file mode 100644 index 0000000000..78c8be5f4b --- /dev/null +++ b/src/gt4py/next/embedded/exceptions.py @@ -0,0 +1,29 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next import common +from gt4py.next.errors import exceptions as gt4py_exceptions + + +class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + def __init__( + self, + domain: common.Domain, + indices: common.FieldSlice, + index: common.AnyIndex, + dim: common.Dimension, + ): + super().__init__( + f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." + ) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 74230263db..e956858549 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -33,17 +33,19 @@ from . import formatting -class DSLError(Exception): +class GT4PyError(Exception): + @property + def message(self) -> str: + return self.args[0] + + +class DSLError(GT4PyError): location: Optional[SourceLocation] def __init__(self, location: Optional[SourceLocation], message: str) -> None: self.location = location super().__init__(message) - @property - def message(self) -> str: - return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> DSLError: self.location = location return self diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 065a8ccd6f..2e43a42d1b 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -18,6 +18,7 @@ from gt4py.next import common from gt4py.next.common import UnitRange +from gt4py.next.embedded import exceptions as embedded_exceptions from gt4py.next.embedded.common import _slice_range, sub_domain @@ -47,15 +48,17 @@ def test_slice_range(): ([(I, (-2, 3))], (I, 1), []), ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), ([(I, (-2, 3))], -5, []), - # ([(I, (-2, 3))], -6, IndexError), - # ([(I, (-2, 3))], slice(-6, -7), IndexError), + ([(I, (-2, 3))], -6, IndexError), + ([(I, (-2, 3))], slice(-7, -6), IndexError), + ([(I, (-2, 3))], slice(-6, -7), IndexError), ([(I, (-2, 3))], 4, []), - # ([(I, (-2, 3))], 5, IndexError), - # ([(I, (-2, 3))], slice(4, 5), IndexError), - # ([(I, (-2, 3))], (I, -3), IndexError), - # ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), - # ([(I, (-2, 3))], (I, 3), IndexError), - # ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ([(I, (-2, 3))], 5, IndexError), + ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), + ([(I, (-2, 3))], slice(5, 6), IndexError), + ([(I, (-2, 3))], (I, -3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], (I, 3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], 2, @@ -116,7 +119,7 @@ def test_slice_range(): def test_sub_domain(domain, index, expected): domain = common.Domain.from_domain_like(domain) if expected is IndexError: - with pytest.raises(IndexError): + with pytest.raises(embedded_exceptions.IndexOutOfBounds): sub_domain(domain, index) else: expected = common.Domain.from_domain_like(expected) From 9851de2ddf8c6f709cdaa2897791d9fa7679cf5f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 29 Aug 2023 21:29:16 +0200 Subject: [PATCH 08/21] fix tests --- .../unit_tests/embedded_tests/test_nd_array_field.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 16711e3203..95093c8307 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -22,7 +22,7 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange -from gt4py.next.embedded import nd_array_field +from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -336,7 +336,7 @@ def test_get_slices_invalid_type(): (JDim, KDim), (2, 15), ), - ((IDim, 1), (JDim, KDim), (10, 15)), + ((IDim, 5), (JDim, KDim), (10, 15)), ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), ], ) @@ -354,13 +354,13 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) - field = common.field(np.ones((10, 10), dtype=np.int32), domain=domain) + field = common.field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 2), (JDim, 4)) + named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] assert isinstance(value, np.int32) - assert value == 1 + assert value == 21 @pytest.mark.parametrize( @@ -472,7 +472,7 @@ def test_relative_indexing_out_of_bounds(lazy_slice): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common.field(np.ones((10, 10)), domain=domain) - with pytest.raises(IndexError): + with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) From 03880a09c5acc626ac7675358df2246e57306c3d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 29 Aug 2023 21:37:35 +0200 Subject: [PATCH 09/21] test alike constructors --- src/gt4py/next/common.py | 64 +++++++++---------- src/gt4py/next/embedded/nd_array_field.py | 2 +- .../unit_tests/embedded_tests/test_common.py | 4 +- tests/next_tests/unit_tests/test_common.py | 62 +++++++++--------- 4 files changed, 67 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d13ca4f03a..8857093057 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -90,17 +90,6 @@ def __post_init__(self): object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) - @classmethod - def from_unit_range_like(cls, r: UnitRangeLike) -> UnitRange: - assert is_unit_range_like(r) - if isinstance(r, UnitRange): - return r - if isinstance(r, range) and r.step == 1: - return cls(r.start, r.stop) - if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): - return cls(r[0], r[1]) - raise RuntimeError("Unreachable") - @classmethod def infinity(cls) -> UnitRange: return cls(Infinity.negative(), Infinity.positive()) @@ -147,6 +136,17 @@ def __and__(self, other: Set[Any]) -> UnitRange: raise NotImplementedError("Can only find the intersection between UnitRange instances.") +def unit_range(r: UnitRangeLike) -> UnitRange: + assert is_unit_range_like(r) + if isinstance(r, UnitRange): + return r + if isinstance(r, range) and r.step == 1: + return UnitRange(r.start, r.stop) + if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + return UnitRange(r[0], r[1]) + raise RuntimeError("Unreachable") + + IntIndex: TypeAlias = int | np.integer DomainRange: TypeAlias = UnitRange | IntIndex NamedRange: TypeAlias = tuple[Dimension, UnitRange] @@ -211,8 +211,8 @@ def is_domain_like(v: Any) -> TypeGuard[DomainLike]: ) -def to_named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: - return (v[0], UnitRange.from_unit_range_like(v[1])) +def named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: + return (v[0], unit_range(v[1])) @dataclasses.dataclass(frozen=True, init=False) @@ -252,25 +252,6 @@ def __init__( if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") - @classmethod - def from_domain_like(cls, domain_like: DomainLike) -> Domain: - assert is_domain_like(domain_like) - if isinstance(domain_like, Domain): - return domain_like - if isinstance(domain_like, Sequence) and all( - isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) - for e in domain_like - ): - return cls(*tuple(to_named_range(d) for d in domain_like)) - if isinstance(domain_like, Mapping) and all( - isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items() - ): - return cls( - dims=tuple(domain_like.keys()), - ranges=tuple(UnitRange.from_unit_range_like(r) for r in domain_like.values()), - ) - raise RuntimeError("Unreachable") - def __len__(self) -> int: return len(self.ranges) @@ -314,6 +295,25 @@ def __and__(self, other: Domain) -> Domain: return Domain(dims=broadcast_dims, ranges=intersected_ranges) +def domain(domain_like: DomainLike) -> Domain: + assert is_domain_like(domain_like) + if isinstance(domain_like, Domain): + return domain_like + if isinstance(domain_like, Sequence) and all( + isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) + for e in domain_like + ): + return Domain(*tuple(named_range(d) for d in domain_like)) + if isinstance(domain_like, Mapping) and all( + isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items() + ): + return Domain( + dims=tuple(domain_like.keys()), + ranges=tuple(unit_range(r) for r in domain_like.values()), + ) + raise RuntimeError("Unreachable") + + def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 7867b02b32..2b4e49048f 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -165,7 +165,7 @@ def from_array( domain: common.DomainLike, dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: - domain = common.Domain.from_domain_like(domain) + domain = common.domain(domain) xp = cls.array_ns xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 2e43a42d1b..444978097c 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -117,11 +117,11 @@ def test_slice_range(): ], ) def test_sub_domain(domain, index, expected): - domain = common.Domain.from_domain_like(domain) + domain = common.domain(domain) if expected is IndexError: with pytest.raises(embedded_exceptions.IndexOutOfBounds): sub_domain(domain, index) else: - expected = common.Domain.from_domain_like(expected) + expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 3eba53963c..66eed059c1 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -21,8 +21,10 @@ Domain, Infinity, UnitRange, + domain, + named_range, promote_dims, - to_named_range, + unit_range, ) @@ -33,7 +35,7 @@ @pytest.fixture -def domain(): +def a_domain(): return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) @@ -55,8 +57,8 @@ def test_unit_range_length(rng): @pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) -def test_from_unit_range_like(rng_like): - assert UnitRange.from_unit_range_like(rng_like) == UnitRange(2, 4) +def test_unit_range_like(rng_like): + assert unit_range(rng_like) == UnitRange(2, 4) def test_unit_range_repr(rng): @@ -156,12 +158,12 @@ def test_mixed_infinity_range(): (IDim, UnitRange(2, 4)), ], ) -def test_to_named_range(named_rng_like): - assert to_named_range(named_rng_like) == (IDim, UnitRange(2, 4)) +def test_named_range_like(named_rng_like): + assert named_range(named_rng_like) == (IDim, UnitRange(2, 4)) -def test_domain_length(domain): - assert len(domain) == 3 +def test_domain_length(a_domain): + assert len(a_domain) == 3 @pytest.mark.parametrize( @@ -172,20 +174,20 @@ def test_domain_length(domain): ({IDim: (2, 4), JDim: (3, 5)}), ], ) -def test_from_domain_like(domain_like): - assert Domain.from_domain_like(domain_like) == Domain( +def test_domain_like(domain_like): + assert domain(domain_like) == Domain( dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) ) -def test_domain_iteration(domain): - iterated_values = [val for val in domain] - assert iterated_values == list(zip(domain.dims, domain.ranges)) +def test_domain_iteration(a_domain): + iterated_values = [val for val in a_domain] + assert iterated_values == list(zip(a_domain.dims, a_domain.ranges)) -def test_domain_contains_named_range(domain): - assert (IDim, UnitRange(0, 10)) in domain - assert (IDim, UnitRange(-5, 5)) not in domain +def test_domain_contains_named_range(a_domain): + assert (IDim, UnitRange(0, 10)) in a_domain + assert (IDim, UnitRange(-5, 5)) not in a_domain @pytest.mark.parametrize( @@ -214,21 +216,21 @@ def test_domain_contains_named_range(domain): ), ], ) -def test_domain_intersection_different_dimensions(domain, second_domain, expected): - result_domain = domain & second_domain +def test_domain_intersection_different_dimensions(a_domain, second_domain, expected): + result_domain = a_domain & second_domain print(result_domain) assert result_domain == expected -def test_domain_intersection_reversed_dimensions(domain): +def test_domain_intersection_reversed_dimensions(a_domain): domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) with pytest.raises( ValueError, match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", ): - domain & domain2 + a_domain & domain2 @pytest.mark.parametrize( @@ -241,8 +243,8 @@ def test_domain_intersection_reversed_dimensions(domain): (-2, (JDim, UnitRange(5, 15))), ], ) -def test_domain_integer_indexing(domain, index, expected): - result = domain[index] +def test_domain_integer_indexing(a_domain, index, expected): + result = a_domain[index] assert result == expected @@ -253,8 +255,8 @@ def test_domain_integer_indexing(domain, index, expected): (slice(1, None), ((JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30)))), ], ) -def test_domain_slice_indexing(domain, slice_obj, expected): - result = domain[slice_obj] +def test_domain_slice_indexing(a_domain, slice_obj, expected): + result = a_domain[slice_obj] assert isinstance(result, Domain) assert len(result) == len(expected) assert all(res == exp for res, exp in zip(result, expected)) @@ -267,21 +269,21 @@ def test_domain_slice_indexing(domain, slice_obj, expected): (KDim, (KDim, UnitRange(20, 30))), ], ) -def test_domain_dimension_indexing(domain, index, expected_result): - result = domain[index] +def test_domain_dimension_indexing(a_domain, index, expected_result): + result = a_domain[index] assert result == expected_result -def test_domain_indexing_dimension_missing(domain): +def test_domain_indexing_dimension_missing(a_domain): with pytest.raises(KeyError, match=r"No Dimension of type .* is present in the Domain."): - domain[ECDim] + a_domain[ECDim] -def test_domain_indexing_invalid_type(domain): +def test_domain_indexing_invalid_type(a_domain): with pytest.raises( KeyError, match="Invalid index type, must be either int, slice, or Dimension." ): - domain["foo"] + a_domain["foo"] def test_domain_repeat_dims(): From 02ca5346f2e35865f48fb847bc4cf4847ac0fc35 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 29 Aug 2023 21:38:29 +0200 Subject: [PATCH 10/21] cleanup --- tests/next_tests/unit_tests/test_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 66eed059c1..31e35221ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -346,5 +346,3 @@ def test_dimension_promotion( promote_dims(*dim_list) assert exc_info.match(expected_error_msg) - assert exc_info.match(expected_error_msg) - assert exc_info.match(expected_error_msg) From 9df8329e82a8e4ab1cc98e90d568b9d4d28ed3c3 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 30 Aug 2023 08:40:32 +0200 Subject: [PATCH 11/21] remove duplicated lines --- src/gt4py/next/common.py | 7 ------- src/gt4py/next/embedded/common.py | 1 - src/gt4py/next/embedded/nd_array_field.py | 1 - 3 files changed, 9 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 8857093057..60ce4ab550 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -590,10 +590,3 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: ) return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list - return topologically_sorted_list diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 9f2303d95d..a9a77aeeee 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -142,4 +142,3 @@ def _find_index_of_dim( if dim == d: return i return None - return None diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 2b4e49048f..119cd89bbc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -442,4 +442,3 @@ def _compute_slice( return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") - raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") From 1d5315571666f9df6d49a5cfbc3826a7d70ddf07 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 30 Aug 2023 09:16:00 +0200 Subject: [PATCH 12/21] more compact str representation of Domain --- src/gt4py/next/common.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 60ce4ab550..ef96cc8320 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -64,17 +64,23 @@ class DimensionKind(StrEnum): VERTICAL = "vertical" LOCAL = "local" - def __str__(self): + def __repr__(self): return f"{type(self).__name__}.{self.name}" + def __str__(self): + return self.value + @dataclasses.dataclass(frozen=True) class Dimension: value: str kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) + def __repr__(self): + return f'Dimension(value="{self.value}", kind={repr(self.kind)})' + def __str__(self): - return f'Dimension(value="{self.value}", kind={self.kind})' + return f"{self.value}[{self.kind}]" @dataclasses.dataclass(frozen=True) @@ -135,6 +141,9 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __str__(self) -> str: + return f"({self.start}:{self.stop})" + def unit_range(r: UnitRangeLike) -> UnitRange: assert is_unit_range_like(r) @@ -217,6 +226,8 @@ def named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: @dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[NamedRange]): + """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" + dims: tuple[Dimension, ...] ranges: tuple[UnitRange, ...] @@ -294,8 +305,25 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) + def __str__(self) -> str: + return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + def domain(domain_like: DomainLike) -> Domain: + """ + Construct `Domain` from `DomainLike` object. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> domain(((I, (2, 4)), (J, (3, 5)))) + Domain(dims=(Dimension(value="I", kind=DimensionKind.HORIZONTAL), Dimension(value="J", kind=DimensionKind.HORIZONTAL)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + + >>> domain({I: (2, 4), J: (3, 5)}) + Domain(dims=(Dimension(value="I", kind=DimensionKind.HORIZONTAL), Dimension(value="J", kind=DimensionKind.HORIZONTAL)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + """ assert is_domain_like(domain_like) if isinstance(domain_like, Domain): return domain_like From e85abe466d987fb323fc28a792fb526e3661290c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 30 Aug 2023 09:20:45 +0200 Subject: [PATCH 13/21] fix error --- src/gt4py/next/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ef96cc8320..b46c040f0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -153,7 +153,7 @@ def unit_range(r: UnitRangeLike) -> UnitRange: return UnitRange(r.start, r.stop) if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): return UnitRange(r[0], r[1]) - raise RuntimeError("Unreachable") + raise ValueError(f"`{r}` is not `UnitRangeLike`.") IntIndex: TypeAlias = int | np.integer @@ -339,7 +339,7 @@ def domain(domain_like: DomainLike) -> Domain: dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), ) - raise RuntimeError("Unreachable") + raise ValueError(f"`{domain_like}` is not `DomainLike`.") def _broadcast_ranges( From 901ac15e1c59be31431730ce5ec258afe42c4787 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 30 Aug 2023 09:36:34 +0200 Subject: [PATCH 14/21] remove reps --- src/gt4py/next/common.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b46c040f0a..889c73d61f 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -64,9 +64,6 @@ class DimensionKind(StrEnum): VERTICAL = "vertical" LOCAL = "local" - def __repr__(self): - return f"{type(self).__name__}.{self.name}" - def __str__(self): return self.value @@ -76,9 +73,6 @@ class Dimension: value: str kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) - def __repr__(self): - return f'Dimension(value="{self.value}", kind={repr(self.kind)})' - def __str__(self): return f"{self.value}[{self.kind}]" @@ -319,10 +313,10 @@ def domain(domain_like: DomainLike) -> Domain: >>> J = Dimension("J") >>> domain(((I, (2, 4)), (J, (3, 5)))) - Domain(dims=(Dimension(value="I", kind=DimensionKind.HORIZONTAL), Dimension(value="J", kind=DimensionKind.HORIZONTAL)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) >>> domain({I: (2, 4), J: (3, 5)}) - Domain(dims=(Dimension(value="I", kind=DimensionKind.HORIZONTAL), Dimension(value="J", kind=DimensionKind.HORIZONTAL)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) """ assert is_domain_like(domain_like) if isinstance(domain_like, Domain): From 643ffaabf84a39c37c66273cdd97b15e19dd12de Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 30 Aug 2023 10:29:16 +0200 Subject: [PATCH 15/21] fix test --- .../feature_tests/ffront_tests/test_foast_pretty_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index 0bc5a98a4e..c1bee4fa2f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -82,7 +82,7 @@ def scan(inp: int32) -> int32: expected = textwrap.dedent( f""" - @scan_operator(axis=Dimension(value="KDim", kind=DimensionKind.VERTICAL), forward=False, init=1) + @scan_operator(axis=KDim[vertical], forward=False, init=1) def scan(inp: int32) -> int32: {ssa.unique_name("foo", 0)} = inp return inp From e16c4aead942fde64e317681276f4374719af175 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 31 Aug 2023 16:09:45 +0200 Subject: [PATCH 16/21] refactor typealiases --- src/gt4py/_core/definitions.py | 4 +- src/gt4py/next/common.py | 125 ++++++++++++---------- src/gt4py/next/embedded/common.py | 39 +++---- src/gt4py/next/embedded/exceptions.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 23 ++-- 5 files changed, 97 insertions(+), 98 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index f49bac531a..059ba6c24c 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -247,13 +247,15 @@ def subndim(self) -> int: return len(self.tensor_shape) def __eq__(self, other: Any) -> bool: - # TODO: discuss (make concrete subclasses equal to instances of this with the same type) return ( isinstance(other, DType) and self.scalar_type == other.scalar_type and self.tensor_shape == other.tensor_shape ) + def __hash__(self) -> int: + return hash((self.scalar_type, self.tensor_shape)) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 889c73d61f..a1f40c2f92 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -21,7 +21,7 @@ import sys from collections.abc import Mapping, Sequence, Set from types import EllipsisType -from typing import TypeGuard, overload +from typing import overload import numpy as np import numpy.typing as npt @@ -36,7 +36,9 @@ ParamSpec, Protocol, TypeAlias, + TypeGuard, TypeVar, + cast, extended_runtime_checkable, runtime_checkable, ) @@ -139,36 +141,42 @@ def __str__(self) -> str: return f"({self.start}:{self.stop})" -def unit_range(r: UnitRangeLike) -> UnitRange: - assert is_unit_range_like(r) +RangeLike: TypeAlias = UnitRange | range | tuple[int, int] + + +def unit_range(r: RangeLike) -> UnitRange: if isinstance(r, UnitRange): return r - if isinstance(r, range) and r.step == 1: + if isinstance(r, range): + if r.step != 1: + raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") return UnitRange(r.start, r.stop) if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): return UnitRange(r[0], r[1]) - raise ValueError(f"`{r}` is not `UnitRangeLike`.") + raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") -IntIndex: TypeAlias = int | np.integer -DomainRange: TypeAlias = UnitRange | IntIndex -NamedRange: TypeAlias = tuple[Dimension, UnitRange] +IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] -AnyIndex: TypeAlias = IntIndex | NamedRange | NamedIndex | slice | EllipsisType -DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] -BufferSlice: TypeAlias = tuple[slice | IntIndex | EllipsisType, ...] -FieldSlice: TypeAlias = DomainSlice | BufferSlice | AnyIndex -UnitRangeLike: TypeAlias = UnitRange | range | tuple[int, int] -DomainLike: TypeAlias = ( - Sequence[tuple[Dimension, UnitRangeLike]] | Mapping[Dimension, UnitRangeLike] -) +NamedRange: TypeAlias = tuple[Dimension, UnitRange] +RelativeIndexElement: TypeAlias = IntIndex | slice | EllipsisType +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement +AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] +RelativeIndexSequence: TypeAlias = tuple[ + slice | IntIndex | EllipsisType, ... +] # is a tuple but called Sequence for symmetry +AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence +AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence def is_int_index(p: Any) -> TypeGuard[IntIndex]: - return isinstance(p, (int, np.integer)) + # should be replaced by isinstance(p, IntIndex), but mypy complains with + # `Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type]` + return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: Any) -> TypeGuard[NamedRange]: +def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 @@ -177,44 +185,41 @@ def is_named_range(v: Any) -> TypeGuard[NamedRange]: ) -def is_named_index(v: Any) -> TypeGuard[NamedRange]: +def is_named_index(v: AnyIndex) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) ) -def is_domain_slice(v: Any) -> TypeGuard[DomainSlice]: +def is_any_index_element(v: AnyIndex) -> TypeGuard[AnyIndexElement]: + return ( + is_int_index(v) + or is_named_range(v) + or is_named_index(v) + or isinstance(v, slice) + or v is Ellipsis + ) + + +def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) -def is_buffer_slice(v: Any) -> TypeGuard[BufferSlice]: +def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: return isinstance(v, tuple) and all( isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v ) -def is_unit_range_like(v: Any) -> TypeGuard[UnitRangeLike]: - return ( - isinstance(v, UnitRange) - or (isinstance(v, range) and v.step == 1) - or (isinstance(v, tuple) and isinstance(v[0], int) and isinstance(v[1], int)) - ) - - -def is_domain_like(v: Any) -> TypeGuard[DomainLike]: - return ( - isinstance(v, Sequence) - and all( - isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) - for e in v - ) - ) or ( - isinstance(v, Mapping) - and all(isinstance(d, Dimension) and is_unit_range_like(r) for d, r in v.items()) +def as_any_index_sequence(index: AnyIndex) -> AnyIndexSequence: + # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` + return cast( + AnyIndexSequence, + (index,) if is_any_index_element(index) else index, ) -def named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: +def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) @@ -230,7 +235,8 @@ def __init__( *args: NamedRange, dims: Optional[tuple[Dimension, ...]] = None, ranges: Optional[tuple[UnitRange, ...]] = None, - ): + ) -> None: + # TODO throw user error in case pre-conditions are not met if dims is not None or ranges is not None: if dims is None and ranges is None: raise ValueError("Either both none of `dims` and `ranges` must be specified.") @@ -239,8 +245,15 @@ def __init__( "No extra `args` allowed when constructing fomr `dims` and `ranges`." ) - assert dims is not None - assert ranges is not None + assert dims is not None and ranges is not None # for mypy + if not all(isinstance(dim, Dimension) for dim in dims): + raise ValueError( + f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + ) + if not all(isinstance(rng, Dimension) for rng in ranges): + raise ValueError( + f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + ) if len(dims) != len(ranges): raise ValueError( f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." @@ -249,7 +262,8 @@ def __init__( object.__setattr__(self, "dims", dims) object.__setattr__(self, "ranges", ranges) else: - assert all(is_named_range(arg) for arg in args) + if not all(is_named_range(arg) for arg in args): + raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") dims, ranges = zip(*args) if len(args) > 0 else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) @@ -265,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> "Domain": + def __getitem__(self, index: slice) -> Domain: ... @overload @@ -303,6 +317,9 @@ def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" +DomainLike: TypeAlias = Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] + + def domain(domain_like: DomainLike) -> Domain: """ Construct `Domain` from `DomainLike` object. @@ -318,17 +335,11 @@ def domain(domain_like: DomainLike) -> Domain: >>> domain({I: (2, 4), J: (3, 5)}) Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) """ - assert is_domain_like(domain_like) if isinstance(domain_like, Domain): return domain_like - if isinstance(domain_like, Sequence) and all( - isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) - for e in domain_like - ): + if isinstance(domain_like, Sequence): return Domain(*tuple(named_range(d) for d in domain_like)) - if isinstance(domain_like, Mapping) and all( - isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items() - ): + if isinstance(domain_like, Mapping): return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), @@ -394,7 +405,7 @@ def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndex) -> Field | core_defs.ScalarT: ... # Operators @@ -403,7 +414,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndex) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -472,12 +483,12 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: FieldSlice, value: Field | core_defs.ScalarT) -> None: + def __setitem__(self, index: AnyIndex, value: Field | core_defs.ScalarT) -> None: ... def is_mutable_field( - v: Any, + v: Field, ) -> TypeGuard[MutableField]: # This function is introduced to localize the `type: ignore` because # extended_runtime_checkable does not make the protocol runtime_checkable diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index a9a77aeeee..d02ac9d44c 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,26 +14,27 @@ import itertools from types import EllipsisType -from typing import Any, Optional, Sequence, cast +from typing import Any, Optional, Sequence from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions -def sub_domain(domain: common.Domain, index: common.FieldSlice) -> common.Domain: - index = _tuplize_field_slice(index) +def sub_domain(domain: common.Domain, index: common.AnyIndex) -> common.Domain: + index_sequence = common.as_any_index_sequence(index) - if common.is_domain_slice(index): - return _absolute_sub_domain(domain, index) + if common.is_absolute_index_sequence(index_sequence): + return _absolute_sub_domain(domain, index_sequence) - assert isinstance(index, tuple) - if all(isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index): - return _relative_sub_domain(domain, index) + if common.is_relative_index_sequence(index_sequence): + return _relative_sub_domain(domain, index_sequence) raise IndexError(f"Unsupported index type: {index}") -def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> common.Domain: +def _relative_sub_domain( + domain: common.Domain, index: common.RelativeIndexSequence +) -> common.Domain: named_ranges: list[common.NamedRange] = [] expanded = _expand_ellipsis(index, len(domain)) @@ -62,7 +63,9 @@ def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> co return common.Domain(*named_ranges) -def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> common.Domain: +def _absolute_sub_domain( + domain: common.Domain, index: common.AbsoluteIndexSequence +) -> common.Domain: named_ranges: list[common.NamedRange] = [] for i, (dim, rng) in enumerate(domain): if (pos := _find_index_of_dim(dim, index)) is not None: @@ -89,22 +92,6 @@ def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> co return common.Domain(*named_ranges) -def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: - """ - Wrap a single index/slice/range into a tuple. - - Note: the condition is complex as `NamedRange`, `NamedIndex` are implemented as `tuple`. - """ - if ( - not isinstance(v, tuple) - and not common.is_domain_slice(v) - or common.is_named_index(v) - or common.is_named_range(v) - ): - return cast(common.FieldSlice, (v,)) - return v - - def _expand_ellipsis( indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 78c8be5f4b..b190d1a821 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -20,8 +20,8 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): def __init__( self, domain: common.Domain, - indices: common.FieldSlice, - index: common.AnyIndex, + indices: common.AnyIndex, + index: common.AnyIndexElement, dim: common.Dimension, ): super().__init__( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 119cd89bbc..7e2dc598cd 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -188,7 +188,7 @@ def from_array( def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() - def restrict(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] @@ -254,17 +254,16 @@ def __invert__(self) -> _BaseNdArrayField: return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") - def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.BufferSlice]: + def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.RelativeIndexSequence]: new_domain = embedded_common.sub_domain(self.domain, index) - index = embedded_common._tuplize_field_slice(index) - + index_sequence = common.as_any_index_sequence(index) slice_ = ( - _get_slices_from_domain_slice(self.domain, index) - if common.is_domain_slice(index) - else index + _get_slices_from_domain_slice(self.domain, index_sequence) + if common.is_absolute_index_sequence(index_sequence) + else index_sequence ) - assert common.is_buffer_slice(slice_), slice_ + assert common.is_relative_index_sequence(slice_), slice_ return new_domain, slice_ @@ -298,7 +297,7 @@ def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.Buffer def _np_cp_setitem( self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], - index: common.FieldSlice, + index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: target_domain, target_slice = self._slice(index) @@ -350,7 +349,7 @@ class JaxArrayField(_BaseNdArrayField): def __setitem__( self, - index: common.FieldSlice, + index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # use `self.ndarray.at(index).set(value)` @@ -388,7 +387,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> common.BufferSlice: +) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -415,7 +414,7 @@ def _get_slices_from_domain_slice( def _compute_slice( - rng: common.DomainRange, domain: common.Domain, pos: int + rng: common.UnitRange | common.IntIndex, domain: common.Domain, pos: int ) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. From 00897377e6df97d84d8f854d1ed32cb8a8b8b231 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 31 Aug 2023 17:25:48 +0200 Subject: [PATCH 17/21] remaining review comments --- src/gt4py/next/common.py | 44 ++++++++++++++++--- src/gt4py/next/embedded/common.py | 24 +++++----- src/gt4py/next/embedded/exceptions.py | 9 ++++ src/gt4py/next/embedded/nd_array_field.py | 10 ++--- src/gt4py/next/utils.py | 7 ++- .../unit_tests/embedded_tests/test_common.py | 22 +++++++--- 6 files changed, 82 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index a1f40c2f92..9708d2a6d2 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,8 +19,8 @@ import enum import functools import sys +import types from collections.abc import Mapping, Sequence, Set -from types import EllipsisType from typing import overload import numpy as np @@ -159,12 +159,12 @@ def unit_range(r: RangeLike) -> UnitRange: IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] NamedRange: TypeAlias = tuple[Dimension, UnitRange] -RelativeIndexElement: TypeAlias = IntIndex | slice | EllipsisType +RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] RelativeIndexSequence: TypeAlias = tuple[ - slice | IntIndex | EllipsisType, ... + slice | IntIndex | types.EllipsisType, ... ] # is a tuple but called Sequence for symmetry AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence @@ -250,7 +250,7 @@ def __init__( raise ValueError( f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." ) - if not all(isinstance(rng, Dimension) for rng in ranges): + if not all(isinstance(rng, UnitRange) for rng in ranges): raise ValueError( f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." ) @@ -264,7 +264,7 @@ def __init__( else: if not all(is_named_range(arg) for arg in args): raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") - dims, ranges = zip(*args) if len(args) > 0 else ((), ()) + dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) @@ -303,6 +303,20 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") def __and__(self, other: Domain) -> Domain: + """ + Intersect `Domain`s, missing `Dimension`s are considered infinite. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + + >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -371,8 +385,8 @@ class NextGTDimsInterface(Protocol): """ A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. - The dimension names are objects of type :class:`Dimension`, in contrast to :py:mod:`gt4py.cartesian`, - where the labels are `str` s with implied semantics, see :py:class:`~gt4py._core.definitions.GTDimsInterface` . + The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . """ # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way @@ -425,6 +439,10 @@ def __abs__(self) -> Field: def __neg__(self) -> Field: ... + @abc.abstractmethod + def __invert__(self) -> Field: + """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @@ -469,6 +487,18 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + @abc.abstractmethod + def __and__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __or__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __xor__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + def is_field( v: Any, diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d02ac9d44c..6ad909fc1b 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import itertools -from types import EllipsisType -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -40,9 +38,8 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - domain, expanded, fillvalue=slice(None) - ): + expanded += (slice(None),) * (len(domain) - len(expanded)) + for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): try: sliced = _slice_range(rng, idx) @@ -93,15 +90,14 @@ def _absolute_sub_domain( def _expand_ellipsis( - indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int + indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: - expanded_indices: list[common.IntIndex | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) + if Ellipsis in indices: + idx = indices.index(Ellipsis) + indices = ( + indices[:idx] + (slice(None),) * (target_size - (len(indices) - 1)) + indices[idx + 1 :] + ) + return cast(tuple[common.IntIndex | slice, ...], indices) # mypy leave me alone and trust me! def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index b190d1a821..c115487367 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -17,6 +17,11 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + domain: common.Domain + indices: common.AnyIndex + index: common.AnyIndexElement + dim: common.Dimension + def __init__( self, domain: common.Domain, @@ -27,3 +32,7 @@ def __init__( super().__init__( f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." ) + self.domain = domain + self.indices = indices + self.index = index + self.dim = dim diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 7e2dc598cd..effd5e2694 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -224,7 +224,7 @@ def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - def __and__(self, other: common.Field) -> _BaseNdArrayField: + def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( self, other @@ -233,14 +233,14 @@ def __and__(self, other: common.Field) -> _BaseNdArrayField: __rand__ = __and__ - def __or__(self, other: common.Field) -> _BaseNdArrayField: + def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ - def __xor__(self, other: common.Field) -> _BaseNdArrayField: + def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( self, other @@ -263,7 +263,7 @@ def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.Relative if common.is_absolute_index_sequence(index_sequence) else index_sequence ) - assert common.is_relative_index_sequence(slice_), slice_ + assert common.is_relative_index_sequence(slice_) return new_domain, slice_ @@ -352,7 +352,7 @@ def __setitem__( index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # use `self.ndarray.at(index).set(value)` + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") common.field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 3e7dc2d4d3..006b3057b0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar, TypeGuard +from typing import Any, ClassVar, TypeGuard, TypeVar class RecursionGuard: @@ -51,5 +51,8 @@ def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) -def is_tuple_of(v: Any, t: type) -> TypeGuard[tuple]: +_T = TypeVar("_T") + + +def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 444978097c..640ed326bb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -22,12 +22,17 @@ from gt4py.next.embedded.common import _slice_range, sub_domain -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) - - result = _slice_range(input_range, slice_obj) +@pytest.mark.parametrize( + "rng, slce, expected", + [ + (UnitRange(2, 10), slice(2, -2), UnitRange(4, 8)), + (UnitRange(2, 10), slice(2, None), UnitRange(4, 10)), + (UnitRange(2, 10), slice(None, -2), UnitRange(2, 8)), + (UnitRange(2, 10), slice(None), UnitRange(2, 10)), + ], +) +def test_slice_range(rng, slce, expected): + result = _slice_range(rng, slce) assert result == expected @@ -114,6 +119,11 @@ def test_slice_range(): (slice(1, 2), Ellipsis, slice(2, 3)), [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + ), ], ) def test_sub_domain(domain, index, expected): From 9ff35ab74da602b515579f9cfe02e70d656b5d9a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:27:38 +0200 Subject: [PATCH 18/21] Apply suggestions from code review Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 5d7a5f480e..f5af3d1f97 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,7 +25,7 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType -from .embedded import nd_array_field +from .embedded import nd_array_field as _nd_array_field # Just for registering field implementations from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here From 0ceaaf4dd28b38061770cc386200d7a4a997bde7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:28:30 +0200 Subject: [PATCH 19/21] address review comments --- src/gt4py/next/common.py | 17 ++++++++--------- src/gt4py/next/embedded/common.py | 2 +- src/gt4py/next/embedded/exceptions.py | 4 ++-- src/gt4py/next/embedded/nd_array_field.py | 10 ++++++---- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 84e0d1f145..f5c96f07df 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -168,7 +168,7 @@ def unit_range(r: RangeLike) -> UnitRange: slice | IntIndex | types.EllipsisType, ... ] # is a tuple but called Sequence for symmetry AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence -AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence +AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence def is_int_index(p: Any) -> TypeGuard[IntIndex]: @@ -177,7 +177,7 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: +def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 @@ -186,13 +186,13 @@ def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: ) -def is_named_index(v: AnyIndex) -> TypeGuard[NamedRange]: +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) ) -def is_any_index_element(v: AnyIndex) -> TypeGuard[AnyIndexElement]: +def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) or is_named_range(v) @@ -212,7 +212,7 @@ def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSe ) -def as_any_index_sequence(index: AnyIndex) -> AnyIndexSequence: +def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` return cast( AnyIndexSequence, @@ -237,7 +237,6 @@ def __init__( dims: Optional[tuple[Dimension, ...]] = None, ranges: Optional[tuple[UnitRange, ...]] = None, ) -> None: - # TODO throw user error in case pre-conditions are not met if dims is not None or ranges is not None: if dims is None and ranges is None: raise ValueError("Either both none of `dims` and `ranges` must be specified.") @@ -420,7 +419,7 @@ def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndex) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @@ -429,7 +428,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndex) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -514,7 +513,7 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: AnyIndex, value: Field | core_defs.ScalarT) -> None: + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 6ad909fc1b..3799923d87 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -18,7 +18,7 @@ from gt4py.next.embedded import exceptions as embedded_exceptions -def sub_domain(domain: common.Domain, index: common.AnyIndex) -> common.Domain: +def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: index_sequence = common.as_any_index_sequence(index) if common.is_absolute_index_sequence(index_sequence): diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index c115487367..393123db36 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -18,14 +18,14 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): domain: common.Domain - indices: common.AnyIndex + indices: common.AnyIndexSpec index: common.AnyIndexElement dim: common.Dimension def __init__( self, domain: common.Domain, - indices: common.AnyIndex, + indices: common.AnyIndexSpec, index: common.AnyIndexElement, dim: common.Dimension, ): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e684f3e24c..fcaa09e7eb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -160,7 +160,7 @@ def from_array( def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() - def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] @@ -226,7 +226,9 @@ def __invert__(self) -> _BaseNdArrayField: return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") - def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.RelativeIndexSequence]: + def _slice( + self, index: common.AnyIndexSpec + ) -> tuple[common.Domain, common.RelativeIndexSequence]: new_domain = embedded_common.sub_domain(self.domain, index) index_sequence = common.as_any_index_sequence(index) @@ -269,7 +271,7 @@ def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.Relative def _np_cp_setitem( self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], - index: common.AnyIndex, + index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: target_domain, target_slice = self._slice(index) @@ -321,7 +323,7 @@ class JaxArrayField(_BaseNdArrayField): def __setitem__( self, - index: common.AnyIndex, + index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` From d7366cba9114b24213c9abf3cdc644681360aa0b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:34:29 +0200 Subject: [PATCH 20/21] fix formatting --- src/gt4py/next/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index f5af3d1f97..cc35899668 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,7 +25,9 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType -from .embedded import nd_array_field as _nd_array_field # Just for registering field implementations +from .embedded import ( # Just for registering field implementations + nd_array_field as _nd_array_field, +) from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here From f2a9030c81cd18781877fbc004c6442fab5acb9c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 10:10:02 +0200 Subject: [PATCH 21/21] fix comments --- src/gt4py/next/common.py | 21 ++++++++++++++++++--- src/gt4py/next/embedded/common.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f5c96f07df..b85239cd0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -331,7 +331,9 @@ def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" -DomainLike: TypeAlias = Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +) # `Domain` is `Sequence[NamedRange]` and therefore a subset def domain(domain_like: DomainLike) -> Domain: @@ -656,13 +658,26 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: class FieldBuiltinFuncRegistry: + """ + Mixin for adding `fbuiltins` registry to a `Field`. + + Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry, + dispatching (via ChainMap) to its parent's registries. + """ + _builtin_func_map: collections.ChainMap[ fbuiltins.BuiltInFunction, Callable ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): - # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) - cls._builtin_func_map = cls._builtin_func_map.new_child() + cls._builtin_func_map = collections.ChainMap( + {}, # New empty `dict`` for new registrations on this class + *[ + c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order + for c in cls.__mro__ + if "_builtin_func_map" in c.__dict__ + ], + ) @classmethod def register_builtin_func( diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 3799923d87..37ba4954f3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -112,7 +112,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit ) + (slice_obj.stop or len(input_range)) if start < input_range.start or stop > input_range.stop: - raise IndexError() + raise IndexError("Slice out of range (no clipping following array API standard).") return common.UnitRange(start, stop)