Skip to content

Commit

Permalink
support NMT/TTS (#623)
Browse files Browse the repository at this point in the history
* support import erf

* fix argmax min value

* set arg_max init value by std library

* replace dequant with convert

* add more condition for merge child regions

* add shapeinference for cpu infer[onnx]

* apply code-format changes

* rewrite circle check

* apply code-format changes

* modify comment

* fix ci

Co-authored-by: curioyang <curioyang@users.noreply.github.com>
  • Loading branch information
curioyang and curioyang authored Aug 30, 2022
1 parent ed2e479 commit 6925ee3
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 7 deletions.
5 changes: 4 additions & 1 deletion include/nncase/runtime/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ typedef enum _unary_op
unary_square,
unary_tanh,
unary_bitwise_not,
unary_logical_not
unary_logical_not,
unary_erf
} unary_op_t;

inline std::string unary_op_to_string(unary_op_t op)
Expand Down Expand Up @@ -301,6 +302,8 @@ inline std::string unary_op_to_string(unary_op_t op)
return "unary_bitwise_not";
case unary_logical_not:
return "unary_logical_not";
case unary_erf:
return "unary_erf";
}
return "unknown";
}
Expand Down
2 changes: 1 addition & 1 deletion include/nncase/runtime/runtime_op_utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ inline bool is_optimized_binary_op(binary_op_t op)

inline bool is_optimized_unary_op(unary_op_t op)
{
return op == unary_abs || op == unary_ceil || op == unary_cos || op == unary_exp || op == unary_floor || op == unary_log || op == unary_neg || op == unary_round || op == unary_rsqrt || op == unary_sign || op == unary_sin || op == unary_sqrt || op == unary_square || op == unary_tanh;
return op == unary_abs || op == unary_ceil || op == unary_cos || op == unary_exp || op == unary_floor || op == unary_log || op == unary_neg || op == unary_round || op == unary_sign || op == unary_sin || op == unary_sqrt || op == unary_square || op == unary_tanh;
}

template <class TShape>
Expand Down
3 changes: 3 additions & 0 deletions src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ void register_neutral_evaluators()
case unary_tanh:
unary([](auto a) { return tanh(a); });
break;
case unary_erf:
unary([](auto a) { return erf(a); });
break;
default:
throw std::runtime_error("Not supported unary");
} });
Expand Down
1 change: 1 addition & 0 deletions src/importer/onnx/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ DEFINE_OPCODE(Elu)
DEFINE_OPCODE(Exp)
DEFINE_OPCODE(Expand)
DEFINE_OPCODE(Equal)
DEFINE_OPCODE(Erf)
DEFINE_OPCODE(Flatten)
DEFINE_OPCODE(Floor)
DEFINE_OPCODE(Gather)
Expand Down
5 changes: 5 additions & 0 deletions src/importer/onnx/ops/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ void onnx_importer::convert_op_Tanh(const onnx::NodeProto &node)
convert_unary(node, unary_tanh);
}

void onnx_importer::convert_op_Erf(const onnx::NodeProto &node)
{
convert_unary(node, unary_erf);
}

void onnx_importer::convert_unary(const onnx::NodeProto &node, const unary_op_t unary_op)
{
assert(node.input().size() == 1);
Expand Down
5 changes: 2 additions & 3 deletions src/importer/onnx/ops/where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ void onnx_importer::convert_op_Where(const onnx::NodeProto &node)
const auto &input_c = node.input()[2];
const auto &output = node.output()[0];

quant_param_t qparam { 0, 1.f };
datatype_t dtype = dt_float32;
auto deq_a = graph_.emplace<dequantize>(get_datatype(input_a).value(), get_shape(input_a), dtype, qparam);
deq_a->name(op_name + "/deq_a");
auto deq_a = graph_.emplace<convert>(get_datatype(input_a).value(), get_shape(input_a), dtype);
deq_a->name(op_name + "/cvt");

auto op = graph_.emplace<ternary>(dtype, get_datatype(input_b).value(), deq_a->output().shape(), get_shape(input_b), get_shape(input_c));
op->name(op_name + "/ternary");
Expand Down
28 changes: 26 additions & 2 deletions src/ir/graph.partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <nncase/ir/ops/constant.h>
#include <nncase/ir/visitor.h>
#include <nncase/runtime/stackvm/runtime_module.h>
#include <queue>
#include <unordered_set>

using namespace nncase;
Expand Down Expand Up @@ -184,6 +185,27 @@ class graph_merger
} while (changed);
}

bool check_circle(std::list<region>::iterator &ita, std::list<region>::iterator &itb)
{
/*
总共判断两层就可以了
*/
for (auto it : itb->region_inputs)
{
for (auto mid = regions_.begin(); mid != regions_.end(); mid++)
{
if (mid == ita || mid == itb)
continue;
if (mid->outputs.contains(it->connection()) && mid->module_type != ita->module_type && !mid->is_all_noaction
&& std::any_of(mid->region_inputs.begin(), mid->region_inputs.end(), [&](input_connector *in) { return ita->outputs.contains(in->connection()); }))
{
return false;
}
}
}
return true;
}

bool merge_child_region()
{
bool ever_changed = false;
Expand All @@ -202,9 +224,11 @@ class graph_merger
&& itb->module_type == runtime::stackvm::stackvm_module_type))
continue;

// itb's inputs all connect to ita's output
//// itb's inputs all connect to ita's output
// itb's has inputs connect to ita's output without circle
if ((ita->module_type == itb->module_type || itb->is_all_noaction)
&& std::all_of(itb->region_inputs.begin(), itb->region_inputs.end(), [&](input_connector *in) { return ita->outputs.contains(in->connection()); }))
&& std::any_of(itb->region_inputs.begin(), itb->region_inputs.end(), [&](input_connector *in) { return ita->outputs.contains(in->connection()); })
&& check_circle(ita, itb))
to_be_merge.emplace_back(itb);
}

Expand Down
1 change: 1 addition & 0 deletions src/kernels/cpu/reference/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ result<void> reference::unary(unary_op_t op, const float *input, float *output,
UNARY_IMPL(unary_sqrt, sqrtf);
UNARY_IMPL(unary_square, [](float v) { return v * v; });
UNARY_IMPL(unary_tanh, tanhf);
UNARY_IMPL(unary_erf, erff);
default:
return err(std::errc::not_supported);
}
Expand Down
1 change: 1 addition & 0 deletions tests/onnx_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def cpu_infer(self, case_dir: str, model_file: bytes, type: str, mode: str):
onnx_model = onnx.load(model_file)
onnx_model = version_converter.convert_version(onnx_model, 8)
model_file = os.path.join(case_dir, 'converted.onnx')
onnx_model = onnx.shape_inference(onnx_model)
onnx.save_model(onnx_model, model_file)
sess = ort.InferenceSession(model_file)

Expand Down

0 comments on commit 6925ee3

Please sign in to comment.