Skip to content
This repository was archived by the owner on Oct 1, 2020. It is now read-only.

Per channel quant #51

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions bench/q8gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,105 @@ class Q8GEMM_XZP : public Q8GEMM {
qnnp_q31_requantization_params requantizationParams_;
};

class Q8GEMM_PER_CHANNEL : public Q8GEMM {
public:
inline Q8GEMM_PER_CHANNEL(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr) : Q8GEMM(mr, nr, np, kr) {}
virtual void SetUp(const benchmark::State&) override
{
std::random_device randomDevice;
auto rng = std::mt19937(randomDevice());
auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);

a_.resize(mc() * kc());
std::generate(a_.begin(), a_.end(), std::ref(u8rng));
k_.resize(nc() * kc());
std::generate(k_.begin(), k_.end(), std::ref(u8rng));
b_.resize(mc());
std::generate(b_.begin(), b_.end(), std::ref(s32rng));
w_.resize(kcStride() * ncStride() + ncStride() * sizeof(int32_t) / sizeof(uint8_t));
const uint8_t kernel_zero_point_center = 127;
kernelZeroPointPerChannel_.resize(nr());
requantizationScalePerChannel_.resize(nr());
multiplierPerChannel_.resize(nr());
rightShiftPerChannel_.resize(nr());
const float scale_min = 0.5f;
const float scale_max = 0.99999f;
for (size_t i = 0; i < nr(); ++i) {
kernelZeroPointPerChannel_[i] =
static_cast<uint8_t>(std::min(255, std::max(0, kernel_zero_point_center + (int)(i - nr()/2))));
requantizationScalePerChannel_[i] = scale_min + i * (scale_max - scale_min) / nr();
}
std::fill(w_.begin(), w_.end(), kernel_zero_point_center);
pack_q8gemm_w_per_channel(
nc(), kc(),
nr(), np(), kr(),
127, kernelZeroPointPerChannel_.data(),
k(), b(), w());
c_.resize(mc() * nc());
std::fill(c_.begin(), c_.end(), 0xA5);
quantizationParams_ =
qnnp_compute_conv_quantization_params_per_channel(
127, nr(), kernelZeroPointPerChannel_.data(),
requantizationScalePerChannel_.data(), multiplierPerChannel_.data(), rightShiftPerChannel_.data(),
127, 1, 254);
}

virtual void TearDown(benchmark::State& state) override
{
state.SetItemsProcessed(uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
a_.clear();
k_.clear();
b_.clear();
w_.clear();
c_.clear();
kernelZeroPointPerChannel_.clear();
kernelAndInputScalePerChannel_.clear();
requantizationScalePerChannel_.clear();
multiplierPerChannel_.clear();
rightShiftPerChannel_.clear();
}

protected:
std::vector<uint8_t> kernelZeroPointPerChannel_;
std::vector<float> kernelAndInputScalePerChannel_;
std::vector<float> requantizationScalePerChannel_;
std::vector<int32_t> multiplierPerChannel_;
std::vector<int32_t> rightShiftPerChannel_;
};

template <uint32_t MR, uint32_t NR, uint32_t NP, uint32_t KR>
class Q8GEMM_PER_CHANNEL_L1 : public Q8GEMM_PER_CHANNEL {
public:
inline Q8GEMM_PER_CHANNEL_L1() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR)
{
cpuinfo_initialize();
const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size;
const size_t l1d_reserve = 512;
kc_ = ((l1d_size - l1d_reserve) / sizeof(uint8_t) - mr() * nr()) / (mr() + nr());
if (kr() != 1) {
kc_ = kc_ / kr() * kr();
} else {
kc_ = kc_ / nr() * nr();
}
}
};

template <uint32_t MR, uint32_t NR, uint32_t NP, uint32_t KR>
class Q8GEMM_PER_CHANNEL_Op : public Q8GEMM_PER_CHANNEL {
public:
inline Q8GEMM_PER_CHANNEL_Op() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR) {}

virtual void SetUp(const benchmark::State& state) override
{
mc_ = state.range(0);
nc_ = state.range(1);
kc_ = state.range(2);

Q8GEMM_PER_CHANNEL::SetUp(state);
}
};

template <uint32_t MR, uint32_t NR, uint32_t NP, uint32_t KR>
class Q8GEMM_XZP_L1 : public Q8GEMM_XZP {
public:
Expand Down Expand Up @@ -647,6 +746,40 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(MobileNetV1GemmArgumen
BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(SqueezeNetV10GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(GemmArguments);

BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state)
{
for (auto _ : state) {
q8gemm_ukernel_4x8__aarch32_neon_per_channel(
mr(), nr(), kc(),
a(), kc() * sizeof(uint8_t),
w(),
c(), mr() * sizeof(uint8_t),
quantizationParams(), 0);
}
}

BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state)
{
for (auto _ : state) {
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0; n < nc(); n += nr()) {
const uint32_t nrr = min(nc() - n, nr());
q8gemm_ukernel_4x8__aarch32_neon_per_channel(
mrr, nrr, kc(),
a() + m * kc(), kc() * sizeof(uint8_t),
w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
c() + m * nc() + n, nc() * sizeof(uint8_t),
quantizationParams(), 0);
}
}
}
}
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(MobileNetV1GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(SqueezeNetV10GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(GemmArguments);

BENCHMARK_TEMPLATE_F(Q8GEMM_XZP_L1, 4x8c2__aarch32_neon, 4, 8, 8, 2)(benchmark::State& state)
{
for (auto _ : state) {
Expand Down Expand Up @@ -770,6 +903,41 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(MobileNetV1GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(SqueezeNetV10GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(GemmArguments);

BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state)
{
for (auto _ : state) {
q8gemm_ukernel_4x8__neon_per_channel(
mr(), nr(), kc(),
a(), kc() * sizeof(uint8_t),
w(),
c(), mr() * sizeof(uint8_t),
quantizationParams(), 0);
}
}

BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state)
{
for (auto _ : state) {
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0; n < nc(); n += nr()) {
const uint32_t nrr = min(nc() - n, nr());
q8gemm_ukernel_4x8__neon_per_channel(
mrr, nrr, kc(),
a() + m * kc(), kc() * sizeof(uint8_t),
w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
c() + m * nc() + n, nc() * sizeof(uint8_t),
quantizationParams(), 0);
}
}
}
}

BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(MobileNetV1GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(SqueezeNetV10GemmArguments);
BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(GemmArguments);

BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__neon, 8, 8, 8, 1)(benchmark::State& state)
{
for (auto _ : state) {
Expand Down
2 changes: 2 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def main(args):
build.cc("q8gavgpool/up8xm-neon.c"),
build.cc("q8gemm/4x-sumrows-neon.c"),
build.cc("q8gemm/4x8-neon.c"),
build.cc("q8gemm/4x8-neon_per_channel.c"),
build.cc("q8gemm/4x8c2-xzp-neon.c"),
build.cc("q8gemm/6x4-neon.c"),
build.cc("q8gemm/8x8-neon.c"),
Expand All @@ -128,6 +129,7 @@ def main(args):
build.cc("q8conv/4x8-aarch32-neon.S"),
build.cc("q8dwconv/up8x9-aarch32-neon.S"),
build.cc("q8gemm/4x8-aarch32-neon.S"),
build.cc("q8gemm/4x8-aarch32-neon-per-channel.S"),
build.cc("q8gemm/4x8c2-xzp-aarch32-neon.S"),
]
if build.target.is_arm64:
Expand Down
Loading