-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF FE]: Support complex tensors for Equal operation #29339
base: master
Are you sure you want to change the base?
Conversation
@@ -289,6 +291,52 @@ ov::Output<ov::Node> ComplexTypeMark::div(const NodeContext& context, | |||
return {result}; | |||
} | |||
|
|||
ov::Output<ov::Node> ComplexTypeMark::equal(const NodeContext& context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move this functionality to common translators. We will re-use it in PyTorch FE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I move it to src\frontends\common_translators\src\op\complex.cpp
or put it in a new file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new file equal.cpp
in src\frontends\common_translators\src\op\
auto lr = lhs_complex->get_real(); | ||
auto li = lhs_complex->get_imag(); | ||
auto rr = rhs_complex->get_real(); | ||
auto ri = rhs_complex->get_imag(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please think how to implement it using get_data()
and how it will simplify this decomposition
} else if (lhs_complex) { | ||
// rhs is of a real type | ||
auto lhs_real = lhs_complex->get_real(); | ||
auto lhs_imag = lhs_complex->get_imag(); | ||
|
||
auto eq_real = context.mark_node(make_shared<v1::Equal>(lhs_real, rhs)); | ||
auto zero_const = context.mark_node(make_shared<v0::Constant>(lhs_imag.get_element_type(), Shape{}, 0)); | ||
auto eq_imag = context.mark_node(make_shared<v1::Equal>(lhs_imag, zero_const)); | ||
|
||
auto result = context.mark_node(make_shared<v1::LogicalAnd>(eq_real, eq_imag)); | ||
return {result}; | ||
} else if (rhs_complex) { | ||
// lhs is of a real type | ||
auto rhs_real = rhs_complex->get_real(); | ||
auto rhs_imag = rhs_complex->get_imag(); | ||
|
||
auto eq_real = context.mark_node(make_shared<v1::Equal>(lhs, rhs_real)); | ||
auto zero_const = context.mark_node(make_shared<v0::Constant>(rhs_imag.get_element_type(), Shape{}, 0)); | ||
auto eq_imag = context.mark_node(make_shared<v1::Equal>(rhs_imag, zero_const)); | ||
|
||
auto result = context.mark_node(make_shared<v1::LogicalAnd>(eq_real, eq_imag)); | ||
return {result}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two cases are not supported by TF. Both operands must be of the same type: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Equal
@@ -8,6 +8,7 @@ | |||
#include <map> | |||
#include <string> | |||
|
|||
#include "common_translators.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed
#include "common_translators.hpp" |
@@ -19,6 +19,8 @@ COMMON_OP_CONVERTER(translate_imag); | |||
COMMON_OP_CONVERTER(translate_atan2); | |||
COMMON_OP_CONVERTER(translate_angle); | |||
|
|||
COMMON_OP_CONVERTER(translate_equal_op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
COMMON_OP_CONVERTER(translate_equal_op); | |
COMMON_OP_CONVERTER(translate_equal); |
let us have names without _op
suffix. It is obvious that this is for ops:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left couple of comments. Other part looks good to me.
@mvafin, fyi. it can be reused for PT FE |
build_jenkins |
OutputVector translate_equal_op(const NodeContext& node) { | ||
default_op_checks(node, 2, {"Equal"}, true); | ||
|
||
auto result = common_translators::translate_equal(node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rkazants It seems that remove #include "common_translators.hpp"
from src/frontends/tensorflow_common/include/common_op_table.hpp
will result in an error here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions about the implementation of translate_equal_op
function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please include "common_translators.hpp" in "binary_op.cpp". No need to insert to header file, we use it internally in this case.
@rkazants Updated |
build_jenkins |
Details:
Tickets: