Skip to content

Commit

Permalink
Feature/output type (#566)
Browse files Browse the repository at this point in the history
* change output type for quant

* support output type for quant

* support change output type && get output quant param

* split output type from preprocess options to compile options

* dequantize qint8 output for result compare

* update config for output type

* update default output type
  • Loading branch information
curioyang authored Apr 20, 2022
1 parent 1ccea59 commit f2c0a1e
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 19 deletions.
2 changes: 1 addition & 1 deletion include/nncase/targets/neutral_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class NNCASE_API neutral_target : public target
void register_target_independent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) override;
void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param) override;
void register_allocation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void add_quantization_broadcast(std::unordered_set<ir::node_opcode> &opcodes) override;

Expand Down
2 changes: 1 addition & 1 deletion include/nncase/targets/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class NNCASE_API target
virtual void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) = 0;
virtual void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual std::unique_ptr<ir::quantizer> create_quantizer(const module_type_t &type, ir::calibrate_method calib_method);
virtual void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w);
virtual void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param);
virtual void register_target_dependent_after_quantization_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual void register_target_dependent_after_buffer_fusion_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr);
virtual void register_allocation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) = 0;
Expand Down
5 changes: 3 additions & 2 deletions include/nncase/transforms/neutral/add_quant_motion.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class NNCASE_API add_input_dequantize_transform : public transform
class NNCASE_API add_output_quantize_transform : public transform
{
public:
add_output_quantize_transform(datatype_t dt) noexcept
: output_type_(dt) { }
add_output_quantize_transform(datatype_t dt, quant_param_t &output_quant_param) noexcept
: output_type_(dt), output_quant_param_(output_quant_param) { }
void process(transform_context &context) override;

protected:
Expand All @@ -45,5 +45,6 @@ class NNCASE_API add_output_quantize_transform : public transform

private:
datatype_t output_type_;
quant_param_t &output_quant_param_;
};
}
22 changes: 14 additions & 8 deletions src/nncase/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,10 @@ class compiler_impl : public compiler
quant->broadcast_output(graph, opcodes);
}

ir::transforms::transform_pass p("process i&o node");

if (compile_options_.output_type != "float32")
p.emplace<nncase::ir::transforms::add_output_quantize_transform>(parse_datatype_str(compile_options_.output_type));
pmgr.add_pass(std::move(p));

pmgr.quantizer(quant);
if (compile_options_.dump_ir)
pmgr.dump_dir(compile_options_.dump_dir);
target_->register_quantize_passes(graph.module_type(), pmgr, parse_datatype_str(compile_options_.quant_type), compile_options_.w_quant_type, compile_options_.use_mse_quant_w);
target_->register_quantize_passes(graph.module_type(), pmgr, parse_datatype_str(compile_options_.quant_type), compile_options_.w_quant_type, compile_options_.use_mse_quant_w, parse_datatype_str(compile_options_.output_type), output_quant_params_);
pmgr.run();
dump_graph(graph, "quantize");
};
Expand Down Expand Up @@ -786,7 +780,18 @@ class compiler_impl : public compiler
std::cout << "TOTAL"
<< "\t" << format_size(total_usage) << std::endl;

std::ofstream file(compile_options_.dump_dir / "memory_usage.txt");
std::ofstream file(compile_options_.dump_dir / "kmodel_info.txt");
if (compile_options_.dump_dir.filename().string() == "ptq" and compile_options_.output_type != "float32")
{
file << "\nOUTPUT_QUANT_PARAM" << std::endl;
file << "scale: " << output_quant_params_.scale << std::endl;
file << "zero_point: " << output_quant_params_.zero_point << std::endl;

std::cout << "\nOUTPUT_QUANT_PARAM" << std::endl;
std::cout << "scale: " << output_quant_params_.scale << std::endl;
std::cout << "zero_point: " << output_quant_params_.zero_point << std::endl;
}
file << "\nMEMORY USAGES" << std::endl;
file << "input: " << format_size(mod_builder.max_usage(mem_input)) << std::endl;
file << "output: " << format_size(mod_builder.max_usage(mem_output)) << std::endl;
file << "data: " << format_size(mod_builder.max_usage(mem_data)) << std::endl;
Expand All @@ -811,6 +816,7 @@ class compiler_impl : public compiler
std::unique_ptr<nncase::target> target_;
std::string real_inlayout_ = "";
std::string real_outlayout_ = "";
quant_param_t output_quant_params_ = { 0, 0 };
};
}

Expand Down
10 changes: 8 additions & 2 deletions src/targets/neutral_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <nncase/schedule/buffer_allocator.h>
#include <nncase/targets/neutral_target.h>
#include <nncase/transforms/neutral/add_quant_checkpoints.h>
#include <nncase/transforms/neutral/add_quant_motion.h>
#include <nncase/transforms/neutral/binary_motion.h>
#include <nncase/transforms/neutral/bitcast_motion.h>
#include <nncase/transforms/neutral/dequantize_motion.h>
Expand Down Expand Up @@ -242,12 +243,12 @@ void neutral_target::register_quantize_annotation_passes([[maybe_unused]] const

{
transform_pass p("annotate_neutral_quantize");
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary);
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::op_bitcast, ir::op_dequantize, ir::op_binary);
pass_mgr.add_pass(std::move(p));
}
}

void neutral_target::register_quantize_passes([[maybe_unused]] const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void neutral_target::register_quantize_passes([[maybe_unused]] const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
{
transform_pass p("fused_unary_to_lut");
Expand All @@ -260,6 +261,11 @@ void neutral_target::register_quantize_passes([[maybe_unused]] const module_type
p.emplace<fold_quantize_transform>();
pass_mgr.add_pass(std::move(p));
}
{
transform_pass p("change_output_type");
p.emplace<add_output_quantize_transform>(output_type, output_quant_param);
pass_mgr.add_pass(std::move(p));
}
}

void neutral_target::register_allocation_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr)
Expand Down
2 changes: 1 addition & 1 deletion src/targets/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void target::register_quantize_annotation_passes([[maybe_unused]] const module_t
{
}

void target::register_quantize_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void target::register_quantize_passes([[maybe_unused]] const module_type_t &type, [[maybe_unused]] ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
}

Expand Down
2 changes: 2 additions & 0 deletions src/transforms/neutral/add_quant_motion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void add_output_quantize_transform::process(transform_context &context)
auto old_range = quantizer.get(output.owner().output_at(0));
auto params = quantizer.get_quant_param(old_range, bits, qm);

// get quant param for qint output
output_quant_param_ = params;
auto q = context.graph.emplace<quantize>(dt_float32, output.shape(), output_type_, params);
auto new_out_node = context.graph.emplace<output_node>(q->output().type(), q->output().shape());
old_out->input().clear_connection();
Expand Down
7 changes: 4 additions & 3 deletions targets/k210/k210_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <nncase/transforms/k210/kpu_conv2d.h>
#include <nncase/transforms/k210/strided_slice_motion.h>
#include <nncase/transforms/neutral/add_quant_checkpoints.h>
#include <nncase/transforms/neutral/add_quant_motion.h>
#include <nncase/transforms/neutral/add_to_conv2d.h>
#include <nncase/transforms/neutral/dequantize_motion.h>
#include <nncase/transforms/neutral/eliminate_dilated_conv2d.h>
Expand Down Expand Up @@ -160,12 +161,12 @@ void k210_target::register_quantize_annotation_passes(const module_type_t &type,

{
transform_pass p("annotate_kpu_quantize");
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize);
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize, ir::op_binary);
pass_mgr.add_pass(std::move(p));
}
}

void k210_target::register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w)
void k210_target::register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, [[maybe_unused]] datatype_t quant_type, [[maybe_unused]] std::string_view w_quant_type, [[maybe_unused]] bool use_mse_quant_w, [[maybe_unused]] datatype_t output_type, [[maybe_unused]] quant_param_t &output_quant_param)
{
{
transform_pass p("lowering_kpu_conv2d");
Expand All @@ -186,7 +187,7 @@ void k210_target::register_quantize_passes(const module_type_t &type, ir::transf
pass_mgr.add_pass(std::move(p));
}
{
neutral_target::register_quantize_passes(type, pass_mgr, quant_type, w_quant_type, use_mse_quant_w);
neutral_target::register_quantize_passes(type, pass_mgr, quant_type, w_quant_type, use_mse_quant_w, output_type, output_quant_param);

transform_pass p("fold_kpu_data_exchg2");
// p.emplace<fuse_kpu_download_transform>();
Expand Down
2 changes: 1 addition & 1 deletion targets/k210/k210_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class k210_target : public neutral_target
void register_allocators(const module_type_t &type, schedule::allocator_map_t &allocators, std::vector<std::shared_ptr<schedule::buffer_allocator>> &allocator_holders) override;
void register_evaluator_ops() override;
void register_quantize_annotation_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w) override;
void register_quantize_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, datatype_t quant_type, std::string_view w_quant_type, bool use_mse_quant_w, datatype_t output_type, quant_param_t &output_quant_param) override;
void register_target_dependent_passes(const module_type_t &type, ir::transforms::pass_manager &pass_mgr, bool use_ptq) override;
void add_quantization_broadcast(std::unordered_set<ir::node_opcode> &opcodes) override;

Expand Down
1 change: 1 addition & 0 deletions tests/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ case: # case的配置,应该是一个多层次的
dump_import_op_range: false
quant_type: 'uint8'
w_quant_type: 'uint8'
output_type: 'float32'
use_mse_quant_w: true
quant_method: "no_clip"

Expand Down
16 changes: 16 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def _cast_bfloat16_then_float32(values: np.array):
values[i] = value


def deq_output(kmodel_info, data):
with open(kmodel_info, 'r') as f:
a = f.readlines()[2:4]
scale = float(a[0].split(' ')[-1][:-1])
zero_point = int(a[1].split(' ')[-1][:-1])
return np.float32((data.astype(np.int) - zero_point) * scale)


def generate_image_dataset(shape: List[int], dtype: np.dtype,
batch_index: int, batch_size: int,
case_dir: str,
Expand Down Expand Up @@ -652,6 +660,7 @@ def nncase_infer(self, cfg, case_dir: str,
compile_options.is_fpga = cfg.compile_opt.is_fpga
compile_options.use_mse_quant_w = cfg.compile_opt.use_mse_quant_w
compile_options.input_type = preprocess['input_type']
compile_options.output_type = cfg.compile_opt.output_type
compile_options.quant_type = cfg.compile_opt.quant_type
compile_options.w_quant_type = cfg.compile_opt.w_quant_type
compile_options.swapRB = preprocess['swapRB']
Expand Down Expand Up @@ -741,6 +750,13 @@ def nncase_infer(self, cfg, case_dir: str,
infer_output_paths.append((
os.path.join(infer_dir, f'nncase_result_{i}.bin'),
os.path.join(infer_dir, f'nncase_result_{i}.txt')))
if cfg.compile_opt.output_type != "float32" and infer_dir.split('/')[-1] == "ptq":
result.tofile(os.path.join(
infer_dir, f'nncase_result_{cfg.compile_opt.output_type}_{i}.bin'))
self.totxtfile(os.path.join(
infer_dir, f'nncase_result_{cfg.compile_opt.output_type}_{i}.txt'), result)
result = deq_output(os.path.join(
infer_dir, f'kmodel_info.txt'), result)
result.tofile(infer_output_paths[-1][0])
self.totxtfile(infer_output_paths[-1][1], result)
return infer_output_paths
Expand Down

0 comments on commit f2c0a1e

Please sign in to comment.