diff --git a/docs/user/cartesian/Makefile b/docs/user/cartesian/Makefile index 091bc3b8d2..13e692b96d 100644 --- a/docs/user/cartesian/Makefile +++ b/docs/user/cartesian/Makefile @@ -2,12 +2,13 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -SRCDIR = ../../../src/gt4py -AUTODOCDIR = _source -BUILDDIR = _build +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +SRCDIR = ../../../src/gt4py +SPHINX_APIDOC_OPTS = --private # private modules for gt4py._core +AUTODOCDIR = _source +BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) @@ -55,7 +56,7 @@ clean: autodoc: @echo @echo "Running sphinx-apidoc..." - sphinx-apidoc ${SPHINX_OPTS} -o ${AUTODOCDIR} ${SRCDIR} + sphinx-apidoc ${SPHINX_APIDOC_OPTS} -o ${AUTODOCDIR} ${SRCDIR} @echo @echo "sphinx-apidoc finished. The generated autodocs are in $(AUTODOCDIR)." diff --git a/docs/user/cartesian/arrays.rst b/docs/user/cartesian/arrays.rst index 6788e2757f..6ef7c6e5c1 100644 --- a/docs/user/cartesian/arrays.rst +++ b/docs/user/cartesian/arrays.rst @@ -39,6 +39,8 @@ Internally, gt4py uses the utilities :code:`gt4py.utils.as_numpy` and :code:`gt4 buffers. GT4Py developers are advised to always use those utilities as to guarantee support across gt4py as the supported interfaces are extended. +.. _cartesian-arrays-dimension-mapping: + Dimension Mapping ^^^^^^^^^^^^^^^^^ @@ -56,6 +58,8 @@ which implements this lookup. Note: Support for xarray can be added manually by the user by means of the mechanism described `here `_. +.. _cartesian-arrays-default-origin: + Default Origin ^^^^^^^^^^^^^^ @@ -180,4 +184,4 @@ Additionally, these **optional** keyword-only parameters are accepted: determine the default layout for the storage. Currently supported will be :code:`"I"`, :code:`"J"`, :code:`"K"` and additional dimensions as string representations of integers, starting at :code:`"0"`. (This information is not retained in the resulting array, and needs to be specified instead - with the :code:`__gt_dims__` interface. ) \ No newline at end of file + with the :code:`__gt_dims__` interface. ) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 2546ae3e4e..059ba6c24c 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -213,18 +213,13 @@ class DType(Generic[ScalarT]): """ scalar_type: Type[ScalarT] - tensor_shape: TensorShape + tensor_shape: TensorShape = dataclasses.field(default=()) - def __init__( - self, scalar_type: Type[ScalarT], tensor_shape: Sequence[IntegralScalar] = () - ) -> None: - if not isinstance(scalar_type, type): - raise TypeError(f"Invalid scalar type '{scalar_type}'") - if not is_valid_tensor_shape(tensor_shape): - raise TypeError(f"Invalid tensor shape '{tensor_shape}'") - - object.__setattr__(self, "scalar_type", scalar_type) - object.__setattr__(self, "tensor_shape", tensor_shape) + def __post_init__(self) -> None: + if not isinstance(self.scalar_type, type): + raise TypeError(f"Invalid scalar type '{self.scalar_type}'") + if not is_valid_tensor_shape(self.tensor_shape): + raise TypeError(f"Invalid tensor shape '{self.tensor_shape}'") @functools.cached_property def kind(self) -> DTypeKind: @@ -251,6 +246,16 @@ def lanes(self) -> int: def subndim(self) -> int: return len(self.tensor_shape) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, DType) + and self.scalar_type == other.scalar_type + and self.tensor_shape == other.tensor_shape + ) + + def __hash__(self) -> int: + return hash((self.scalar_type, self.tensor_shape)) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): @@ -322,6 +327,11 @@ class Float64DType(FloatingDType[float64]): scalar_type: Final[Type[float64]] = dataclasses.field(default=float64, init=False) +@dataclasses.dataclass(frozen=True) +class BoolDType(DType[bool_]): + scalar_type: Final[Type[bool_]] = dataclasses.field(default=bool_, init=False) + + DTypeLike = Union[DType, npt.DTypeLike] @@ -332,11 +342,29 @@ def dtype(dtype_like: DTypeLike) -> DType: # -- Custom protocols -- class GTDimsInterface(Protocol): - __gt_dims__: Tuple[str, ...] + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming the buffer dimensions. + + In `gt4py.cartesian` the allowed values are `"I"`, `"J"` and `"K"` with the established semantics. + + See :ref:`cartesian-arrays-dimension-mapping` for details. + """ + + @property + def __gt_dims__(self) -> Tuple[str, ...]: + ... class GTOriginInterface(Protocol): - __gt_origin__: Tuple[int, ...] + """ + A `GTOriginInterface` is an object providing `__gt_origin__`, describing the origin of a buffer. + + See :ref:`cartesian-arrays-default-origin` for details. + """ + + @property + def __gt_origin__(self) -> Tuple[int, ...]: + ... # -- Device representation -- diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index b4d1fc0c09..cc35899668 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,6 +25,9 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType +from .embedded import ( # Just for registering field implementations + nd_array_field as _nd_array_field, +) from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 866b2aadb7..b85239cd0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,9 @@ import enum import functools import sys -from collections.abc import Sequence, Set -from types import EllipsisType -from typing import ChainMap, TypeGuard, overload +import types +from collections.abc import Mapping, Sequence, Set +from typing import overload import numpy as np import numpy.typing as npt @@ -37,16 +37,18 @@ ParamSpec, Protocol, TypeAlias, + TypeGuard, TypeVar, + cast, extended_runtime_checkable, - final, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimT = TypeVar("DimT", bound="Dimension") -DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) +DimsT = TypeVar( + "DimsT", covariant=True +) # bound to `Sequence[Dimension]` if instance of Dimension would be a type class Infinity(int): @@ -66,7 +68,7 @@ class DimensionKind(StrEnum): LOCAL = "local" def __str__(self): - return f"{type(self).__name__}.{self.name}" + return self.value @dataclasses.dataclass(frozen=True) @@ -75,7 +77,7 @@ class Dimension: kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) def __str__(self): - return f'Dimension(value="{self.value}", kind={self.kind})' + return f"{self.value}[{self.kind}]" @dataclasses.dataclass(frozen=True) @@ -136,36 +138,139 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __str__(self) -> str: + return f"({self.start}:{self.stop})" + + +RangeLike: TypeAlias = UnitRange | range | tuple[int, int] + + +def unit_range(r: RangeLike) -> UnitRange: + if isinstance(r, UnitRange): + return r + if isinstance(r, range): + if r.step != 1: + raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + return UnitRange(r.start, r.stop) + if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + return UnitRange(r[0], r[1]) + raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") -DomainRange: TypeAlias = UnitRange | int + +IntIndex: TypeAlias = int | core_defs.IntegralScalar +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] NamedRange: TypeAlias = tuple[Dimension, UnitRange] -NamedIndex: TypeAlias = tuple[Dimension, int] -DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] -FieldSlice: TypeAlias = ( - DomainSlice - | tuple[slice | int | EllipsisType, ...] - | slice - | int - | EllipsisType - | NamedRange - | NamedIndex -) +RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement +AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] +RelativeIndexSequence: TypeAlias = tuple[ + slice | IntIndex | types.EllipsisType, ... +] # is a tuple but called Sequence for symmetry +AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence +AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence + + +def is_int_index(p: Any) -> TypeGuard[IntIndex]: + # should be replaced by isinstance(p, IntIndex), but mypy complains with + # `Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type]` + return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) + + +def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) + and len(v) == 2 + and isinstance(v[0], Dimension) + and isinstance(v[1], UnitRange) + ) -@dataclasses.dataclass(frozen=True) +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) + ) + + +def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: + return ( + is_int_index(v) + or is_named_range(v) + or is_named_index(v) + or isinstance(v, slice) + or v is Ellipsis + ) + + +def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: + return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + + +def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: + return isinstance(v, tuple) and all( + isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v + ) + + +def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: + # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` + return cast( + AnyIndexSequence, + (index,) if is_any_index_element(index) else index, + ) + + +def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: + return (v[0], unit_range(v[1])) + + +@dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[NamedRange]): + """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" + dims: tuple[Dimension, ...] ranges: tuple[UnitRange, ...] - def __post_init__(self): + def __init__( + self, + *args: NamedRange, + dims: Optional[tuple[Dimension, ...]] = None, + ranges: Optional[tuple[UnitRange, ...]] = None, + ) -> None: + if dims is not None or ranges is not None: + if dims is None and ranges is None: + raise ValueError("Either both none of `dims` and `ranges` must be specified.") + if len(args) > 0: + raise ValueError( + "No extra `args` allowed when constructing fomr `dims` and `ranges`." + ) + + assert dims is not None and ranges is not None # for mypy + if not all(isinstance(dim, Dimension) for dim in dims): + raise ValueError( + f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + ) + if not all(isinstance(rng, UnitRange) for rng in ranges): + raise ValueError( + f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + ) + if len(dims) != len(ranges): + raise ValueError( + f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." + ) + + object.__setattr__(self, "dims", dims) + object.__setattr__(self, "ranges", ranges) + else: + if not all(is_named_range(arg) for arg in args): + raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + dims, ranges = zip(*args) if args else ((), ()) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) + if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") - if len(self.dims) != len(self.ranges): - raise ValueError( - f"Number of provided dimensions ({len(self.dims)}) does not match number of provided ranges ({len(self.ranges)})." - ) - def __len__(self) -> int: return len(self.ranges) @@ -174,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> "Domain": + def __getitem__(self, index: slice) -> Domain: ... @overload @@ -187,7 +292,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] - return Domain(dims_slice, ranges_slice) + return Domain(dims=dims_slice, ranges=ranges_slice) elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) @@ -197,7 +302,21 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") - def __and__(self, other: "Domain") -> "Domain": + def __and__(self, other: Domain) -> Domain: + """ + Intersect `Domain`s, missing `Dimension`s are considered infinite. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + + >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -206,15 +325,49 @@ def __and__(self, other: "Domain") -> "Domain": _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(broadcast_dims, intersected_ranges) + return Domain(dims=broadcast_dims, ranges=intersected_ranges) + + def __str__(self) -> str: + return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + + +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +) # `Domain` is `Sequence[NamedRange]` and therefore a subset + + +def domain(domain_like: DomainLike) -> Domain: + """ + Construct `Domain` from `DomainLike` object. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> domain(((I, (2, 4)), (J, (3, 5)))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + + >>> domain({I: (2, 4), J: (3, 5)}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + """ + if isinstance(domain_like, Domain): + return domain_like + if isinstance(domain_like, Sequence): + return Domain(*tuple(named_range(d) for d in domain_like)) + if isinstance(domain_like, Mapping): + return Domain( + dims=tuple(domain_like.keys()), + ranges=tuple(unit_range(r) for r in domain_like.values()), + ) + raise ValueError(f"`{domain_like}` is not `DomainLike`.") def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange(Infinity.negative(), Infinity.positive()) - for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims ) @@ -230,8 +383,22 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... +class NextGTDimsInterface(Protocol): + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. + + The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . + """ + + # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + ... + + @extended_runtime_checkable -class Field(Protocol[DimsT, core_defs.ScalarT]): +class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property @@ -242,24 +409,19 @@ def domain(self) -> Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... - @property - def value_type(self) -> type[core_defs.ScalarT]: - ... - @property def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: - codomain = self.value_type.__name__ - return f"⟨{self.domain!s} → {codomain}⟩" + return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @@ -268,7 +430,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -279,6 +441,10 @@ def __abs__(self) -> Field: def __neg__(self) -> Field: ... + @abc.abstractmethod + def __invert__(self) -> Field: + """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @@ -323,23 +489,44 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + @abc.abstractmethod + def __and__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __or__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __xor__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + def is_field( v: Any, -) -> TypeGuard[Field]: # this function is introduced to localize the `type: ignore`` +) -> TypeGuard[Field]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed return isinstance(v, Field) # type: ignore[misc] # we use extended_runtime_checkable -class FieldABC(Field[DimsT, core_defs.ScalarT]): - """Abstract base class for implementations of the :class:`Field` protocol.""" +@extended_runtime_checkable +class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): + @abc.abstractmethod + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: + ... - @final - def __setattr__(self, key, value) -> None: - raise TypeError("Immutable type") - @final - def __setitem__(self, key, value) -> None: - raise TypeError("Immutable type") +def is_mutable_field( + v: Field, +) -> TypeGuard[MutableField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable @functools.singledispatch @@ -347,8 +534,8 @@ def field( definition: Any, /, *, - domain: Optional[Any] = None, # TODO(havogt): provide domain_like to Domain conversion - value_type: Optional[type] = None, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -470,26 +657,27 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: return topologically_sorted_list -def is_named_range(v: Any) -> TypeGuard[NamedRange]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], UnitRange) - - -def is_named_index(v: Any) -> TypeGuard[NamedIndex]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], int) - - -def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: - return isinstance(index, Sequence) and all( - is_named_range(idx) or is_named_index(idx) for idx in index - ) +class FieldBuiltinFuncRegistry: + """ + Mixin for adding `fbuiltins` registry to a `Field`. + Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry, + dispatching (via ChainMap) to its parent's registries. + """ -class FieldBuiltinFuncRegistry: - _builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap() + _builtin_func_map: collections.ChainMap[ + fbuiltins.BuiltInFunction, Callable + ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): - # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) - cls._builtin_func_map = cls._builtin_func_map.new_child() + cls._builtin_func_map = collections.ChainMap( + {}, # New empty `dict`` for new registrations on this class + *[ + c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order + for c in cls.__mro__ + if "_builtin_func_map" in c.__dict__ + ], + ) @classmethod def register_builtin_func( diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py new file mode 100644 index 0000000000..37ba4954f3 --- /dev/null +++ b/src/gt4py/next/embedded/common.py @@ -0,0 +1,127 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, cast + +from gt4py.next import common +from gt4py.next.embedded import exceptions as embedded_exceptions + + +def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: + index_sequence = common.as_any_index_sequence(index) + + if common.is_absolute_index_sequence(index_sequence): + return _absolute_sub_domain(domain, index_sequence) + + if common.is_relative_index_sequence(index_sequence): + return _relative_sub_domain(domain, index_sequence) + + raise IndexError(f"Unsupported index type: {index}") + + +def _relative_sub_domain( + domain: common.Domain, index: common.RelativeIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + + expanded = _expand_ellipsis(index, len(domain)) + if len(domain) < len(expanded): + raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + expanded += (slice(None),) * (len(domain) - len(expanded)) + for (dim, rng), idx in zip(domain, expanded, strict=True): + if isinstance(idx, slice): + try: + sliced = _slice_range(rng, idx) + named_ranges.append((dim, sliced)) + except IndexError: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + else: + # not in new domain + assert common.is_int_index(idx) + new_index = (rng.start if idx >= 0 else rng.stop) + idx + if new_index < rng.start or new_index >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + + return common.Domain(*named_ranges) + + +def _absolute_sub_domain( + domain: common.Domain, index: common.AbsoluteIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + for i, (dim, rng) in enumerate(domain): + if (pos := _find_index_of_dim(dim, index)) is not None: + named_idx = index[pos] + idx = named_idx[1] + if isinstance(idx, common.UnitRange): + if not idx <= rng: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + + named_ranges.append((dim, idx)) + else: + # not in new domain + assert common.is_int_index(idx) + if idx < rng.start or idx >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + else: + # dimension not mentioned in slice + named_ranges.append((dim, domain.ranges[i])) + + return common.Domain(*named_ranges) + + +def _expand_ellipsis( + indices: common.RelativeIndexSequence, target_size: int +) -> tuple[common.IntIndex | slice, ...]: + if Ellipsis in indices: + idx = indices.index(Ellipsis) + indices = ( + indices[:idx] + (slice(None),) * (target_size - (len(indices) - 1)) + indices[idx + 1 :] + ) + return cast(tuple[common.IntIndex | slice, ...], indices) # mypy leave me alone and trust me! + + +def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: + if slice_obj == slice(None): + return common.UnitRange(input_range.start, input_range.stop) + + start = ( + input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop + ) + (slice_obj.start or 0) + stop = ( + input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop + ) + (slice_obj.stop or len(input_range)) + + if start < input_range.start or stop > input_range.stop: + raise IndexError("Slice out of range (no clipping following array API standard).") + + return common.UnitRange(start, stop) + + +def _find_index_of_dim( + dim: common.Dimension, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> Optional[int]: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i + return None diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py new file mode 100644 index 0000000000..393123db36 --- /dev/null +++ b/src/gt4py/next/embedded/exceptions.py @@ -0,0 +1,38 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next import common +from gt4py.next.errors import exceptions as gt4py_exceptions + + +class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + domain: common.Domain + indices: common.AnyIndexSpec + index: common.AnyIndexElement + dim: common.Dimension + + def __init__( + self, + domain: common.Domain, + indices: common.AnyIndexSpec, + index: common.AnyIndexElement, + dim: common.Dimension, + ): + super().__init__( + f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." + ) + self.domain = domain + self.indices = indices + self.index = index + self.dim = dim diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ddef77bb78..fcaa09e7eb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,17 +15,16 @@ from __future__ import annotations import dataclasses -import itertools from collections.abc import Callable, Sequence -from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast +from types import ModuleType +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import FieldBuiltinFuncRegistry +from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -56,7 +55,7 @@ def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_nam def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field): + if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): if not a.domain == b.domain: domain_intersection = a.domain & b.domain a_broadcasted = _broadcast(a, domain_intersection.dims) @@ -82,7 +81,9 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry): +class _BaseNdArrayField( + common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry +): """ Shared field implementation for NumPy-like fields. @@ -94,7 +95,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB _domain: common.Domain _ndarray: core_defs.NDArrayObject - _value_type: type[core_defs.ScalarT] array_ns: ClassVar[ ModuleType @@ -104,13 +104,28 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB def domain(self) -> common.Domain: return self._domain + @property + def shape(self) -> tuple[int, ...]: + return self._ndarray.shape + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return self._domain.dims + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple(-r.start for _, r in self._domain) + @property def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray + def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + return np.asarray(self._ndarray, dtype) + @property - def value_type(self) -> type[core_defs.ScalarT]: - return self._value_type + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(self._ndarray.dtype.type) @classmethod def from_array( @@ -119,38 +134,52 @@ def from_array( | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike /, *, - domain: common.Domain, - value_type: Optional[type] = None, + domain: common.DomainLike, + dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: + domain = common.domain(domain) xp = cls.array_ns - dtype = None - if value_type is not None: - dtype = xp.dtype(value_type) - array = xp.asarray(data, dtype=dtype) - value_type = array.dtype.type # TODO add support for Dimensions as value_type + xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype_like is not None: + assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - assert all(isinstance(d, common.Dimension) for d, r in domain), domain + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim assert all( - len(nr[1]) == s or (s == 1 and nr[1] == common.UnitRange.infinity()) - for nr, s in zip(domain, array.shape) + len(r) == s or (s == 1 and r == common.UnitRange.infinity()) + for r, s in zip(domain.ranges, array.shape) ) - assert value_type is not None # for mypy - return cls(domain, array, value_type) + return cls(domain, array) def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + new_domain, buffer_slice = self._slice(index) + + new_buffer = self.ndarray[buffer_slice] + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new_buffer) + return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new_buffer, domain=new_domain) + + __getitem__ = restrict + __call__ = None # type: ignore[assignment] # TODO: remap __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") @@ -165,78 +194,51 @@ def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") - def __getitem__(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: - if ( - not isinstance(index, tuple) - and not common.is_domain_slice(index) - or common.is_named_index(index) - or common.is_named_range(index) - ): - index = cast(common.FieldSlice, (index,)) + __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - if common.is_domain_slice(index): - return self._getitem_absolute_slice(index) + def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( + self, other + ) + raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") - assert isinstance(index, tuple) - if all(isinstance(idx, (slice, int)) or idx is Ellipsis for idx in index): - return self._getitem_relative_slice(index) + __rand__ = __and__ - raise IndexError(f"Unsupported index type: {index}") + def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") - restrict = ( - __getitem__ # type:ignore[assignment] # TODO(havogt) I don't see the problem that mypy has - ) + __ror__ = __or__ - def _getitem_absolute_slice( - self, index: common.DomainSlice - ) -> common.Field | core_defs.ScalarT: - slices = _get_slices_from_domain_slice(self.domain, index) - new_ranges = [] - new_dims = [] - new = self.ndarray[slices] - - for i, dim in enumerate(self.domain.dims): - if (pos := _find_index_of_dim(dim, index)) is not None: - index_or_range = index[pos][1] - if isinstance(index_or_range, common.UnitRange): - new_ranges.append(index_or_range) - new_dims.append(dim) - else: - # dimension not mentioned in slice - new_ranges.append(self.domain.ranges[i]) - new_dims.append(dim) - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( + self, other + ) + raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new) - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) - - def _getitem_relative_slice( - self, indices: tuple[slice | int | EllipsisType, ...] - ) -> common.Field | core_defs.ScalarT: - new = self.ndarray[indices] - new_dims = [] - new_ranges = [] - - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - self.domain, _expand_ellipsis(indices, len(self.domain)), fillvalue=slice(None) - ): - if isinstance(idx, slice): - new_dims.append(dim) - new_ranges.append(_slice_range(rng, idx)) - else: - assert isinstance(idx, int) # not in new_domain - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + __rxor__ = __xor__ - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new), new - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + def __invert__(self) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + + def _slice( + self, index: common.AnyIndexSpec + ) -> tuple[common.Domain, common.RelativeIndexSequence]: + new_domain = embedded_common.sub_domain(self.domain, index) + + index_sequence = common.as_any_index_sequence(index) + slice_ = ( + _get_slices_from_domain_slice(self.domain, index_sequence) + if common.is_absolute_index_sequence(index_sequence) + else index_sequence + ) + assert common.is_relative_index_sequence(slice_) + return new_domain, slice_ # -- Specialized implementations for intrinsic operations on array fields -- @@ -266,6 +268,25 @@ def _getitem_relative_slice( fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] ) + +def _np_cp_setitem( + self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, +) -> None: + target_domain, target_slice = self._slice(index) + + if common.is_field(value): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self.ndarray, "__setitem__") + self.ndarray[target_slice] = value + + # -- Concrete array implementations -- # NumPy _nd_array_implementations = [np] @@ -275,6 +296,8 @@ def _getitem_relative_slice( class NumPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = np + __setitem__ = _np_cp_setitem + common.field.register(np.ndarray, NumPyArrayField.from_array) @@ -286,6 +309,8 @@ class NumPyArrayField(_BaseNdArrayField): class CuPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = cp + __setitem__ = _np_cp_setitem + common.field.register(cp.ndarray, CuPyArrayField.from_array) # JAX @@ -296,38 +321,30 @@ class CuPyArrayField(_BaseNdArrayField): class JaxArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = jnp - common.field.register(jnp.ndarray, JaxArrayField.from_array) - + def __setitem__( + self, + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` + raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") -def _find_index_of_dim( - dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> Optional[int]: - for i, (d, _) in enumerate(domain_slice): - if dim == d: - return i - return None + common.field.register(jnp.ndarray, JaxArrayField.from_array) def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice: list[slice | None] = [] - new_domain_dims = [] - new_domain_ranges = [] + named_ranges = [] for dim in new_dimensions: - if (pos := _find_index_of_dim(dim, field.domain)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - new_domain_dims.append(dim) - new_domain_ranges.append(field.domain[pos][1]) + named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - new_domain_dims.append(dim) - new_domain_ranges.append( - common.UnitRange(common.Infinity.negative(), common.Infinity.positive()) + named_ranges.append( + (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) ) - return common.field( - field.ndarray[tuple(domain_slice)], - domain=common.Domain(tuple(new_domain_dims), tuple(new_domain_ranges)), - ) + return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) def _builtins_broadcast( @@ -344,7 +361,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | int | None, ...]: +) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -359,10 +376,10 @@ def _get_slices_from_domain_slice( specified in the Domain. If a dimension is not included in the named indices or ranges, a None is used to indicate expansion along that axis. """ - slice_indices: list[slice | int | None] = [] + slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := _find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: @@ -370,7 +387,9 @@ def _get_slices_from_domain_slice( return tuple(slice_indices) -def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> slice | int: +def _compute_slice( + rng: common.UnitRange | common.IntIndex, domain: common.Domain, pos: int +) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. Args: @@ -392,34 +411,7 @@ def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> rng.start - domain.ranges[pos].start, rng.stop - domain.ranges[pos].start, ) - elif isinstance(rng, int): + elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") - - -def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: - # handle slice(None) case - if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) - - start = ( - input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop - ) + (slice_obj.start or 0) - stop = ( - input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop - ) + (slice_obj.stop or len(input_range)) - - return common.UnitRange(start, stop) - - -def _expand_ellipsis( - indices: tuple[int | slice | EllipsisType, ...], target_size: int -) -> tuple[int | slice, ...]: - expanded_indices: list[int | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 74230263db..e956858549 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -33,17 +33,19 @@ from . import formatting -class DSLError(Exception): +class GT4PyError(Exception): + @property + def message(self) -> str: + return self.args[0] + + +class DSLError(GT4PyError): location: Optional[SourceLocation] def __init__(self, location: Optional[SourceLocation], message: str) -> None: self.location = location super().__init__(message) - @property - def message(self) -> str: - return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> DSLError: self.location = location return self diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ba027be13c..52aae34b3f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + import dataclasses import inspect from builtins import bool, float, int, tuple diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 0c5de764f2..006b3057b0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeGuard, TypeVar class RecursionGuard: @@ -49,3 +49,10 @@ def __enter__(self): def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) + + +_T = TypeVar("_T") + + +def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: + return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index 0bc5a98a4e..c1bee4fa2f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -82,7 +82,7 @@ def scan(inp: int32) -> int32: expected = textwrap.dedent( f""" - @scan_operator(axis=Dimension(value="KDim", kind=DimensionKind.VERTICAL), forward=False, init=1) + @scan_operator(axis=KDim[vertical], forward=False, init=1) def scan(inp: int32) -> int32: {ssa.unique_name("foo", 0)} = inp return inp diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py new file mode 100644 index 0000000000..640ed326bb --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -0,0 +1,137 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Sequence + +import pytest + +from gt4py.next import common +from gt4py.next.common import UnitRange +from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.embedded.common import _slice_range, sub_domain + + +@pytest.mark.parametrize( + "rng, slce, expected", + [ + (UnitRange(2, 10), slice(2, -2), UnitRange(4, 8)), + (UnitRange(2, 10), slice(2, None), UnitRange(4, 10)), + (UnitRange(2, 10), slice(None, -2), UnitRange(2, 8)), + (UnitRange(2, 10), slice(None), UnitRange(2, 10)), + ], +) +def test_slice_range(rng, slce, expected): + result = _slice_range(rng, slce) + assert result == expected + + +I = common.Dimension("I") +J = common.Dimension("J") +K = common.Dimension("K") + + +@pytest.mark.parametrize( + "domain, index, expected", + [ + ([(I, (2, 5))], 1, []), + ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), + ([(I, (2, 5))], (I, 2), []), + ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], 1, []), + ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), + ([(I, (-2, 3))], (I, 1), []), + ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], -5, []), + ([(I, (-2, 3))], -6, IndexError), + ([(I, (-2, 3))], slice(-7, -6), IndexError), + ([(I, (-2, 3))], slice(-6, -7), IndexError), + ([(I, (-2, 3))], 4, []), + ([(I, (-2, 3))], 5, IndexError), + ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), + ([(I, (-2, 3))], slice(5, 6), IndexError), + ([(I, (-2, 3))], (I, -3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], (I, 3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + 2, + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + slice(2, 3), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, 2), + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, UnitRange(2, 3)), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, 3), + [(I, (2, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, UnitRange(4, 5)), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, 3), (I, 2)), + [(K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, UnitRange(4, 5)), (I, 2)), + [(J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(2, 3)), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (Ellipsis, slice(2, 3)), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), Ellipsis, slice(2, 3)), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + ), + ], +) +def test_sub_domain(domain, index, expected): + domain = common.domain(domain) + if expected is IndexError: + with pytest.raises(embedded_exceptions.IndexOutOfBounds): + sub_domain(domain, index) + else: + expected = common.domain(expected) + result = sub_domain(domain, index) + assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index a2aa3112bd..95093c8307 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -22,8 +22,8 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange -from gt4py.next.embedded import nd_array_field -from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice, _slice_range +from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field +from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -40,16 +40,42 @@ def nd_array_implementation(request): @pytest.fixture( - params=[operator.add, operator.sub, operator.mul, operator.truediv, operator.floordiv], + params=[ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + ] ) -def binary_op(request): +def binary_arithmetic_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation): +@pytest.fixture( + params=[operator.xor, operator.and_, operator.or_], +) +def binary_logical_op(request): + yield request.param + + +@pytest.fixture(params=[operator.neg, operator.pos]) +def unary_arithmetic_op(request): + yield request.param + + +@pytest.fixture(params=[operator.invert]) +def unary_logical_op(request): + yield request.param + + +def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): + if not dtype: + dtype = nd_array_implementation.float32 return common.field( - nd_array_implementation.asarray(lst, dtype=nd_array_implementation.float32), - domain=((common.Dimension("foo"), common.UnitRange(0, len(lst))),), + nd_array_implementation.asarray(lst, dtype=dtype), + domain={common.Dimension("foo"): (0, len(lst))}, ) @@ -72,16 +98,57 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) -def test_binary_ops(binary_op, nd_array_implementation): +def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - result = binary_op(*field_inputs) + result = binary_arithmetic_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_binary_logical_ops(binary_logical_op, nd_array_implementation): + inp_a = [True, True, False, False] + inp_b = [True, False, True, False] + inputs = [inp_a, inp_b] + + expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) + + field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + + result = binary_logical_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_logical_ops(unary_logical_op, nd_array_implementation): + inp = [ + True, + False, + ] + + expected = unary_logical_op(np.asarray(inp)) + + field_input = _make_field(inp, nd_array_implementation, dtype=bool) + + result = unary_logical_op(field_input) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): + inp = [1.0, -2.0, 0.0] + + expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) + + field_input = _make_field(inp, nd_array_implementation) + + result = unary_arithmetic_op(field_input) assert np.allclose(result.ndarray, expected) @@ -93,7 +160,7 @@ def test_binary_ops(binary_op, nd_array_implementation): ((JDim,), (None, slice(5, 10))), ], ) -def test_binary_operations_with_intersection(binary_op, dims, expected_indices): +def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): arr1 = np.arange(10) arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) @@ -103,8 +170,8 @@ def test_binary_operations_with_intersection(binary_op, dims, expected_indices): field1 = common.field(arr1, domain=arr1_domain) field2 = common.field(arr2, domain=arr2_domain) - op_result = binary_op(field1, field2) - expected_result = binary_op(arr1[expected_indices[0], expected_indices[1]], arr2) + op_result = binary_arithmetic_op(field1, field2) + expected_result = binary_arithmetic_op(arr1[expected_indices[0], expected_indices[1]], arr2) assert op_result.ndarray.shape == (5, 5) assert np.allclose(op_result.ndarray, expected_result) @@ -122,10 +189,8 @@ def product_nd_array_implementation(request): def test_mixed_fields(product_nd_array_implementation): first_impl, second_impl = product_nd_array_implementation - if (first_impl.__name__ == "cupy" and second_impl.__name__ == "numpy") or ( - first_impl.__name__ == "numpy" and second_impl.__name__ == "cupy" - ): - pytest.skip("Binary operation between CuPy and NumPy requires explicit conversion.") + if "numpy" in first_impl.__name__ and "cupy" in second_impl.__name__: + pytest.skip("Binary operation between NumPy and CuPy requires explicit conversion.") inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] @@ -271,7 +336,7 @@ def test_get_slices_invalid_type(): (JDim, KDim), (2, 15), ), - ((IDim, 1), (JDim, KDim), (10, 15)), + ((IDim, 5), (JDim, KDim), (10, 15)), ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), ], ) @@ -282,20 +347,20 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): field = common.field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain.dims == expected_dimensions def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) - field = common.field(np.ones((10, 10), dtype=np.int32), domain=domain) + field = common.field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 2), (JDim, 4)) + named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] assert isinstance(value, np.int32) - assert value == 1 + assert value == 21 @pytest.mark.parametrize( @@ -304,19 +369,23 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 4))), + Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + ), + ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ( + (Ellipsis, 1), + (10,), + Domain((IDim, UnitRange(5, 15))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 12)))), - ((Ellipsis, 1), (10,), Domain((IDim,), (UnitRange(5, 15),))), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, JDim), (UnitRange(7, 8), UnitRange(7, 9))), + Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim,), (UnitRange(6, 7),)), + Domain((IDim, UnitRange(6, 7))), ), ], ) @@ -325,7 +394,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -333,32 +402,44 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain((IDim, JDim), (UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None), slice(None), slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), - ((0, Ellipsis, 0), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ], ) @@ -369,7 +450,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -391,7 +472,7 @@ def test_relative_indexing_out_of_bounds(lazy_slice): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common.field(np.ones((10, 10)), domain=domain) - with pytest.raises(IndexError): + with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) @@ -403,10 +484,40 @@ def test_field_unsupported_index(index): field[index] -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) +@pytest.mark.parametrize( + "index, value", + [ + ((1, 1), 42.0), + ((1, slice(None)), np.ones((10,)) * 42.0), + ( + (1, slice(None)), + common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + ), + ], +) +def test_setitem(index, value): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + expected = np.copy(field.ndarray) + expected[index] = value + + field[index] = value + + assert np.allclose(field.ndarray, expected) + + +def test_setitem_wrong_domain(): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + value_incompatible = common.field( + np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) + ) - result = _slice_range(input_range, slice_obj) - assert result == expected + with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8cdc96254c..31e35221ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -15,7 +15,17 @@ import pytest -from gt4py.next.common import Dimension, DimensionKind, Domain, Infinity, UnitRange, promote_dims +from gt4py.next.common import ( + Dimension, + DimensionKind, + Domain, + Infinity, + UnitRange, + domain, + named_range, + promote_dims, + unit_range, +) IDim = Dimension("IDim") @@ -25,15 +35,8 @@ @pytest.fixture -def domain(): - range1 = UnitRange(0, 10) - range2 = UnitRange(5, 15) - range3 = UnitRange(20, 30) - - dimensions = (IDim, JDim, KDim) - ranges = (range1, range2, range3) - - return Domain(dimensions, ranges) +def a_domain(): + return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) def test_empty_range(): @@ -53,6 +56,11 @@ def test_unit_range_length(rng): assert len(rng) == 10 +@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) +def test_unit_range_like(rng_like): + assert unit_range(rng_like) == UnitRange(2, 4) + + def test_unit_range_repr(rng): assert repr(rng) == "UnitRange(-5, 5)" @@ -142,54 +150,87 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() -def test_domain_length(domain): - assert len(domain) == 3 +@pytest.mark.parametrize( + "named_rng_like", + [ + (IDim, (2, 4)), + (IDim, range(2, 4)), + (IDim, UnitRange(2, 4)), + ], +) +def test_named_range_like(named_rng_like): + assert named_range(named_rng_like) == (IDim, UnitRange(2, 4)) + + +def test_domain_length(a_domain): + assert len(a_domain) == 3 -def test_domain_iteration(domain): - iterated_values = [val for val in domain] - assert iterated_values == list(zip(domain.dims, domain.ranges)) +@pytest.mark.parametrize( + "domain_like", + [ + (Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))), + ((IDim, (2, 4)), (JDim, (3, 5))), + ({IDim: (2, 4), JDim: (3, 5)}), + ], +) +def test_domain_like(domain_like): + assert domain(domain_like) == Domain( + dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) + ) + +def test_domain_iteration(a_domain): + iterated_values = [val for val in a_domain] + assert iterated_values == list(zip(a_domain.dims, a_domain.ranges)) -def test_domain_contains_named_range(domain): - assert (IDim, UnitRange(0, 10)) in domain - assert (IDim, UnitRange(-5, 5)) not in domain + +def test_domain_contains_named_range(a_domain): + assert (IDim, UnitRange(0, 10)) in a_domain + assert (IDim, UnitRange(-5, 5)) not in a_domain @pytest.mark.parametrize( "second_domain, expected", [ ( - Domain((IDim, JDim), (UnitRange(2, 12), UnitRange(7, 17))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30)), + ), ), ( - Domain((IDim, KDim), (UnitRange(2, 12), UnitRange(7, 27))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27))), + Domain(dims=(IDim, KDim), ranges=(UnitRange(2, 12), UnitRange(7, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27)), + ), ), ( - Domain((JDim, KDim), (UnitRange(2, 12), UnitRange(4, 27))), - Domain((IDim, JDim, KDim), (UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27))), + Domain(dims=(JDim, KDim), ranges=(UnitRange(2, 12), UnitRange(4, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27)), + ), ), ], ) -def test_domain_intersection_different_dimensions(domain, second_domain, expected): - result_domain = domain & second_domain +def test_domain_intersection_different_dimensions(a_domain, second_domain, expected): + result_domain = a_domain & second_domain print(result_domain) assert result_domain == expected -def test_domain_intersection_reversed_dimensions(domain): - dimensions = (JDim, IDim) - ranges = (UnitRange(2, 12), UnitRange(7, 17)) - domain2 = Domain(dimensions, ranges) +def test_domain_intersection_reversed_dimensions(a_domain): + domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) with pytest.raises( ValueError, match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", ): - domain & domain2 + a_domain & domain2 @pytest.mark.parametrize( @@ -202,8 +243,8 @@ def test_domain_intersection_reversed_dimensions(domain): (-2, (JDim, UnitRange(5, 15))), ], ) -def test_domain_integer_indexing(domain, index, expected): - result = domain[index] +def test_domain_integer_indexing(a_domain, index, expected): + result = a_domain[index] assert result == expected @@ -214,8 +255,8 @@ def test_domain_integer_indexing(domain, index, expected): (slice(1, None), ((JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30)))), ], ) -def test_domain_slice_indexing(domain, slice_obj, expected): - result = domain[slice_obj] +def test_domain_slice_indexing(a_domain, slice_obj, expected): + result = a_domain[slice_obj] assert isinstance(result, Domain) assert len(result) == len(expected) assert all(res == exp for res, exp in zip(result, expected)) @@ -228,28 +269,28 @@ def test_domain_slice_indexing(domain, slice_obj, expected): (KDim, (KDim, UnitRange(20, 30))), ], ) -def test_domain_dimension_indexing(domain, index, expected_result): - result = domain[index] +def test_domain_dimension_indexing(a_domain, index, expected_result): + result = a_domain[index] assert result == expected_result -def test_domain_indexing_dimension_missing(domain): +def test_domain_indexing_dimension_missing(a_domain): with pytest.raises(KeyError, match=r"No Dimension of type .* is present in the Domain."): - domain[ECDim] + a_domain[ECDim] -def test_domain_indexing_invalid_type(domain): +def test_domain_indexing_invalid_type(a_domain): with pytest.raises( KeyError, match="Invalid index type, must be either int, slice, or Dimension." ): - domain["foo"] + a_domain["foo"] def test_domain_repeat_dims(): dims = (IDim, JDim, IDim) ranges = (UnitRange(0, 5), UnitRange(0, 8), UnitRange(0, 3)) with pytest.raises(NotImplementedError, match=r"Domain dimensions must be unique, not .*"): - Domain(dims, ranges) + Domain(dims=dims, ranges=ranges) def test_domain_dims_ranges_length_mismatch():