Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: itir.embedded: fix shift inside scan pass #1280

Merged
merged 7 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,6 @@ class ColumnDescriptor:
class ScanArgIterator:
wrapped_iter: ItIterator
k_pos: int
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
if not self.can_deref():
Expand All @@ -1196,7 +1195,7 @@ def can_deref(self) -> bool:
return self.wrapped_iter.can_deref()

def shift(self, *offsets: OffsetPart) -> ScanArgIterator:
return ScanArgIterator(self.wrapped_iter, self.k_pos, offsets=[*offsets, *self.offsets])
return ScanArgIterator(self.wrapped_iter.shift(*offsets), self.k_pos)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the important part.



def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]:
Expand Down Expand Up @@ -1392,7 +1391,12 @@ def _closure_runner():
col_pos[column.axis] = k
assert _is_concrete_position(col_pos)
ordered_indices = get_ordered_indices(out.axes, col_pos)
out.field_setitem(ordered_indices, res[k])
if isinstance(res, tuple):
out.field_setitem(
ordered_indices, tuple(res[i][k] for i in range(len(res)))
) # TODO(tehrengruber): only works for scalars
else:
out.field_setitem(ordered_indices, res[k])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this fixed in https://github.com/GridTools/gt4py/pull/1141/files? or is it another instance of that problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure actually, might be two different things. This is not clean however as it only works for tuples not nested tuples.


ctx = cvars.copy_context()
ctx.run(_closure_runner)
Expand Down
27 changes: 18 additions & 9 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ def _inline_lifts(ir, lift_mode):
return ir


def _inline_into_scan(ir, *, max_iter=10):
for _ in range(10):
# in case there are multiple levels of lambdas around the scan we have to do multiple iterations
inlined = InlineIntoScan().visit(ir)
inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift=True)
if inlined == ir:
break
ir = inlined
else:
raise RuntimeError(f"Inlining into scan did not converge with {max_iter} iterations.")
return ir


def apply_common_transforms(
ir: ir.Node,
*,
Expand Down Expand Up @@ -87,15 +100,7 @@ def apply_common_transforms(

if lift_mode == LiftMode.FORCE_INLINE:
ir = CollapseTuple.apply(ir, ignore_tuple_size=unconditionally_collapse_tuples)
for _ in range(10):
# in case there are multiple levels of lambdas around the scan we have to do multiple iterations
inlined = InlineIntoScan().visit(ir)
inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift=True)
if inlined == ir:
break
ir = inlined
else:
raise RuntimeError("Inlining into scan did not converge.")
ir = _inline_into_scan(ir)

ir = NormalizeShifts().visit(ir)

Expand All @@ -118,6 +123,10 @@ def apply_common_transforms(
assert offset_provider is not None
ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))`
ir = _inline_into_scan(ir)

ir = EtaReduction().visit(ir)
ir = ScanEtaReduction().visit(ir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ def testee_op(
cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected)


def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures
@gtx.scan_operator(axis=KDim, forward=True, init=(0.0))
def testee_scan(state: float, inp: float) -> float:
return inp

@gtx.field_operator
def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
return testee_scan(inp(Koff[1]))

inp = cases.allocate(
cartesian_case,
testee,
"inp",
extend={KDim: (0, 1)},
strategy=cases.UniqueInitializer(start=1),
)()
out = cases.allocate(cartesian_case, testee, "inp").zeros()()
ksize = cartesian_case.default_sizes[KDim]
expected = np.full((ksize), np.arange(start=1, stop=ksize + 1, step=1).astype(float64))

cases.run(cartesian_case, testee, inp, out=out)

cases.verify(cartesian_case, testee, inp, out=out, ref=expected)


def test_astype_int(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_scan_in_stencil(program_processor, lift_mode):
isize = 1
ksize = 3
Koff = offset("Koff")
inp = gtx.np_as_located_field(IDim, KDim)(np.ones((isize, ksize)))
inp = gtx.np_as_located_field(IDim, KDim)(
np.copy(np.broadcast_to(np.arange(0, ksize), (isize, ksize)))
)
out = gtx.np_as_located_field(IDim, KDim)(np.zeros((isize, ksize)))

reference = np.zeros((isize, ksize - 1))
Expand All @@ -40,13 +42,9 @@ def test_scan_in_stencil(program_processor, lift_mode):
def sum(state, k, kp):
return state + deref(k) + deref(kp)

@fundef
def shifted(inp):
return deref(shift(Koff, 1)(inp))

@fundef
def wrapped(inp):
return scan(sum, True, 0.0)(inp, lift(shifted)(inp))
return scan(sum, True, 0.0)(inp, shift(Koff, 1)(inp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change breaks something else, if we undo we could merge.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this issue by adding another inline_into_scan pass after temporaries are created

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We concluded that this is fine for now even though we are not sure whether the ordering of the passes makes sense anymore. This will be addressed as a whole in another PR.


run_processor(
wrapped[cartesian_domain(named_range(IDim, 0, isize), named_range(KDim, 0, ksize - 1))],
Expand Down
13 changes: 0 additions & 13 deletions tests/next_tests/unit_tests/ffront_tests/test_icon_like_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ class setup:
def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):
if fieldview_backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]:
pytest.xfail("Needs implementation of scan projector.")
if fieldview_backend == roundtrip.executor:
pytest.xfail("Inline into scan breaks embedded execution.")

solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)(
test_setup.z_alpha,
Expand All @@ -230,11 +228,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):


def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail(
"Inline into scan breaks embedded execution and relies on CollapseTuple ignore_tuple_size==True."
)

solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)(
test_setup.z_alpha,
test_setup.z_beta,
Expand All @@ -248,9 +241,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):


def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail("Inline into scan breaks embedded execution.")

solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)(
test_setup.z_alpha,
test_setup.z_beta,
Expand All @@ -265,9 +255,6 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):


def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail("Only working in gtfn with CollapseTuple ignore_tuple_size==True.")

solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)(
test_setup.z_alpha,
test_setup.z_beta,
Expand Down