-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug[next]: Fix shift / remap lowering #1231
Changes from 2 commits
8e4eccd
6611614
4ebba26
2cc8c7f
addd51f
1c5cc8d
cc14de9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -546,6 +546,29 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: | |
arg_types = [arg.type for arg in new_args] | ||
kwarg_types = {name: arg.type for name, arg in new_kwargs.items()} | ||
|
||
func_str_repr: str # just for error handling | ||
if isinstance( | ||
new_func.type, | ||
(ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), | ||
): | ||
# Since we use the `id` attribute in the latter part of the toolchain ensure we | ||
# have the proper format here. | ||
if not isinstance( | ||
new_func, | ||
(foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), | ||
): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, msg="Functions can only be called directly!" | ||
) | ||
func_str_repr = new_func.id | ||
elif isinstance(new_func.type, ts.FieldType): | ||
func_str_repr = str(new_func) | ||
else: | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", | ||
) | ||
|
||
# ensure signature is valid | ||
try: | ||
type_info.accepts_args( | ||
|
@@ -556,7 +579,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: | |
) | ||
except GTTypeError as err: | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, msg=f"Invalid argument types in call to `{node.func.id}`!" | ||
node, msg=f"Invalid argument types in call to `{func_str_repr}`!" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once you have implemented the change that you wanted to make to avoid func_str_repr, it should be good to go. |
||
) from err | ||
|
||
return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) | ||
|
@@ -571,34 +594,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: | |
|
||
if ( | ||
isinstance(new_func.type, ts.FunctionType) | ||
and isinstance(new_func, foast.Name) | ||
and new_func.id in fbuiltins.MATH_BUILTIN_NAMES | ||
): | ||
return self._visit_math_built_in(new_node, **kwargs) | ||
elif ( | ||
isinstance(new_func.type, ts.FunctionType) | ||
and not type_info.is_concrete(return_type) | ||
and isinstance(new_func, foast.Name) | ||
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES | ||
): | ||
visitor = getattr(self, f"_visit_{new_func.id}") | ||
return visitor(new_node, **kwargs) | ||
|
||
return new_node | ||
|
||
def _ensure_signature_valid(self, node: foast.Call, **kwargs) -> None: | ||
try: | ||
type_info.accepts_args( | ||
cast(ts.FunctionType, node.func.type), | ||
with_args=[arg.type for arg in node.args], | ||
with_kwargs={keyword: arg.type for keyword, arg in node.kwargs.items()}, | ||
raise_exception=True, | ||
) | ||
except GTTypeError as err: | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, msg=f"Invalid argument types in call to `{node.func.id}`!" | ||
) from err | ||
|
||
def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: | ||
func_name = node.func.id | ||
func_name = cast(foast.Name, node.func).id | ||
|
||
# validate arguments | ||
error_msg_preamble = f"Incompatible argument in call to `{func_name}`." | ||
|
@@ -670,7 +682,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: | |
field_dims_str = ", ".join(str(dim) for dim in field_type.dims) | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible field argument in call to `{node.func.id}`. " | ||
msg=f"Incompatible field argument in call to `{str(node.func)}`. " | ||
f"Expected a field with dimension {reduction_dim}, but got " | ||
f"{field_dims_str}.", | ||
) | ||
|
@@ -726,15 +738,15 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: | |
if not type_info.is_integral(arg_1): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible argument in call to `{node.func.id}`. " | ||
msg=f"Incompatible argument in call to `{str(node.func)}`. " | ||
f"Excepted integer for offset field dtype, but got {arg_1.dtype}" | ||
f"{node.location}", | ||
) | ||
|
||
if arg_0.source not in arg_1.dims: | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible argument in call to `{node.func.id}`. " | ||
msg=f"Incompatible argument in call to `{str(node.func)}`. " | ||
f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " | ||
f"{node.location}", | ||
) | ||
|
@@ -755,7 +767,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: | |
if not type_info.is_logical(mask_type): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible argument in call to `{node.func.id}`. Expected " | ||
msg=f"Incompatible argument in call to `{str(node.func)}`. Expected " | ||
f"a field with dtype bool, but got `{mask_type}`.", | ||
) | ||
|
||
|
@@ -773,7 +785,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: | |
): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Return arguments need to be of same type in {node.func.id}, but got: " | ||
msg=f"Return arguments need to be of same type in {str(node.func)}, but got: " | ||
f"{node.args[1].type} and {node.args[2].type}", | ||
) | ||
else: | ||
|
@@ -785,7 +797,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: | |
except GTTypeError as ex: | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible argument in call to `{node.func.id}`.", | ||
msg=f"Incompatible argument in call to `{str(node.func)}`.", | ||
) from ex | ||
|
||
return foast.Call( | ||
|
@@ -803,7 +815,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: | |
if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible broadcast dimension type in {node.func.id}. Expected " | ||
msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected " | ||
f"all broadcast dimensions to be of type Dimension.", | ||
) | ||
|
||
|
@@ -812,7 +824,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: | |
if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): | ||
raise FieldOperatorTypeDeductionError.from_foast_node( | ||
node, | ||
msg=f"Incompatible broadcast dimensions in {node.func.id}. Expected " | ||
msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected " | ||
f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,17 +206,19 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: | |
) | ||
case _: | ||
raise FieldOperatorLoweringError("Unexpected shift arguments!") | ||
return shift_offset(self.visit(node.func, **kwargs)) | ||
return im.lift_(im.lambda__("it")(im.deref_(shift_offset("it"))))( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the reviewer this is the core of this PR. |
||
self.visit(node.func, **kwargs) | ||
) | ||
|
||
def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: | ||
if type_info.type_class(node.func.type) is ts.FieldType: | ||
return self._visit_shift(node, **kwargs) | ||
elif node.func.id in MATH_BUILTIN_NAMES: | ||
elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: | ||
return self._visit_math_built_in(node, **kwargs) | ||
elif node.func.id in FUN_BUILTIN_NAMES: | ||
elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES: | ||
visitor = getattr(self, f"_visit_{node.func.id}") | ||
return visitor(node, **kwargs) | ||
elif node.func.id in TYPE_BUILTIN_NAMES: | ||
elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: | ||
return self._visit_type_constr(node, **kwargs) | ||
elif isinstance( | ||
node.func.type, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This snippet, including the
type_info.accepts_args
is confusing. How does that work in case ofnew_func.type == FieldType
? Let's talk about it.