-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: add where
to embedded field view
#1316
Conversation
return _builtin_binary_op | ||
domain_intersection = functools.reduce( | ||
operator.and_, | ||
[f.domain for f in fields if common.is_field(f)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[f.domain for f in fields if common.is_field(f)], | |
[f.domain for f in fields if not is_scalar(f)], |
I see, intuitively I would say it should ignore scalars then and not just take everything that is a field. Both for readability and robustness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, typing says f
is Field or Scalar, and of the 2 we select the one which has .domain
, i.e. is_field
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is technically true (and a valid viewpoint). My perspective is the following: The resulting domain is the intersection of all input domains. For scalars then there is a (conceptual) broadcast to a zero-dimensional field. So for me the domain intersection is everything, but scalars. When I read this I intermediately thought why do only fields participate in the computation of the domain (and not all inputs) and my answer was because they are not scalars. There is an assert below that ensures the elements are either a field or a scalar so I am fine with this.
if f.domain == domain_intersection: | ||
transformed.append(xp.asarray(f.ndarray)) | ||
else: | ||
f_broadcasted = _broadcast(f, domain_intersection.dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the broadcasted array is materialized here. That seems unnecessary overhead. Can we at least add a todo to clean this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would you do that? Maybe I miss something obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using np.newaxis
which is apparently what is being done in _broadcast
... I'm a little skeptical that f_broadcasted.ndarray[f_slices]
works properly in all cases, but we'll see when you write the tests :-P
assert np.allclose(result.ndarray, expected) | ||
|
||
|
||
def test_where_builtin_with_tuple(nd_array_implementation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a test where the true and false branches are on different domains (different sizes to test intersection, and different dimensions to test the broadcast) is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, the pr is older than that feature 🙃
src/gt4py/next/ffront/fbuiltins.py
Outdated
def where_builtin_function( | ||
fun: Callable[[MaskT, FieldT, FieldT], _R] | ||
) -> WhereBuiltinFunction[_R, MaskT, FieldT]: | ||
return WhereBuiltinFunction(fun) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function really needed? Wouldn't work exactly the same using the WhereBuiltinFunction
class as a decorator (which would call the constructor anyway)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right...
where
with tuples