Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: add where to embedded field view #1316

Merged
merged 11 commits into from
Nov 16, 2023
74 changes: 34 additions & 40 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,31 @@
from gt4py.next.ffront import fbuiltins


def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_unary_op(a: _BaseNdArrayField) -> common.Field:
xp = a.__class__.array_ns
def _make_nary_intrinsic(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_op(*fields: common.Field) -> common.Field:
first = fields[0]
assert isinstance(first, _BaseNdArrayField)
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)
new_data = op(a.ndarray)

return a.__class__.from_array(new_data, domain=a.domain)
others_transformed = []
if len(fields) > 1:
for other in fields[1:]:
if hasattr(other, "__gt_builtin_func__"): # isinstance(b, common.Field):
if not first.domain == other.domain:
raise NotImplementedError(
f"support for different domain not implemented: {first.domain}, {other.domain}"
)
others_transformed.append(xp.asarray(other.ndarray))
else:
assert isinstance(other, definitions.SCALAR_TYPES)
others_transformed.append(other)

_builtin_unary_op.__name__ = builtin_name
return _builtin_unary_op
new_data = op(first.ndarray, *others_transformed)
return first.__class__.from_array(new_data, domain=first.domain)


def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field:
xp = a.__class__.array_ns
op = getattr(xp, array_builtin_name)
if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field):
if not a.domain == b.domain:
raise NotImplementedError(
f"support for different domain not implemented: {a.domain}, {b.domain}"
)
new_data = op(a.ndarray, xp.asarray(b.ndarray))
else:
assert isinstance(b, definitions.SCALAR_TYPES)
new_data = op(a.ndarray, b)

return a.__class__.from_array(new_data, domain=a.domain)

_builtin_binary_op.__name__ = builtin_name
return _builtin_binary_op
_builtin_op.__name__ = builtin_name
return _builtin_op


_Value: TypeAlias = common.Field | ScalarT
Expand Down Expand Up @@ -176,23 +171,21 @@ def restrict(self: _BaseNdArrayField, domain) -> _BaseNdArrayField:

__getitem__ = None # type: ignore[assignment] # TODO: restrict

__abs__ = _make_unary_array_field_intrinsic_func("abs", "abs")
__abs__ = _make_nary_intrinsic("abs", "abs")

__neg__ = _make_unary_array_field_intrinsic_func("neg", "negative")
__neg__ = _make_nary_intrinsic("neg", "negative")

__add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add")
__add__ = __radd__ = _make_nary_intrinsic("add", "add")

__sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract")
__sub__ = __rsub__ = _make_nary_intrinsic("sub", "subtract")

__mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply")
__mul__ = __rmul__ = _make_nary_intrinsic("mul", "multiply")

__truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide")
__truediv__ = __rtruediv__ = _make_nary_intrinsic("div", "divide")

__floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func(
"floordiv", "floor_divide"
)
__floordiv__ = __rfloordiv__ = _make_nary_intrinsic("floordiv", "floor_divide")

__pow__ = _make_binary_array_field_intrinsic_func("pow", "power")
__pow__ = _make_nary_intrinsic("pow", "power")


# -- Specialized implementations for intrinsic operations on array fields --
Expand All @@ -209,18 +202,19 @@ def restrict(self: _BaseNdArrayField, domain) -> _BaseNdArrayField:
if name in ["abs", "power", "gamma"]:
continue
_BaseNdArrayField.register_builtin_func(
getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name)
getattr(fbuiltins, name), _make_nary_intrinsic(name, name)
)

_BaseNdArrayField.register_builtin_func(
fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined]
fbuiltins.minimum, _make_nary_intrinsic("minimum", "minimum") # type: ignore[attr-defined]
)
_BaseNdArrayField.register_builtin_func(
fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined]
fbuiltins.maximum, _make_nary_intrinsic("maximum", "maximum") # type: ignore[attr-defined]
)
_BaseNdArrayField.register_builtin_func(
fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined]
fbuiltins.fmod, _make_nary_intrinsic("fmod", "fmod") # type: ignore[attr-defined]
)
_BaseNdArrayField.register_builtin_func(fbuiltins.where, _make_nary_intrinsic("where", "where"))

# -- Concrete array implementations --
# NumPy
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:

def dispatch(self, *args: Any) -> Callable[_P, _R]:
arg_types = tuple(type(arg) for arg in args)
if any(t == tuple for t in arg_types):
return self.function
for atype in arg_types:
# current strategy is to select the implementation of the first arg that supports the operation
# TODO: define a strategy that converts or prevents conversion
Expand Down Expand Up @@ -172,6 +174,10 @@ def where(
false_field: Field | gt4py_defs.ScalarT | Tuple,
/,
) -> Field | Tuple:
if isinstance(true_field, tuple) and isinstance(false_field, tuple):
if len(true_field) != len(false_field):
raise ValueError("Tuple of different size not allowed")
return tuple(where(mask, t, f) for t, f in zip(true_field, false_field))
raise NotImplementedError()


Expand Down
37 changes: 35 additions & 2 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def binary_op(request):
yield request.param


def _make_field(lst: Iterable, nd_array_implementation):
def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None):
if not dtype:
dtype = nd_array_implementation.float32
return common.field(
nd_array_implementation.asarray(lst, dtype=nd_array_implementation.float32),
nd_array_implementation.asarray(lst, dtype=dtype),
domain=((common.Dimension("foo"), common.UnitRange(0, len(lst))),),
)

Expand All @@ -65,6 +67,37 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati
assert np.allclose(result.ndarray, expected)


def test_where_builtin(nd_array_implementation):
cond = np.asarray([True, False])
true_ = np.asarray([1.0, 2.0], dtype=np.float32)
false_ = np.asarray([3.0, 4.0], dtype=np.float32)

field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]]
expected = np.where(cond, true_, false_)

result = fbuiltins.where(*field_inputs)
assert np.allclose(result.ndarray, expected)


def test_where_builtin(nd_array_implementation):
cond = np.asarray([True, False])
true0 = np.asarray([1.0, 2.0], dtype=np.float32)
false0 = np.asarray([3.0, 4.0], dtype=np.float32)
true1 = np.asarray([11.0, 12.0], dtype=np.float32)
false1 = np.asarray([13.0, 14.0], dtype=np.float32)

expected0 = np.where(cond, true0, false0)
expected1 = np.where(cond, true1, false1)

cond_field = _make_field(cond, nd_array_implementation, dtype=bool)
field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1])
field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1])

result = fbuiltins.where(cond_field, field_true, field_false)
assert np.allclose(result[0].ndarray, expected0)
assert np.allclose(result[1].ndarray, expected1)


def test_binary_ops(binary_op, nd_array_implementation):
inp_a = [-1.0, 4.2, 42]
inp_b = [2.0, 3.0, -3.0]
Expand Down