Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/output type #566

Merged
merged 10 commits into from
Apr 20, 2022
2 changes: 1 addition & 1 deletion include/nncase/targets/neutral_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class NNCASE_API neutral_target : public target
void register_target_independent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) override;
void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param) override;
void register_allocation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void add_quantization_broadcast(std::unordered_set<ir::node_opcode> &opcodes) override;

Expand Down
2 changes: 1 addition & 1 deletion include/nncase/targets/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class NNCASE_API target
virtual void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) = 0;
virtual void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual std::unique_ptr<ir::quantizer> create_quantizer(const module_type_t &type, ir::calibrate_method calib_method);
virtual void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w);
virtual void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param);
virtual void register_target_dependent_after_quantization_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual void register_target_dependent_after_buffer_fusion_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual void register_allocation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) = 0;
Expand Down
5 changes: 3 additions & 2 deletions include/nncase/transforms/neutral/add_quant_motion.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class NNCASE_API add_input_dequantize_transform : public transform
class NNCASE_API add_output_quantize_transform : public transform
{
public:
add_output_quantize_transform(datatype_t dt) noexcept
: output_type_(dt) { }
add_output_quantize_transform(datatype_t dt, quant_param_t &output_quant_param) noexcept
: output_type_(dt), output_quant_param_(output_quant_param) { }
void process(transform_context &context) override;

protected:
Expand All @@ -45,5 +45,6 @@ class NNCASE_API add_output_quantize_transform : public transform

private:
datatype_t output_type_;
quant_param_t &output_quant_param_;
};
}
22 changes: 14 additions & 8 deletions src/nncase/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,10 @@ class compiler_impl : public compiler
quant->broadcast_output(graph, opcodes);
}

ir::transforms::transform_pass p("process i&o node");

if (compile_options_.output_type != "float32")
p.emplace<nncase::ir::transforms::add_output_quantize_transform>(parse_datatype_str(compile_options_.output_type));
pmgr.add_pass(std::move(p));

pmgr.quantizer(quant);
if (compile_options_.dump_ir)
pmgr.dump_dir(compile_options_.dump_dir);
target_->register_quantize_passes(graph.module_type(), pmgr, parse_datatype_str(compile_options_.quant_type), compile_options_.w_quant_type, compile_options_.use_mse_quant_w);
target_->register_quantize_passes(graph.module_type(), pmgr, parse_datatype_str(compile_options_.quant_type), compile_options_.w_quant_type, compile_options_.use_mse_quant_w, parse_datatype_str(compile_options_.output_type), output_quant_params_);
pmgr.run();
dump_graph(graph, "quantize");
};
Expand Down Expand Up @@ -786,7 +780,18 @@ class compiler_impl : public compiler
std::cout << "TOTAL"
<< "\t" << format_size(total_usage) << std::endl;

std::ofstream file(compile_options_.dump_dir / "memory_usage.txt");
std::ofstream file(compile_options_.dump_dir / "kmodel_info.txt");
if (compile_options_.dump_dir.filename().string() == "ptq" and compile_options_.output_type != "float32")
{
file << "\nOUTPUT_QUANT_PARAM" << std::endl;
file << "scale: " << output_quant_params_.scale << std::endl;
file << "zero_point: " << output_quant_params_.zero_point << std::endl;

std::cout << "\nOUTPUT_QUANT_PARAM" << std::endl;
std::cout << "scale: " << output_quant_params_.scale << std::endl;
std::cout << "zero_point: " << output_quant_params_.zero_point << std::endl;
}
file << "\nMEMORY USAGES" << std::endl;
file << "input: " << format_size(mod_builder.max_usage(mem_input)) << std::endl;
file << "output: " << format_size(mod_builder.max_usage(mem_output)) << std::endl;
file << "data: " << format_size(mod_builder.max_usage(mem_data)) << std::endl;
Expand All @@ -811,6 +816,7 @@ class compiler_impl : public compiler
std::unique_ptr<nncase::target> target_;
std::string real_inlayout_ = "";
std::string real_outlayout_ = "";
quant_param_t output_quant_params_ = { 0, 0 };
};
}

Expand Down
10 changes: 8 additions & 2 deletions src/targets/neutral_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <nncase/schedule/buffer_allocator.h>
#include <nncase/targets/neutral_target.h>
#include <nncase/transforms/neutral/add_quant_checkpoints.h>
#include <nncase/transforms/neutral/add_quant_motion.h>
#include <nncase/transforms/neutral/binary_motion.h>
#include <nncase/transforms/neutral/bitcast_motion.h>
#include <nncase/transforms/neutral/dequantize_motion.h>
Expand Down Expand Up @@ -242,12 +243,12 @@ void neutral_target::register_quantize_annotation_passes([[maybe_unused]] const

{
transform_pass p("annotate_neutral_quantize");
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary);
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::op_bitcast, ir::op_dequantize, ir::op_binary);
pass_mgr.add_pass(std::move(p));
}
}

void neutral_target::register_quantize_passes([[maybe_unused]] const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void neutral_target::register_quantize_passes([[maybe_unused]] const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
{
transform_pass p("fused_unary_to_lut");
Expand All @@ -260,6 +261,11 @@ void neutral_target::register_quantize_passes([[maybe_unused]] const module_type
p.emplace<fold_quantize_transform>();
pass_mgr.add_pass(std::move(p));
}
{
transform_pass p("change_output_type");
p.emplace<add_output_quantize_transform>(output_type, output_quant_param);
pass_mgr.add_pass(std::move(p));
}
}

void neutral_target::register_allocation_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr)
Expand Down
2 changes: 1 addition & 1 deletion src/targets/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void target::register_quantize_annotation_passes([[maybe_unused]] const module_t
{
}

void target::register_quantize_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void target::register_quantize_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
}

Expand Down
2 changes: 2 additions & 0 deletions src/transforms/neutral/add_quant_motion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void add_output_quantize_transform::process(transform_context &context)
auto old_range = quantizer.get(output.owner().output_at(0));
auto params = quantizer.get_quant_param(old_range, bits, qm);

// get quant param for qint output
output_quant_param_ = params;
auto q = context.graph.emplace<quantize>(dt_float32, output.shape(), output_type_, params);
auto new_out_node = context.graph.emplace<output_node>(q->output().type(), q->output().shape());
old_out->input().clear_connection();
Expand Down
7 changes: 4 additions & 3 deletions targets/k210/k210_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <nncase/transforms/k210/kpu_conv2d.h>
#include <nncase/transforms/k210/strided_slice_motion.h>
#include <nncase/transforms/neutral/add_quant_checkpoints.h>
#include <nncase/transforms/neutral/add_quant_motion.h>
#include <nncase/transforms/neutral/add_to_conv2d.h>
#include <nncase/transforms/neutral/dequantize_motion.h>
#include <nncase/transforms/neutral/eliminate_dilated_conv2d.h>
Expand Down Expand Up @@ -160,12 +161,12 @@ void k210_target::register_quantize_annotation_passes(const module_type_t &type,

{
transform_pass p("annotate_kpu_quantize");
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize);
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize, ir::op_binary);
pass_mgr.add_pass(std::move(p));
}
}

void k210_target::register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void k210_target::register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
{
transform_pass p("lowering_kpu_conv2d");
Expand All @@ -186,7 +187,7 @@ void k210_target::register_quantize_passes(const module_type_t &type, ir::transf
pass_mgr.add_pass(std::move(p));
}
{
neutral_target::register_quantize_passes(type, pass_mgr, quant_type, w_quant_type, use_mse_quant_w);
neutral_target::register_quantize_passes(type, pass_mgr, quant_type, w_quant_type, use_mse_quant_w, output_type, output_quant_param);

transform_pass p("fold_kpu_data_exchg2");
// p.emplace<fuse_kpu_download_transform>();
Expand Down
2 changes: 1 addition & 1 deletion targets/k210/k210_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class k210_target : public neutral_target
void register_allocators(const module_type_t &type, schedule::allocator_map_t &allocators, std::vector<std::shared_ptr<schedule::buffer_allocator>> &allocator_holders) override;
void register_evaluator_ops() override;
void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param) override;
void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) override;
void add_quantization_broadcast(std::unordered_set<ir::node_opcode> &opcodes) override;

Expand Down
1 change: 1 addition & 0 deletions tests/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ case: # case的配置,应该是一个多层次的
dump_import_op_range: false
quant_type: 'uint8'
w_quant_type: 'uint8'
output_type: 'float32'
use_mse_quant_w: true
quant_method: "no_clip"

Expand Down
16 changes: 16 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def _cast_bfloat16_then_float32(values: np.array):
values[i] = value


def deq_output(kmodel_info, data):
with open(kmodel_info, 'r') as f:
a = f.readlines()[2:4]
scale = float(a[0].split(' ')[-1][:-1])
zero_point = int(a[1].split(' ')[-1][:-1])
return np.float32((data.astype(np.int) - zero_point) * scale)


def generate_image_dataset(shape: List[int], dtype: np.dtype,
batch_index: int, batch_size: int,
case_dir: str,
Expand Down Expand Up @@ -652,6 +660,7 @@ def nncase_infer(self, cfg, case_dir: str,
compile_options.is_fpga = cfg.compile_opt.is_fpga
compile_options.use_mse_quant_w = cfg.compile_opt.use_mse_quant_w
compile_options.input_type = preprocess['input_type']
compile_options.output_type = cfg.compile_opt.output_type
compile_options.quant_type = cfg.compile_opt.quant_type
compile_options.w_quant_type = cfg.compile_opt.w_quant_type
compile_options.swapRB = preprocess['swapRB']
Expand Down Expand Up @@ -741,6 +750,13 @@ def nncase_infer(self, cfg, case_dir: str,
infer_output_paths.append((
os.path.join(infer_dir, f'nncase_result_{i}.bin'),
os.path.join(infer_dir, f'nncase_result_{i}.txt')))
if cfg.compile_opt.output_type != "float32" and infer_dir.split('/')[-1] == "ptq":
result.tofile(os.path.join(
infer_dir, f'nncase_result_{cfg.compile_opt.output_type}_{i}.bin'))
self.totxtfile(os.path.join(
infer_dir, f'nncase_result_{cfg.compile_opt.output_type}_{i}.txt'), result)
result = deq_output(os.path.join(
infer_dir, f'kmodel_info.txt'), result)
result.tofile(infer_output_paths[-1][0])
self.totxtfile(infer_output_paths[-1][1], result)
return infer_output_paths
Expand Down