Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Backend in Transpiler #28860

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# global
import ivy


# local
import gast
from ...configurations.base_transformer_config import (
BaseTransformerConfig,
)
from ...transformer import Transformer
from ....utils.ast_utils import (
ast_to_source_code,
)
from ....utils.api_utils import (
get_native_array_str_from_backend,
get_native_module_str_from_backend,
)
from .ivy_postprocessing_transformer import (
IvyCodePostProcessor,
)


class IvyToTorchCodePostProcessor(IvyCodePostProcessor):
"""
Perform post-processing for PyTorch backend.
"""

def __init__(
self,
root,
transformer: Transformer,
configuration: BaseTransformerConfig,
new_name="tensor",
) -> None:
super().__init__(root, transformer, configuration, new_name=new_name)
self.root = root
self.transformer = transformer
self.configuration = configuration

def _handle_ivy_array(self, node):
new_name = get_native_array_str_from_backend(ivy.backend)
return gast.parse(f"{ivy.backend}.{new_name}").body[0].value

def _handle_ivy_variable(self, node):
return gast.parse("torch.nn.Parameter").body[0].value

def _handle_ivy_module(self, node):
new_name = get_native_module_str_from_backend(
backend_str=ivy.backend,
is_root_obj=self.transformer.object_like.is_root_obj,
depth=self.transformer.object_like.depth,
)
new_name = new_name.replace(".", "_")
return gast.parse(f"{new_name}").body[0].value

def _handle_assign_transform(self, node):
return gast.Call(
func=gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="nn.Parameter",
ctx=gast.Load(),
),
args=node.value.args,
keywords=node.value.keywords,
)

def _transform_isinstance_check(self, node):
"""
if not isinstance(module, torch_nn_Module) --> if not isinstance(module, (torch_nn_Module, torch.nn.Module))
"""
new_args = [
node.args[0],
gast.Tuple(
elts=[
node.args[1],
gast.parse("torch.nn.Module").body[0].value,
],
ctx=gast.Load(),
),
]
node.args = new_args
return node

def _get_forward_name(self, node):
return "forward"

def _maybe_convert_device_attribute(self, node):
# For PyTorch, device is a property that can be accessed and modified
# No special handling needed for device in PyTorch
return node

def _maybe_replace_with_native_array_calls(self, node):
func_str = ast_to_source_code(node.func).strip()
if func_str in ("torch.Tensor", "Tensor", "ivy.Array"):
new_func = gast.Attribute(
value=gast.Name(
id="torch",
ctx=gast.Load(),
annotation=None,
type_comment=None,
),
attr="tensor",
ctx=gast.Load(),
)
node.func = gast.fix_missing_locations(new_func)
return node

def _replace_ivy_array_pattern(self, elts):
"""
Transform the type check argument of an isinstance call
to replace any occurrence of (ivy.Array, ivy.Array) with
(torch.Tensor, torch.nn.Parameter).
"""
# Pattern to look for: (ivy.Array, ivy.Array)
pattern = [
gast.Attribute(
value=gast.Name(id="ivy", ctx=gast.Load()),
attr="Array",
ctx=gast.Load(),
),
gast.Attribute(
value=gast.Name(id="ivy", ctx=gast.Load()),
attr="Array",
ctx=gast.Load(),
),
]

# Serialize the pattern into a string
pattern_dump = [gast.dump(node) for node in pattern]

# Traverse through the elements and replace any matching pattern
transformed_elts = []
i = 0
while i < len(elts):
# Serialize current slice of elements and compare with pattern_dump
elts_dump = [gast.dump(node) for node in elts[i : i + 2]]
if elts_dump == pattern_dump: # Check if we found the pattern
# Replace the pattern with (torch.Tensor, torch.nn.Parameter)
transformed_elts.extend(
[
gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="Tensor",
ctx=gast.Load(),
),
gast.Attribute(
value=gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="nn",
ctx=gast.Load(),
),
attr="Parameter",
ctx=gast.Load(),
),
]
)
i += 2 # Skip the matched elements
else:
transformed_elts.append(elts[i])
i += 1

return transformed_elts

def _maybe_modify_inplace_update_fn(self, node):
# Check if the function name contains "inplace_update"
if "inplace_update" in node.name:
# Step 1: Modify the default value of keep_input_dtype to True
self._modify_keep_input_dtype_kwarg(node)

# Step 2: Modify assignment nodes to use val_native on the right-hand side
self._modify_assignments_to_val_native(node)

# Step 3: Replace conditional blocks with direct assignment for PyTorch
self._replace_conditional_blocks_for_torch(node)

return node

def _modify_keep_input_dtype_kwarg(self, node):
"""Step 1: Modify keep_input_dtype kwarg default value to True in inplace update signature."""
for kwarg, default in zip(node.args.kwonlyargs, node.args.kw_defaults):
if ast_to_source_code(kwarg).strip() == "keep_input_dtype":
# Modify default value to True if it exists
if isinstance(default, gast.Constant):
default.value = True
break

def _modify_assignments_to_val_native(self, node):
"""Step 2: Modify assignment nodes to use val_native on the RHS in inplace_update body."""

class AssignVisitor(gast.NodeTransformer):
def visit_Assign(self, assign_node):
for target in assign_node.targets:
if ast_to_source_code(target).strip() == "x":
# Modify the right-hand side to use val_native (keep function calls)
val_native_node = gast.Name(id="val_native", ctx=gast.Load())
# If the right-hand side is a function call, replace its first argument with "val_native"
if isinstance(assign_node.value, gast.Call):
assign_node.value.args[0] = val_native_node
else:
# Otherwise, replace the entire right-hand side with "val_native"
assign_node.value = val_native_node
self.generic_visit(assign_node)
return assign_node

AssignVisitor().visit(node)

def _replace_conditional_blocks_for_torch(self, node):
"""Step 3: Replace conditional blocks with direct assignment for PyTorch."""

class IfVisitor(gast.NodeTransformer):
def visit_If(self, if_node):
# Check if this is the specific conditional we want to replace
condition_str = ast_to_source_code(if_node.test).strip()
if "torch_is_ivy_array_bknd" in condition_str or "is_ivy_array_bknd" in condition_str:
# Replace with a direct assignment: x = x_native
return gast.Assign(
targets=[gast.Name(id="x", ctx=gast.Store())],
value=gast.Name(id="x_native", ctx=gast.Load()),
type_comment=None,
)
self.generic_visit(if_node)
return if_node

IfVisitor().visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
from ...transformations.transformers.postprocessing_transformer.ivy_to_numpy_postprocessing_transformer import (
IvyToNumpyCodePostProcessor,
)
from ...transformations.transformers.postprocessing_transformer.ivy_to_torch_postprocessing_transformer import (
IvyToTorchCodePostProcessor,
)
from ...transformations.transformers.recursive_transformer.ivy_recursive_transformer import (
IvyRecurser,
)
Expand All @@ -75,6 +78,7 @@ class IvyToSourceTranslatorConfig(BaseTranslatorConfig):
IvyToTFCodePostProcessor: IvyCodePostProcessorConfig,
IvyToJAXCodePostProcessor: IvyCodePostProcessorConfig,
IvyToNumpyCodePostProcessor: IvyCodePostProcessorConfig,
IvyToTorchCodePostProcessor: IvyCodePostProcessorConfig,
}

def __init__(self, source="ivy", target="tensorflow", base_output_dir="") -> None:
Expand Down Expand Up @@ -122,6 +126,21 @@ def __init__(self, source="ivy", target="tensorflow", base_output_dir="") -> Non
PytorchToFlaxLayer,
HFPretrainedFlaxTransformer,
]
elif target == "torch":
self.transformers: List[BaseTransformer] = [
IvyNodeDeleter,
IvyDecoratorRemover,
# BaseTypeHintRemover,
BaseDocstringRemover,
# BaseTypeAnnotationRemover,
IvyMethodToFunctionConverter,
BaseDundersTransformer,
IvyCodePreProcessor,
BaseNameCanonicalizer,
BaseGlobalsTransformer,
IvyRecurser,
IvyToTorchCodePostProcessor,
]
elif target == "numpy":
self.transformers: List[BaseTransformer] = [
IvyNodeDeleter,
Expand Down