Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: add where to embedded field view #1316

Merged
merged 11 commits into from
Nov 16, 2023

Conversation

havogt
Copy link
Contributor

@havogt havogt commented Aug 10, 2023

  • unifies unary and binary builtin in NdArrayField
  • special case for where with tuples

@havogt havogt requested a review from tehrengruber August 21, 2023 12:46
return _builtin_binary_op
domain_intersection = functools.reduce(
operator.and_,
[f.domain for f in fields if common.is_field(f)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[f.domain for f in fields if common.is_field(f)],
[f.domain for f in fields if not is_scalar(f)],

I see, intuitively I would say it should ignore scalars then and not just take everything that is a field. Both for readability and robustness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, typing says f is Field or Scalar, and of the 2 we select the one which has .domain, i.e. is_field.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is technically true (and a valid viewpoint). My perspective is the following: The resulting domain is the intersection of all input domains. For scalars then there is a (conceptual) broadcast to a zero-dimensional field. So for me the domain intersection is everything, but scalars. When I read this I intermediately thought why do only fields participate in the computation of the domain (and not all inputs) and my answer was because they are not scalars. There is an assert below that ensures the elements are either a field or a scalar so I am fine with this.

if f.domain == domain_intersection:
transformed.append(xp.asarray(f.ndarray))
else:
f_broadcasted = _broadcast(f, domain_intersection.dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the broadcasted array is materialized here. That seems unnecessary overhead. Can we at least add a todo to clean this up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you do that? Maybe I miss something obvious.

Copy link
Contributor

@tehrengruber tehrengruber Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using np.newaxis which is apparently what is being done in _broadcast... I'm a little skeptical that f_broadcasted.ndarray[f_slices] works properly in all cases, but we'll see when you write the tests :-P

assert np.allclose(result.ndarray, expected)


def test_where_builtin_with_tuple(nd_array_implementation):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a test where the true and false branches are on different domains (different sizes to test intersection, and different dimensions to test the broadcast) is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, the pr is older than that feature 🙃

Comment on lines 154 to 157
def where_builtin_function(
fun: Callable[[MaskT, FieldT, FieldT], _R]
) -> WhereBuiltinFunction[_R, MaskT, FieldT]:
return WhereBuiltinFunction(fun)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function really needed? Wouldn't work exactly the same using the WhereBuiltinFunction class as a decorator (which would call the constructor anyway)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right...

@havogt havogt requested a review from tehrengruber November 16, 2023 09:09
@havogt havogt merged commit b8cda74 into GridTools:main Nov 16, 2023
@havogt havogt deleted the field_view_where branch November 16, 2023 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants