-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug[next]: itir.embedded: fix shift inside scan pass #1280
Changes from 3 commits
16990ec
d4160bb
12a6ae9
1fea44a
ccaf485
04efcff
e91b4be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1185,7 +1185,6 @@ class ColumnDescriptor: | |
class ScanArgIterator: | ||
wrapped_iter: ItIterator | ||
k_pos: int | ||
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) | ||
|
||
def deref(self) -> Any: | ||
if not self.can_deref(): | ||
|
@@ -1196,7 +1195,7 @@ def can_deref(self) -> bool: | |
return self.wrapped_iter.can_deref() | ||
|
||
def shift(self, *offsets: OffsetPart) -> ScanArgIterator: | ||
return ScanArgIterator(self.wrapped_iter, self.k_pos, offsets=[*offsets, *self.offsets]) | ||
return ScanArgIterator(self.wrapped_iter.shift(*offsets), self.k_pos) | ||
|
||
|
||
def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]: | ||
|
@@ -1392,7 +1391,12 @@ def _closure_runner(): | |
col_pos[column.axis] = k | ||
assert _is_concrete_position(col_pos) | ||
ordered_indices = get_ordered_indices(out.axes, col_pos) | ||
out.field_setitem(ordered_indices, res[k]) | ||
if isinstance(res, tuple): | ||
out.field_setitem( | ||
ordered_indices, tuple(res[i][k] for i in range(len(res))) | ||
) # TODO(tehrengruber): only works for scalars | ||
else: | ||
out.field_setitem(ordered_indices, res[k]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this fixed in https://github.com/GridTools/gt4py/pull/1141/files? or is it another instance of that problem? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure actually, might be two different things. This is not clean however as it only works for tuples not nested tuples. |
||
|
||
ctx = cvars.copy_context() | ||
ctx.run(_closure_runner) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,9 @@ def test_scan_in_stencil(program_processor, lift_mode): | |
isize = 1 | ||
ksize = 3 | ||
Koff = offset("Koff") | ||
inp = gtx.np_as_located_field(IDim, KDim)(np.ones((isize, ksize))) | ||
inp = gtx.np_as_located_field(IDim, KDim)( | ||
np.copy(np.broadcast_to(np.arange(0, ksize), (isize, ksize))) | ||
) | ||
out = gtx.np_as_located_field(IDim, KDim)(np.zeros((isize, ksize))) | ||
|
||
reference = np.zeros((isize, ksize - 1)) | ||
|
@@ -40,13 +42,9 @@ def test_scan_in_stencil(program_processor, lift_mode): | |
def sum(state, k, kp): | ||
return state + deref(k) + deref(kp) | ||
|
||
@fundef | ||
def shifted(inp): | ||
return deref(shift(Koff, 1)(inp)) | ||
|
||
@fundef | ||
def wrapped(inp): | ||
return scan(sum, True, 0.0)(inp, lift(shifted)(inp)) | ||
return scan(sum, True, 0.0)(inp, shift(Koff, 1)(inp)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change breaks something else, if we undo we could merge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I addressed this issue by adding another inline_into_scan pass after temporaries are created There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We concluded that this is fine for now even though we are not sure whether the ordering of the passes makes sense anymore. This will be addressed as a whole in another PR. |
||
|
||
run_processor( | ||
wrapped[cartesian_domain(named_range(IDim, 0, isize), named_range(KDim, 0, ksize - 1))], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the important part.