Skip to content

Commit

Permalink
bug[next]: itir.embedded: fix shift inside scan pass (#1280)
Browse files Browse the repository at this point in the history
Shifts were ignored in the `ScanArgIterator`, i.e. all shifts in the scan_pass were ignored. This partially fixes the `test_icon_like_scan`.

Additional:
- Run `InlineIntoScan` after temporary creation to inline remaining operations that block bubbling scans to the top.
  • Loading branch information
tehrengruber authored Jul 4, 2023
1 parent f0bc756 commit a61efd0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,6 @@ class ColumnDescriptor:
class ScanArgIterator:
wrapped_iter: ItIterator
k_pos: int
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
if not self.can_deref():
Expand All @@ -1196,7 +1195,7 @@ def can_deref(self) -> bool:
return self.wrapped_iter.can_deref()

def shift(self, *offsets: OffsetPart) -> ScanArgIterator:
return ScanArgIterator(self.wrapped_iter, self.k_pos, offsets=[*offsets, *self.offsets])
return ScanArgIterator(self.wrapped_iter.shift(*offsets), self.k_pos)


def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]:
Expand Down
27 changes: 18 additions & 9 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ def _inline_lifts(ir, lift_mode):
return ir


def _inline_into_scan(ir, *, max_iter=10):
for _ in range(10):
# in case there are multiple levels of lambdas around the scan we have to do multiple iterations
inlined = InlineIntoScan().visit(ir)
inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift=True)
if inlined == ir:
break
ir = inlined
else:
raise RuntimeError(f"Inlining into scan did not converge with {max_iter} iterations.")
return ir


def apply_common_transforms(
ir: ir.Node,
*,
Expand Down Expand Up @@ -87,15 +100,7 @@ def apply_common_transforms(

if lift_mode == LiftMode.FORCE_INLINE:
ir = CollapseTuple.apply(ir, ignore_tuple_size=unconditionally_collapse_tuples)
for _ in range(10):
# in case there are multiple levels of lambdas around the scan we have to do multiple iterations
inlined = InlineIntoScan().visit(ir)
inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift=True)
if inlined == ir:
break
ir = inlined
else:
raise RuntimeError("Inlining into scan did not converge.")
ir = _inline_into_scan(ir)

ir = NormalizeShifts().visit(ir)

Expand All @@ -118,6 +123,10 @@ def apply_common_transforms(
assert offset_provider is not None
ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))`
ir = _inline_into_scan(ir)

ir = EtaReduction().visit(ir)
ir = ScanEtaReduction().visit(ir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ def testee_op(
cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected)


def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures
@gtx.scan_operator(axis=KDim, forward=True, init=(0.0))
def testee_scan(state: float, inp: float) -> float:
return inp

@gtx.field_operator
def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
return testee_scan(inp(Koff[1]))

inp = cases.allocate(
cartesian_case,
testee,
"inp",
extend={KDim: (0, 1)},
strategy=cases.UniqueInitializer(start=2),
)()
out = cases.allocate(cartesian_case, testee, "inp").zeros()()
ksize = cartesian_case.default_sizes[KDim]
expected = np.full((ksize), np.arange(start=3, stop=ksize + 3, step=1).astype(float64))

cases.run(cartesian_case, testee, inp, out=out)

cases.verify(cartesian_case, testee, inp, out=out, ref=expected)


def test_astype_int(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_scan_in_stencil(program_processor, lift_mode):
isize = 1
ksize = 3
Koff = offset("Koff")
inp = gtx.np_as_located_field(IDim, KDim)(np.ones((isize, ksize)))
inp = gtx.np_as_located_field(IDim, KDim)(
np.copy(np.broadcast_to(np.arange(0, ksize), (isize, ksize)))
)
out = gtx.np_as_located_field(IDim, KDim)(np.zeros((isize, ksize)))

reference = np.zeros((isize, ksize - 1))
Expand All @@ -40,13 +42,9 @@ def test_scan_in_stencil(program_processor, lift_mode):
def sum(state, k, kp):
return state + deref(k) + deref(kp)

@fundef
def shifted(inp):
return deref(shift(Koff, 1)(inp))

@fundef
def wrapped(inp):
return scan(sum, True, 0.0)(inp, lift(shifted)(inp))
return scan(sum, True, 0.0)(inp, shift(Koff, 1)(inp))

run_processor(
wrapped[cartesian_domain(named_range(IDim, 0, isize), named_range(KDim, 0, ksize - 1))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ class setup:
def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):
if fieldview_backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]:
pytest.xfail("Needs implementation of scan projector.")
if fieldview_backend == roundtrip.executor:
pytest.xfail("Inline into scan breaks embedded execution.")

solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)(
test_setup.z_alpha,
Expand All @@ -231,9 +229,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):

def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail(
"Inline into scan breaks embedded execution and relies on CollapseTuple ignore_tuple_size==True."
)
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")

solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)(
test_setup.z_alpha,
Expand All @@ -248,9 +244,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):


def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail("Inline into scan breaks embedded execution.")

solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)(
test_setup.z_alpha,
test_setup.z_beta,
Expand All @@ -266,7 +259,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):

def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend):
if fieldview_backend == roundtrip.executor:
pytest.xfail("Only working in gtfn with CollapseTuple ignore_tuple_size==True.")
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")

solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)(
test_setup.z_alpha,
Expand Down

0 comments on commit a61efd0

Please sign in to comment.