Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Extend Single Static Assignment (SSA) pass to support if statements #1250

Merged
merged 6 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 185 additions & 20 deletions src/gt4py/next/ffront/ast_passes/single_static_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,105 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations

import ast
import typing

from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES


def _make_assign(target: str, source: str, location_node: ast.AST):
result = ast.Assign(
targets=[ast.Name(ctx=ast.Store(), id=target)], value=ast.Name(ctx=ast.Load(), id=source)
)
for node in ast.walk(result):
ast.copy_location(node, location_node)
return result


def is_guaranteed_to_return(node: ast.stmt | list[ast.stmt]) -> bool:
if isinstance(node, list):
return any(is_guaranteed_to_return(child) for child in node)
if isinstance(node, ast.Return):
return True
if isinstance(node, ast.If):
return is_guaranteed_to_return(node.body) and is_guaranteed_to_return(node.orelse)
return False


class Versioning:
"""Helper class to keep track of whether versioning (definedness)."""

# invariant: if a version is an `int`, it's not negative
_versions: dict[str, None | int]

def __init__(self):
self._versions = {}

def define(self, name: str) -> None:
if name not in self._versions:
self._versions[name] = None

def assign(self, name: str) -> None:
if self.is_versioned(name):
self._versions[name] = typing.cast(int, self._versions[name]) + 1
else:
self._versions[name] = 0

def is_defined(self, name: str) -> bool:
return name in self._versions

def is_versioned(self, name: str) -> bool:
return self.is_defined(name) and self._versions[name] is not None

def __getitem__(self, name: str) -> None | int:
return self._versions[name]

def __iter__(self) -> typing.Iterator[tuple[str, None | int]]:
return iter(self._versions.items())

def copy(self) -> Versioning:
copy = Versioning()
copy._versions = {**self._versions}
return copy

@staticmethod
def merge(a: Versioning, b: Versioning) -> Versioning:
versions_a, version_b = a._versions, b._versions
names = set(versions_a.keys()) & set(version_b.keys())

merged_versioning = Versioning()
merged_versions = merged_versioning._versions

for name in names:
merged_versions[name] = Versioning._merge_versions(versions_a[name], version_b[name])

return merged_versioning

@staticmethod
def _merge_versions(a: None | int, b: None | int) -> None | int:
if a is None:
return b
elif b is None:
return a
return max(a, b)


class NameEncoder:
"""Helper class to encode names of versioned variables."""

_separator: str

def __init__(self, separator: str):
self._separator = separator

def encode_name(self, name: str, versions: Versioning) -> str:
if versions.is_versioned(name):
return f"{name}{self._separator}{versions[name]}"
return name


class SingleStaticAssignPass(ast.NodeTransformer):
"""
Rename variables in assignments to avoid overwriting.
Expand Down Expand Up @@ -48,11 +141,16 @@ def foo():
a__2 = 3 + a__1
return a__2

Note that each variable name is assigned only once and never updated / overwritten.
Note that each variable name is assigned only once (per branch) and never updated / overwritten.

Note also that after parsing, running the pass and unparsing we get invalid but
readable python code. This is ok because this pass is not intended for
python-to-python translation.

WARNING: This pass is not intended as a general-purpose SSA transformation.
The pass does not support any general Python AST. Known limitations include:
* Nested functions aren't supported
* While loops aren't supported
"""

class RhsRenamer(ast.NodeTransformer):
Expand All @@ -63,17 +161,16 @@ class RhsRenamer(ast.NodeTransformer):
"""

@classmethod
def apply(cls, name_counter, separator, node):
return cls(name_counter, separator).visit(node)
def apply(cls, versioning: Versioning, name_encoder: NameEncoder, node: ast.AST):
return cls(versioning, name_encoder).visit(node)

def __init__(self, name_counter, separator):
def __init__(self, versioning: Versioning, name_encoder: NameEncoder):
super().__init__()
self.name_counter: dict[str, int] = name_counter
self.separator: str = separator
self.versioning: Versioning = versioning
self.name_encoder: NameEncoder = name_encoder

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in self.name_counter:
node.id = f"{node.id}{self.separator}{self.name_counter[node.id]}"
node.id = self.name_encoder.encode_name(node.id, self.versioning)
return node

@classmethod
Expand All @@ -82,11 +179,78 @@ def apply(cls, node: ast.AST) -> ast.AST:

def __init__(self, separator="__"):
super().__init__()
self.name_counter: dict[str, int] = {}
self.separator: str = separator
self.versioning: Versioning = Versioning()
self.name_encoder: NameEncoder = NameEncoder(separator)

def _rename(self, node: ast.AST):
return self.RhsRenamer.apply(self.versioning, self.name_encoder, node)

def visit_FunctionDef(self, node: ast.FunctionDef):
# For practical purposes, this is sufficient, but really not general at all.
# However, the algorithm was never intended to be general.

old_versioning = self.versioning.copy()

def _rename(self, node):
return self.RhsRenamer.apply(self.name_counter, self.separator, node)
for arg in node.args.args:
self.versioning.define(arg.arg)

node.body = [self.visit(stmt) for stmt in node.body]

self.versioning = old_versioning
return node

def visit_If(self, node: ast.If) -> ast.If:
old_versioning = self.versioning

node.test = self._rename(node.test)

self.versioning = old_versioning.copy()
node.body = [self.visit(el) for el in node.body]
body_versioning = self.versioning
body_returns = is_guaranteed_to_return(node.body)

self.versioning = old_versioning.copy()
node.orelse = [self.visit(el) for el in node.orelse]
orelse_versioning = self.versioning
orelse_returns = is_guaranteed_to_return(node.orelse)

if body_returns and not orelse_returns:
self.versioning = orelse_versioning
return node

if orelse_returns and not body_returns:
self.versioning = body_versioning
return node

if body_returns and orelse_returns:
self.versioning = Versioning()
return node

assert not body_returns and not orelse_returns

self.versioning = Versioning.merge(body_versioning, orelse_versioning)

# ensure both branches conclude with the same unique names
for name, merged_version in self.versioning:
body_version = body_versioning[name]
orelse_version = orelse_versioning[name]

if body_version != merged_version:
new_assign = _make_assign(
self.name_encoder.encode_name(name, self.versioning),
self.name_encoder.encode_name(name, body_versioning),
node,
)
node.body.append(new_assign)
elif orelse_version != merged_version:
new_assign = _make_assign(
self.name_encoder.encode_name(name, self.versioning),
self.name_encoder.encode_name(name, orelse_versioning),
node,
)
node.orelse.append(new_assign)

return node

def visit_Assign(self, node: ast.Assign) -> ast.Assign:
# first update rhs names to reference the latest version
Expand All @@ -104,18 +268,19 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
node.value = self._rename(node.value)
node.target = self.visit(node.target)
elif isinstance(node.target, ast.Name):
target_id = node.target.id
# An empty annotation always applies to the next assignment.
# So we need to use the correct versioning, but also ensure
# we restore the old versioning afterwards, because no assignment
# actually happens.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand this comment, could you explain it to me?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand the details, but this is referring to this case (test_empty_annotated_assign):

a = 0
a: int
b = a

If you want I take a closer look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understood the same core idea but don't really get the details. I think we should understand what the comment really means and rephrase it in a more comprehensible way before merging the PR, otherwise is useless.

old_versioning = self.versioning.copy()
node.target = self.visit(node.target)
self.name_counter[target_id] -= 1
self.versioning = old_versioning
return node

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in TYPE_BUILTIN_NAMES:
return node
elif node.id in self.name_counter:
self.name_counter[node.id] += 1
node.id = f"{node.id}{self.separator}{self.name_counter[node.id]}"
else:
self.name_counter[node.id] = 0
node.id = f"{node.id}{self.separator}0"

self.versioning.assign(node.id)
node.id = self.name_encoder.encode_name(node.id, self.versioning)
return node
Loading