From 37c2f64832120c359411716a4c5f8281ed8937b8 Mon Sep 17 00:00:00 2001 From: "Mihai.Olinovici" Date: Thu, 19 Sep 2024 07:29:05 +0000 Subject: [PATCH 01/17] Add RVV qs8/qu8-f32-vcvt kernels and configs. --- cmake/gen/rvv_microkernels.cmake | 4 ++ gen/rvv_microkernels.bzl | 4 ++ scripts/generate-qs8-f32-vcvt.sh | 7 +++ src/configs/unary-elementwise-config.c | 14 +++--- src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c | 47 +++++++++++++++++++ src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c | 47 +++++++++++++++++++ src/qs8-f32-vcvt/qs8-f32-vcvt.h | 5 ++ src/qs8-f32-vcvt/rvv.c.in | 52 +++++++++++++++++++++ src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c | 47 +++++++++++++++++++ src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c | 47 +++++++++++++++++++ src/qu8-f32-vcvt/qu8-f32-vcvt.h | 5 ++ 11 files changed, 273 insertions(+), 6 deletions(-) create mode 100644 src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c create mode 100644 src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c create mode 100644 src/qs8-f32-vcvt/rvv.c.in create mode 100644 src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c create mode 100644 src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c diff --git a/cmake/gen/rvv_microkernels.cmake b/cmake/gen/rvv_microkernels.cmake index cd41d9717bb..7df64dde964 100644 --- a/cmake/gen/rvv_microkernels.cmake +++ b/cmake/gen/rvv_microkernels.cmake @@ -54,8 +54,10 @@ SET(PROD_RVV_MICROKERNEL_SRCS src/f32-vrnd/gen/f32-vrndu-rvv-u4v.c src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c + src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c + src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c src/x32-packw/gen/x32-packw-x4v-gemm-goi-rvv-u8.c @@ -180,8 +182,10 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c + src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c + src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c diff --git a/gen/rvv_microkernels.bzl b/gen/rvv_microkernels.bzl index 397b67d9954..2d7222ca9f0 100644 --- a/gen/rvv_microkernels.bzl +++ b/gen/rvv_microkernels.bzl @@ -50,8 +50,10 @@ PROD_RVV_MICROKERNEL_SRCS = [ "src/f32-vrnd/gen/f32-vrndu-rvv-u4v.c", "src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c", "src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c", + "src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c", "src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c", "src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c", + "src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c", "src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c", "src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c", "src/x32-packw/gen/x32-packw-x4v-gemm-goi-rvv-u8.c", @@ -177,8 +179,10 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c", + "src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c", "src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c", "src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c", + "src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c", "src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c", "src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c", "src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c", diff --git a/scripts/generate-qs8-f32-vcvt.sh b/scripts/generate-qs8-f32-vcvt.sh index 2dbf4529b33..f307861d254 100755 --- a/scripts/generate-qs8-f32-vcvt.sh +++ b/scripts/generate-qs8-f32-vcvt.sh @@ -15,6 +15,13 @@ tools/xngen src/qs8-f32-vcvt/neon.c.in -D BATCH_TILE=16 -D DATATYPE=QU8 -o src/q tools/xngen src/qs8-f32-vcvt/neon.c.in -D BATCH_TILE=24 -D DATATYPE=QU8 -o src/qu8-f32-vcvt/gen/qu8-f32-vcvt-neon-u24.c & tools/xngen src/qs8-f32-vcvt/neon.c.in -D BATCH_TILE=32 -D DATATYPE=QU8 -o src/qu8-f32-vcvt/gen/qu8-f32-vcvt-neon-u32.c & +################################ RISC-V Vector ################################ +tools/xngen src/qs8-f32-vcvt/rvv.c.in -D LMUL=1 -D DATATYPE=QS8 -o src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c & +tools/xngen src/qs8-f32-vcvt/rvv.c.in -D LMUL=2 -D DATATYPE=QS8 -o src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c & + +tools/xngen src/qs8-f32-vcvt/rvv.c.in -D LMUL=1 -D DATATYPE=QU8 -o src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c & +tools/xngen src/qs8-f32-vcvt/rvv.c.in -D LMUL=2 -D DATATYPE=QU8 -o src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c & + ################################# x86 128-bit ################################# tools/xngen src/qs8-f32-vcvt/sse2.c.in -D BATCH_TILE=8 -D DATATYPE=QS8 -o src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse2-u8.c & tools/xngen src/qs8-f32-vcvt/sse2.c.in -D BATCH_TILE=16 -D DATATYPE=QS8 -o src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse2-u16.c & diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index 6e794819768..bfd10be0ba5 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -2066,10 +2066,11 @@ static void init_qs8_to_f32_cvt_config(void) { qs8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_f32_vcvt_ukernel__scalar_u1; qs8_to_f32_cvt_config.init.qs8_f32_cvt = xnn_init_qs8_f32_cvt_scalar_params; qs8_to_f32_cvt_config.element_tile = 1; - #elif XNN_ARCH_RISCV - qs8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_f32_vcvt_ukernel__scalar_u4; + #elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + qs8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_f32_vcvt_ukernel__rvv_u2v; qs8_to_f32_cvt_config.init.qs8_f32_cvt = xnn_init_qs8_f32_cvt_scalar_params; - qs8_to_f32_cvt_config.element_tile = 4; + qs8_to_f32_cvt_config.element_tile = hardware_config->vlenb / sizeof(int8_t) * 2; // (VLENB/sizeof)*LMUL; #else qs8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_f32_vcvt_ukernel__scalar_u4; qs8_to_f32_cvt_config.init.qs8_f32_cvt = xnn_init_qs8_f32_cvt_scalar_params; @@ -2288,10 +2289,11 @@ static void init_qu8_to_f32_cvt_config(void) { qu8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_f32_vcvt_ukernel__scalar_u1; qu8_to_f32_cvt_config.init.qu8_f32_cvt = xnn_init_qu8_f32_cvt_scalar_params; qu8_to_f32_cvt_config.element_tile = 1; - #elif XNN_ARCH_RISCV - qu8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_f32_vcvt_ukernel__scalar_u4; + #elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + qu8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_f32_vcvt_ukernel__rvv_u2v; qu8_to_f32_cvt_config.init.qu8_f32_cvt = xnn_init_qu8_f32_cvt_scalar_params; - qu8_to_f32_cvt_config.element_tile = 4; + qu8_to_f32_cvt_config.element_tile = hardware_config->vlenb / sizeof(uint8_t) * 2; // (VLENB/sizeof)*LMUL; #else qu8_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_f32_vcvt_ukernel__scalar_u4; qu8_to_f32_cvt_config.init.qu8_f32_cvt = xnn_init_qu8_f32_cvt_scalar_params; diff --git a/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c b/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c new file mode 100644 index 00000000000..6c5c3e2bb1a --- /dev/null +++ b/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u1v.c @@ -0,0 +1,47 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-f32-vcvt/rvv.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/common.h" +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/vcvt.h" + + +void xnn_qs8_f32_vcvt_ukernel__rvv_u1v( + size_t batch, + const int8_t* input, + float* output, + const struct xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(int8_t) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_INT8_T; + + const float scale = params->scalar.scale; + const int32_t minus_zero_point = -params->scalar.zero_point; + + for (; batch > 0; ) { + const int32_t n = __riscv_vsetvl_e8m1(batch); batch -= n; + + vint8m1_t x_i8v = __riscv_vle8_v_i8m1(input, n); input += n; + + vint32m4_t wx_i32v = __riscv_vsext_vf4_i32m4(x_i8v, n); + wx_i32v = __riscv_vadd_vx_i32m4(wx_i32v, minus_zero_point, n); + vfloat32m4_t y_f32v = __riscv_vfcvt_f_x_v_f32m4(wx_i32v, n); + y_f32v = __riscv_vfmul_vf_f32m4(y_f32v, scale, n); + + __riscv_vse32_v_f32m4(output, y_f32v, n); output += n; + } +} diff --git a/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c b/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c new file mode 100644 index 00000000000..c6961b8451e --- /dev/null +++ b/src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c @@ -0,0 +1,47 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-f32-vcvt/rvv.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/common.h" +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/vcvt.h" + + +void xnn_qs8_f32_vcvt_ukernel__rvv_u2v( + size_t batch, + const int8_t* input, + float* output, + const struct xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(int8_t) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_INT8_T; + + const float scale = params->scalar.scale; + const int32_t minus_zero_point = -params->scalar.zero_point; + + for (; batch > 0; ) { + const int32_t n = __riscv_vsetvl_e8m2(batch); batch -= n; + + vint8m2_t x_i8v = __riscv_vle8_v_i8m2(input, n); input += n; + + vint32m8_t wx_i32v = __riscv_vsext_vf4_i32m8(x_i8v, n); + wx_i32v = __riscv_vadd_vx_i32m8(wx_i32v, minus_zero_point, n); + vfloat32m8_t y_f32v = __riscv_vfcvt_f_x_v_f32m8(wx_i32v, n); + y_f32v = __riscv_vfmul_vf_f32m8(y_f32v, scale, n); + + __riscv_vse32_v_f32m8(output, y_f32v, n); output += n; + } +} diff --git a/src/qs8-f32-vcvt/qs8-f32-vcvt.h b/src/qs8-f32-vcvt/qs8-f32-vcvt.h index a5f0b4286ce..e7de4a98713 100644 --- a/src/qs8-f32-vcvt/qs8-f32-vcvt.h +++ b/src/qs8-f32-vcvt/qs8-f32-vcvt.h @@ -53,6 +53,11 @@ XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__wasmsimd_u24, 24, false XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__wasmsimd_u32, 32, false, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__rvv_u1v, 1, true, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__rvv_u2v, 2, true, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) +#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR + XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__scalar_u1, 1, false, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__scalar_u2, 2, false, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qs8_f32_vcvt_ukernel__scalar_u3, 3, false, int8_t, float, struct xnn_qs8_f32_cvt_params, xnn_init_qs8_f32_cvt_scalar_params) diff --git a/src/qs8-f32-vcvt/rvv.c.in b/src/qs8-f32-vcvt/rvv.c.in new file mode 100644 index 00000000000..424fa6dfa11 --- /dev/null +++ b/src/qs8-f32-vcvt/rvv.c.in @@ -0,0 +1,52 @@ +// Copyright 2024 Imagination Technologies, inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert LMUL in [1, 2] +$VXINT = {"QS8": "vint", "QU8": "vuint"}[DATATYPE] +$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] +$XLOAD = {"QS8": "__riscv_vle8_v_i8", "QU8": "__riscv_vle8_v_u8"}[DATATYPE] +#include + +#include + +#include "xnnpack/common.h" +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/vcvt.h" + + +void xnn_${DATATYPE.lower()}_f32_vcvt_ukernel__rvv_u${LMUL}v( + size_t batch, + const ${XINT8_T}* input, + float* output, + const struct xnn_${DATATYPE.lower()}_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(int8_t) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_INT8_T; + + const float scale = params->scalar.scale; + const int32_t minus_zero_point = -params->scalar.zero_point; + + for (; batch > 0; ) { + const int32_t n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n; + + $if DATATYPE == "QS8": + vint8m${LMUL}_t x_i8v = __riscv_vle8_v_i8m${LMUL}(input, n); input += n; + + vint32m${LMUL*4}_t wx_i32v = __riscv_vsext_vf4_i32m${LMUL*4}(x_i8v, n); + $else: + vuint8m${LMUL}_t x_u8v = __riscv_vle8_v_u8m${LMUL}(input, n); input += n; + + vint32m${LMUL*4}_t wx_i32v = __riscv_vreinterpret_v_u32m${LMUL*4}_i32m${LMUL*4}(__riscv_vzext_vf4_u32m${LMUL*4}(x_u8v, n)); + wx_i32v = __riscv_vadd_vx_i32m${LMUL*4}(wx_i32v, minus_zero_point, n); + vfloat32m${LMUL*4}_t y_f32v = __riscv_vfcvt_f_x_v_f32m${LMUL*4}(wx_i32v, n); + y_f32v = __riscv_vfmul_vf_f32m${LMUL*4}(y_f32v, scale, n); + + __riscv_vse32_v_f32m${LMUL*4}(output, y_f32v, n); output += n; + } +} diff --git a/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c b/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c new file mode 100644 index 00000000000..e4365f09ac3 --- /dev/null +++ b/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u1v.c @@ -0,0 +1,47 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-f32-vcvt/rvv.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/common.h" +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/vcvt.h" + + +void xnn_qu8_f32_vcvt_ukernel__rvv_u1v( + size_t batch, + const uint8_t* input, + float* output, + const struct xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(int8_t) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_INT8_T; + + const float scale = params->scalar.scale; + const int32_t minus_zero_point = -params->scalar.zero_point; + + for (; batch > 0; ) { + const int32_t n = __riscv_vsetvl_e8m1(batch); batch -= n; + + vuint8m1_t x_u8v = __riscv_vle8_v_u8m1(input, n); input += n; + + vint32m4_t wx_i32v = __riscv_vreinterpret_v_u32m4_i32m4(__riscv_vzext_vf4_u32m4(x_u8v, n)); + wx_i32v = __riscv_vadd_vx_i32m4(wx_i32v, minus_zero_point, n); + vfloat32m4_t y_f32v = __riscv_vfcvt_f_x_v_f32m4(wx_i32v, n); + y_f32v = __riscv_vfmul_vf_f32m4(y_f32v, scale, n); + + __riscv_vse32_v_f32m4(output, y_f32v, n); output += n; + } +} diff --git a/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c b/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c new file mode 100644 index 00000000000..4ac23cdc184 --- /dev/null +++ b/src/qu8-f32-vcvt/gen/qu8-f32-vcvt-rvv-u2v.c @@ -0,0 +1,47 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-f32-vcvt/rvv.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/common.h" +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/vcvt.h" + + +void xnn_qu8_f32_vcvt_ukernel__rvv_u2v( + size_t batch, + const uint8_t* input, + float* output, + const struct xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(int8_t) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_INT8_T; + + const float scale = params->scalar.scale; + const int32_t minus_zero_point = -params->scalar.zero_point; + + for (; batch > 0; ) { + const int32_t n = __riscv_vsetvl_e8m2(batch); batch -= n; + + vuint8m2_t x_u8v = __riscv_vle8_v_u8m2(input, n); input += n; + + vint32m8_t wx_i32v = __riscv_vreinterpret_v_u32m8_i32m8(__riscv_vzext_vf4_u32m8(x_u8v, n)); + wx_i32v = __riscv_vadd_vx_i32m8(wx_i32v, minus_zero_point, n); + vfloat32m8_t y_f32v = __riscv_vfcvt_f_x_v_f32m8(wx_i32v, n); + y_f32v = __riscv_vfmul_vf_f32m8(y_f32v, scale, n); + + __riscv_vse32_v_f32m8(output, y_f32v, n); output += n; + } +} diff --git a/src/qu8-f32-vcvt/qu8-f32-vcvt.h b/src/qu8-f32-vcvt/qu8-f32-vcvt.h index 17cfe586c7d..3bcbc008f05 100644 --- a/src/qu8-f32-vcvt/qu8-f32-vcvt.h +++ b/src/qu8-f32-vcvt/qu8-f32-vcvt.h @@ -53,6 +53,11 @@ XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__wasmsimd_u24, 24, false XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__wasmsimd_u32, 32, false, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__rvv_u1v, 1, true, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__rvv_u2v, 2, true, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) +#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR + XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__scalar_u1, 1, false, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__scalar_u2, 2, false, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_qu8_f32_vcvt_ukernel__scalar_u3, 3, false, uint8_t, float, struct xnn_qu8_f32_cvt_params, xnn_init_qu8_f32_cvt_scalar_params) From e545f9e0125df6ee4f00df929ffa949280027363 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 3 Sep 2024 17:30:17 -0700 Subject: [PATCH 02/17] Neon Dot QP8 ukernels, tests, benchmarks --- CMakeLists.txt | 1 + bench/qp8-f32-qb4w-gemm.cc | 55 ++++++ cmake/gen/neondot_aarch64_microkernels.cmake | 2 + gen/neondot_aarch64_microkernels.bzl | 2 + scripts/generate-tests.sh | 1 + src/packing.cc | 59 ++++++ ...b4w-gemm-minmax-1x4c16s2-aarch64-neondot.c | 32 ++++ ...b4w-gemm-minmax-1x8c16s2-aarch64-neondot.c | 32 ++++ src/xnnpack/gemm.h | 34 +++- src/xnnpack/microfnptr.h | 12 ++ src/xnnpack/pack.h | 24 +++ src/xnnpack/packq.h | 31 ++++ test/gemm-microkernel-tester.cc | 175 +++++++++++++++++- test/gemm-microkernel-tester.h | 5 + test/qp8-f32-qb4w-gemm-minmax.cc | 155 ++++++++++++++++ test/qp8-f32-qb4w-gemm-minmax.yaml | 19 ++ 16 files changed, 629 insertions(+), 10 deletions(-) create mode 100644 bench/qp8-f32-qb4w-gemm.cc create mode 100644 src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c create mode 100644 src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c create mode 100644 test/qp8-f32-qb4w-gemm-minmax.cc create mode 100644 test/qp8-f32-qb4w-gemm-minmax.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c56cbe3019..b682be0f020 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1615,6 +1615,7 @@ IF(XNNPACK_BUILD_TESTS) qd8-f32-qc4w-gemm-minmax qd8-f32-qc8w-igemm-minmax qp8-f32-qc4w-gemm-minmax + qp8-f32-qb4w-gemm-minmax qs8-qc8w-gemm-minmax-fp32 qs8-qc8w-igemm-minmax-fp32 qu8-gemm-minmax-fp32 diff --git a/bench/qp8-f32-qb4w-gemm.cc b/bench/qp8-f32-qb4w-gemm.cc new file mode 100644 index 00000000000..193f8ac48f8 --- /dev/null +++ b/bench/qp8-f32-qb4w-gemm.cc @@ -0,0 +1,55 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/qp8-f32-qb4w-gemm-minmax.yaml +// Generator: tools/generate-gemm-test.py + +#include +#include "bench/gemm-benchmark.h" +#include "bench/utils.h" +#include "xnnpack/common.h" +#include "xnnpack/gemm.h" +#include "xnnpack/isa-checks.h" +#include "xnnpack/microfnptr.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/pack.h" +#include "xnnpack/packw.h" + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + static void qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases, + /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot) + + static void qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases, + /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot) + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/cmake/gen/neondot_aarch64_microkernels.cmake b/cmake/gen/neondot_aarch64_microkernels.cmake index e782a394375..e25d17f73c3 100644 --- a/cmake/gen/neondot_aarch64_microkernels.cmake +++ b/cmake/gen/neondot_aarch64_microkernels.cmake @@ -10,6 +10,7 @@ SET(PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS + src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c) SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS @@ -17,6 +18,7 @@ SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c + src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c diff --git a/gen/neondot_aarch64_microkernels.bzl b/gen/neondot_aarch64_microkernels.bzl index 740847dfd76..6d2c41fe275 100644 --- a/gen/neondot_aarch64_microkernels.bzl +++ b/gen/neondot_aarch64_microkernels.bzl @@ -6,6 +6,7 @@ Auto-generated file. Do not edit! """ PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ + "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", ] @@ -14,6 +15,7 @@ NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c", + "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c", diff --git a/scripts/generate-tests.sh b/scripts/generate-tests.sh index 5d295936fc3..29dabdff3cf 100755 --- a/scripts/generate-tests.sh +++ b/scripts/generate-tests.sh @@ -48,6 +48,7 @@ tools/generate-gemm-test.py --spec test/qd8-f32-qc4w-gemm-minmax.yaml --output-t tools/generate-gemm-test.py --spec test/qd8-f32-qb4w-gemm-minmax.yaml --output-test test/qd8-f32-qb4w-gemm-minmax.cc --output-bench bench/qd8-f32-qb4w-gemm.cc & tools/generate-gemm-test.py --spec test/qp8-f32-qc4w-gemm-minmax.yaml --output-test test/qp8-f32-qc4w-gemm-minmax.cc --output-bench bench/qp8-f32-qc4w-gemm.cc & +tools/generate-gemm-test.py --spec test/qp8-f32-qb4w-gemm-minmax.yaml --output-test test/qp8-f32-qb4w-gemm-minmax.cc --output-bench bench/qp8-f32-qb4w-gemm.cc & tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc & diff --git a/src/packing.cc b/src/packing.cc index e584a3f2147..ee737b5f3a3 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -24,6 +24,7 @@ #if XNN_ENABLE_KLEIDIAI #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #endif // XNN_ENABLE_KLEIDIAI #include @@ -1678,6 +1679,64 @@ void xnn_pack_kai_qs4_weights_and_biases( &kai_params); } } + +size_t xnn_packed_stride_kai_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t extra_bytes) { + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + + const size_t kai_num_bytes_sum_rhs = sizeof(float); + const size_t kai_num_bytes_bias = sizeof(float); + // perhaps derive Bf16 from gemm-config? + // This needs to be updated in the kleidi branch to be in header + // return kai_rhs_packed_stride(k, /*nr=*/1, kr, block_size, Bf16); + const size_t num_bytes_multiplier_rhs = sizeof(uint16_t); + const size_t num_blocks_per_row = k/block_size; + const size_t num_bytes_per_block = (block_size / 2) + num_bytes_multiplier_rhs; + return 1 * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +void xnn_pack_kai_qb4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const struct xnn_qs8_qc4w_packing_params* xnn_params = + reinterpret_cast(params); + + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + // no nxk as of now + xnn_log_fatal( + "KleidiAI does not currently have gio packing routine" + ); + } else { + // Repack the packing params. + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params kai_params; + kai_params.lhs_zero_point = xnn_params->input_zero_point; + kai_params.rhs_zero_point = xnn_params->kernel_zero_point; + kai_params.scale_dt = kai_datatype::kai_dt_bf16; + size_t rhs_stride = round_up_po2(input_channels, 2) / 2; + size_t blocks_per_row = (input_channels + block_size - 1) / block_size; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + groups, output_channels, input_channels, nr, kr, sr, + /*bl=*/block_size, + /*rhs=*/reinterpret_cast(weights), + /*rhs_stride=*/rhs_stride, + /*bias=*/reinterpret_cast(extra_data0), + /*scale=*/reinterpret_cast(extra_data1), + /*scale_stride=*/blocks_per_row * sizeof(uint16_t), + /*rhs_packed*/packed_weights_ptr, + /*extra_bytes=*/0, + &kai_params); + } +} #endif // XNN_ENABLE_KLEIDIAI void xnn_pack_f32_qs8w_gemm_gio_w( diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c new file mode 100644 index 00000000000..f505c42b583 --- /dev/null +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "xnnpack/log.h" +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + const struct xnn_f32_qb4w_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + xnn_log_fatal( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`."); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c new file mode 100644 index 00000000000..9fa6c95889b --- /dev/null +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "xnnpack/log.h" +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + const struct xnn_f32_qb4w_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + xnn_log_fatal( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`."); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 469bf0fcf13..29a0a834221 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -2221,6 +2221,17 @@ DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_u DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm) DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128) + +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64) +DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64) + + #define DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t m, \ @@ -2241,16 +2252,21 @@ DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_u DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2) DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x8c16s2__neoni8mm_mstep2) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128) - -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64) -DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64) +#define DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ + XNN_INTERNAL void fn_name( \ + size_t m, \ + size_t n, \ + size_t k, \ + const void* lhs_packed, \ + const void* rhs_packed, \ + float* dst, \ + size_t dst_stride_row, \ + size_t dst_stride_col, \ + const struct xnn_f32_qb4w_minmax_params \ + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot) +DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot) #define DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index a6d41063a59..26e98a22aaf 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -339,6 +339,18 @@ typedef void (*xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn)( union xnn_f32_minmax_params minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +typedef void (*xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn)( + size_t m, + size_t n, + size_t k, + const void* lhs_packed, + const void* rhs_packed, + float* dst, + size_t dst_stride_row, + size_t dst_stride_col, + const struct xnn_f32_qb4w_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); + // GEMMINC: GEMM INCremental with Min+Max activation typedef void (*xnn_f32_gemminc_minmax_ukernel_fn)( diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index a86c84cfb21..5d50d12cb79 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -483,6 +483,30 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( size_t k, // size_t k_stride, // size_t extra_bytes); + +XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t block_size, // + size_t extra_bytes); #endif // XNN_ENABLE_KLEIDIAI XNN_INTERNAL void xnn_pack_qs8_to_qu8_gemm_gio_w( diff --git a/src/xnnpack/packq.h b/src/xnnpack/packq.h index 2328ef07896..abc93399eb6 100644 --- a/src/xnnpack/packq.h +++ b/src/xnnpack/packq.h @@ -86,6 +86,37 @@ XNN_INLINE static int8_t xnn_x8_packq_f32qp8_get_quantized( return *dst_ptr; } +XNN_INLINE static float xnn_x8_packq_f32qp8_get_recip_scale( + size_t m_idx, const int8_t* lhs_packed, size_t k, + size_t mr_packed, size_t kr, size_t sr) { + const size_t k_internal = k_roundedup(k, kr, sr); + const size_t dst_x = (m_idx % mr_packed); + const size_t packed_offset = + xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr); + + // Get the quantization parameters. + const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal; + dst_ptr += dst_x * sizeof(int32_t); + dst_ptr += mr_packed * sizeof(float); + const float recip_scale = *(const float*)dst_ptr; + return recip_scale; +} + +XNN_INLINE static float xnn_x8_packq_f32qp8_get_neg_nudged_zp( + size_t m_idx, const int8_t* lhs_packed, size_t k, + size_t mr_packed, size_t kr, size_t sr) { + const size_t k_internal = k_roundedup(k, kr, sr); + const size_t dst_x = (m_idx % mr_packed); + const size_t packed_offset = + xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr); + + // Get the quantization parameters. + const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal; + dst_ptr += dst_x * sizeof(int32_t); + const int32_t neg_nudged_zero_point = *(const int32_t*)dst_ptr; + return neg_nudged_zero_point; +} + XNN_INLINE static float xnn_x8_packq_f32qp8_get_dequantized( size_t m_idx, size_t k_idx, const int8_t* lhs_packed, size_t k, size_t mr_packed, size_t kr, size_t sr) { diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 0fbf14d8c8e..929c4a98381 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -3,6 +3,7 @@ #include #include +#include "kai/kai_common.h" #include #include #include @@ -1838,6 +1839,179 @@ void GemmMicrokernelTester::Test( } } +void GemmMicrokernelTester::Test( + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_qb4w_minmax_params_fn init_minmax_params, + xnn_pack_weights_and_biases_fn pack, + xnn_packed_stride_weights_and_biases_fn packed_stride){ + ASSERT_LE(m(), mr()); + + xnnpack::ReplicableRandomDevice rng; + auto f32rng = std::bind(std::uniform_real_distribution(-5.f, 5.f), + std::ref(rng)); + auto scalerng = std::bind(std::uniform_real_distribution(0.5f, 2.f), + std::ref(rng)); + auto w8rng = std::bind(std::uniform_int_distribution( + 0, std::numeric_limits::max()), + std::ref(rng)); + + const size_t k2 = round_up_po2(k(), 2); // tester assumes byte aligned rows + + const size_t packed_k2 = round_up_po2(k(), kr() * sr()); // 2 blocks for nibbles + const size_t packed_k_bytes = (packed_k2 + 1)/ 2; + const size_t num_blocks = packed_k2 / bl(); + + std::vector input_f32(m() * k2); + std::vector b(n() * k2 / 2); + std::vector bias(n(), 0.0f); + std::vector kernel_scale2d(n() * packed_k2 / bl()); + std::vector c((mr() - 1) * cm_stride() + + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + std::vector acc(m() * n()); + std::vector c_ref(m() * n(), 0); + + // Create a fake `gemm_config` for the packing functions. + struct xnn_gemm_config gemm_config; + gemm_config.mr = static_cast(mr()); + gemm_config.mr_packed = static_cast(mr_packed()); + gemm_config.nr = static_cast(nr()); + gemm_config.log2_kr = static_cast(31 - math_clz_nonzero_u32(kr())); + gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); + + const size_t packed_w_stride = + packed_stride(&gemm_config, k2, /*k_stride=*/bl(), /*extra_bytes=*/0); + const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); + std::vector> packed_w(packed_w_size); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input_f32.begin(), input_f32.end(), std::ref(f32rng)); + std::generate(b.begin(), b.end(), std::ref(w8rng)); + std::generate(bias.begin(), bias.end(), std::ref(f32rng)); + std::generate(kernel_scale2d.begin(), kernel_scale2d.end(), [&]() { return math_cvt_bf16_fp32(scalerng()); }); + std::fill(c.begin(), c.end(), nanf("")); + std::fill(packed_w.begin(), packed_w.end(), 0); + + + // Quantize the left-hand operand. + const size_t input_packed_size = + xnn_x8_packq_f32qp8_packed_size(m(), k2, mr_packed(), kr(), sr()); + std::vector input_qp8(input_packed_size); + xnn_x8_packq_f32qp8_ukernel__scalar_u1( + m(), k2, mr_packed(), kr(), sr(), + /*m_idx_start=*/0, reinterpret_cast(input_f32.data()), + /*lhs_stride=*/k2 * sizeof(float), + input_qp8.data() + ); + + + // RHS packing. + struct xnn_qs8_qc4w_packing_params params; + params.input_zero_point = 1; + params.kernel_zero_point = b_zero_point(); + pack(/*flags=*/0, &gemm_config, k2, n(), + /*groups=*/1, /*k_stride=*/bl(), + /*accumulator_init=*/nullptr, + /*weights=*/b.data(), + /*int_extra_data0_fn=*/nullptr, + /*extra_data0=*/nullptr, + /*extra_data0_size=*/0, + /*init_extra_data1_fn=*/ + nullptr, + /*extra_data1=*/kernel_scale2d.data(), + /*extra_data1_size=*/sizeof(float), + /*packed_weights_ptr=*/packed_w.data(), ¶ms); + + size_t stride = nr() * (packed_k_bytes + /* scales= */ num_blocks * sizeof(uint16_t) + /* ksum= */ sizeof(float) + /* bias= */ sizeof(float)); + size_t block_stride = (bl() / 2 + sizeof(uint16_t)) * nr(); + size_t start_offset = nr() * (packed_k_bytes / num_blocks + sizeof(float)); + uintptr_t start = (uintptr_t) packed_w.data() + stride - sizeof(float) * nr(); + + xnn_init_qs8_qc8w_scale_fp32_params( + n(), nr(), nr(), + stride, + stride, + 0, + bias.data(), + (void*) start); + + // Compute 32-bit results and output quantization arguments. + std::fill(c_ref.begin(), c_ref.end(), 0); + for (size_t m_index = 0; m_index < m(); m_index++) { + for (size_t n_index = 0; n_index < n(); n_index++) { + float kfsum = 0.0; + for (size_t bl_index=0; bl_index < num_blocks; ++bl_index) { + int32_t ksum = 0; + int32_t c_ref_acc = 0; + for (size_t kr_index = 0; kr_index < bl(); kr_index++) { + const size_t k_index = bl_index * bl() + kr_index; + const size_t nb_index = (n_index * k2 + k_index) / 2; + const int32_t bv = int32_t((k_index % 2 == 0) ? (b[nb_index] & UINT8_C(0xF)) : (b[nb_index] >> 4)) - b_zero_point(); + ksum += bv; + c_ref_acc += int32_t(xnn_x8_packq_f32qp8_get_quantized(m_index, k_index, input_qp8.data(), + k2, mr_packed(), kr(), sr())) * int32_t(bv); + } + size_t scale_index = n_index * num_blocks + bl_index; + float scale = math_cvt_fp32_bf16(kernel_scale2d[scale_index]); + c_ref[m_index * n() + n_index] += c_ref_acc * scale; + kfsum += scale * ksum; + } + float inv_scale = xnn_x8_packq_f32qp8_get_recip_scale(m_index, input_qp8.data(), k2, mr_packed(), kr(), sr()); + int32_t neg_nudged_zero_point= xnn_x8_packq_f32qp8_get_neg_nudged_zp(m_index, input_qp8.data(), k2, mr_packed(), kr(), sr()); + c_ref[m_index * n() + n_index] += (neg_nudged_zero_point * kfsum); + c_ref[m_index * n() + n_index] *= inv_scale; + c_ref[m_index * n() + n_index] += bias[n_index]; + } + } + + const float accumulated_min = + *std::min_element(c_ref.cbegin(), c_ref.cend()); + const float accumulated_max = + *std::max_element(c_ref.cbegin(), c_ref.cend()); + const float c_min = + qmin() == std::numeric_limits::min() + ? -std::numeric_limits::infinity() + : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * + static_cast(qmin()); + const float c_max = + qmax() == std::numeric_limits::max() + ? std::numeric_limits::infinity() + : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * + static_cast(255 - qmax()); + + // Prepare parameters. + xnn_f32_qb4w_minmax_params minmax_params; + init_minmax_params(&minmax_params, c_min, c_max, 8, bl()); + + for (size_t m_index = 0; m_index < m(); m_index++) { + for (size_t n_index = 0; n_index < n(); n_index++) { + c_ref[m_index * n() + n_index] = + std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min); + } + } + + gemm(m(), n(), k2, input_qp8.data(), packed_w.data(), c.data(), + cm_stride() * sizeof(float), sizeof(float), &minmax_params); + + for (size_t i = 0; i < m(); i++) { + for (size_t j = 0; j < n(); j++) { + // Extract tolerance into variable to workaround test failures on Linux + // AArch64. + const float tolerance = + std::max(1.1e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); + ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c_ref[i * n() + j], tolerance) + << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] + << " (accumulator = " << acc[i * n() + j] << "), optimized = " + << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k2 + << ", cn_stride = " << cn_stride() + << ", cm_stride = " << cm_stride(); + } + } + } +} + void GemmMicrokernelTester::Test( xnn_qs8_gemm_minmax_ukernel_fn gemm, xnn_init_qs8_conv_minmax_params_fn init_params, @@ -3499,4 +3673,3 @@ void GemmMicrokernelTester::Test( } } } - diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index 64d4702062f..f402a590c26 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -398,6 +398,11 @@ class GemmMicrokernelTester { xnn_pack_weights_and_biases_fn pack, xnn_packed_stride_weights_and_biases_fn packed_stride); + void Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_qb4w_minmax_params_fn init_minmax_params, + xnn_pack_weights_and_biases_fn pack, + xnn_packed_stride_weights_and_biases_fn packed_stride); + private: size_t mr_{1}; size_t nr_{1}; diff --git a/test/qp8-f32-qb4w-gemm-minmax.cc b/test/qp8-f32-qb4w-gemm-minmax.cc new file mode 100644 index 00000000000..2dc63d41508 --- /dev/null +++ b/test/qp8-f32-qb4w-gemm-minmax.cc @@ -0,0 +1,155 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/qp8-f32-qb4w-gemm-minmax.yaml +// Generator: tools/generate-gemm-test.py + +#include +#include +#include +#include + +#include +#include "xnnpack/allocator.h" +#include "xnnpack/common.h" +#include "xnnpack/gemm.h" +#include "xnnpack/igemm.h" +#include "xnnpack/isa-checks.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/pack.h" +#include "xnnpack/packw.h" +#include "xnnpack/ppmm.h" +#include "xnnpack/requantization.h" +#include "gemm-microkernel-tester.h" +#include "next_prime.h" + +namespace { + +std::vector CreateTests1( + size_t k_block, size_t adj_k_block, + size_t mr, size_t nr, size_t kr, size_t sr, + size_t mr_packed, + bool is_igemm, + std::function test_func, + std::function isa_check = nullptr) { + std::string kbs = std::to_string(k_block); + std::string kb2s = std::to_string(k_block * 2); + std::string akbs = std::to_string(adj_k_block); + std::string nrs = std::to_string(nr); + + const GemmMicrokernelTester tester = GemmMicrokernelTester() + .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed); + + std::vector gemm_tests; + gemm_tests.reserve(42); + + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs, + tester.clone() + .m(mr).n(nr).k(k_block) + .b_zero_point(8) + .bl(32) + , test_func, isa_check)); + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_strided_a", + tester.clone() + .m(mr).n(nr).k(k_block) + .a_stride(xnnpack::NextPrime(k_block + 1)) + .b_zero_point(8) + .bl(32) + , test_func, isa_check)); + } + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile", + tester.clone() + .k(k_block).iterations(1) + .b_zero_point(8) + .bl(32) + , test_func, isa_check) + .loop_n(1, nr) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_m", + tester.clone() + .n(nr).k(k_block).iterations(1) + .b_zero_point(8) + .bl(32) + , test_func, isa_check) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_n", + tester.clone() + .m(mr).k(k_block).iterations(1) + .b_zero_point(8) + .bl(32) + , test_func, isa_check) + .loop_n(1, nr)); + gemm_tests.push_back(GemmTestParams( + "bl", + tester.clone() + .m(mr).n(nr).k(k_block * 12) + .b_zero_point(8) + , test_func, isa_check) + .loop_k(k_block, k_block * 12, k_block, LoopStepType::Linear) + .loop_bl(32, k_block * 32, 32)); + + return gemm_tests; +} + +} // namespace + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QB4W_GEMM_MINMAX_1X4C16S2__AARCH64_NEONDOT, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/32, + /*adj_k_block=*/32, + /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/1, + /*is_igemm=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QB4W_GEMM_MINMAX_1X8C16S2__AARCH64_NEONDOT, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/32, + /*adj_k_block=*/32, + /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/1, + /*is_igemm=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 diff --git a/test/qp8-f32-qb4w-gemm-minmax.yaml b/test/qp8-f32-qb4w-gemm-minmax.yaml new file mode 100644 index 00000000000..5c7c3b41d1d --- /dev/null +++ b/test/qp8-f32-qb4w-gemm-minmax.yaml @@ -0,0 +1,19 @@ +# Copyright 2023 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Arm KleidiAI kernels +- name: xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot + init: xnn_init_f32_qb4w_minmax_scalar_params + pack: xnn_pack_kai_qb4_weights_and_biases + packed-stride: xnn_packed_stride_kai_qb4_weights_and_biases + k-block: 32 + cpp-check: XNN_ENABLE_KLEIDIAI + +- name: xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot + init: xnn_init_f32_qb4w_minmax_scalar_params + pack: xnn_pack_kai_qb4_weights_and_biases + packed-stride: xnn_packed_stride_kai_qb4_weights_and_biases + k-block: 32 + cpp-check: XNN_ENABLE_KLEIDIAI From 4649ee20279446fb7c68c5f1d6b5a94af15a86c5 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 3 Sep 2024 17:37:34 -0700 Subject: [PATCH 03/17] Wired QP8_QB4W Operator APIs --- CMakeLists.txt | 1 + src/configs/gemm-config.c | 38 +++++ src/enums/operator-type.c | 13 +- src/enums/operator-type.yaml | 2 + src/operator-run.c | 27 +++ src/operators/fully-connected-nc.c | 219 +++++++++++++++++++++++-- src/xnnpack/compute.h | 11 ++ src/xnnpack/config.h | 10 ++ src/xnnpack/internal.h | 27 +++ src/xnnpack/microfnptr.h | 6 +- src/xnnpack/operator-type.h | 1 + test/fully-connected-nc.cc | 151 ++++++++++++++++- test/fully-connected-operator-tester.h | 211 ++++++++++++++++++++++++ 13 files changed, 696 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b682be0f020..6a7c7a2358f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1269,6 +1269,7 @@ IF(XNNPACK_BUILD_TESTS) IF(XNNPACK_BUILD_LIBRARY) # ---[ Launch heavy tests first. SET(LIBRARY_SHARDED_TESTS + fully-connected-nc batch-matrix-multiply-nc batch-matrix-multiply deconvolution-nhwc diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 6561a32b260..1f5c732c04f 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -40,6 +40,7 @@ static struct xnn_gemm_config qd8_f32_qb4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f32_qc4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f32_qc8w_gemm_config = {0}; static struct xnn_gemm_config qp8_f32_qc4w_gemm_config = {0}; +static struct xnn_gemm_config qp8_f32_qb4w_gemm_config = {0}; static struct xnn_gemm_config qs8_qc8w_gemm_config = {0}; static struct xnn_gemm_config qu8_gemm_config = {0}; @@ -55,6 +56,7 @@ XNN_INIT_ONCE_GUARD(qd8_f32_qb4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f32_qc4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f32_qc8w_gemm); XNN_INIT_ONCE_GUARD(qp8_f32_qc4w_gemm); +XNN_INIT_ONCE_GUARD(qp8_f32_qb4w_gemm); XNN_INIT_ONCE_GUARD(qs8_qc8w_gemm); XNN_INIT_ONCE_GUARD(qu8_gemm); @@ -1731,6 +1733,28 @@ static void init_qp8_f32_qc4w_gemm_config(void) { #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } +static void init_qp8_f32_qb4w_gemm_config(void) { +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { +#if XNN_ENABLE_ARM_DOTPROD + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot); + qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.mr = 1; + qp8_f32_qb4w_gemm_config.nr = 8; + qp8_f32_qb4w_gemm_config.log2_kr = 4; + qp8_f32_qb4w_gemm_config.log2_sr = 1; + qp8_f32_qb4w_gemm_config.planes = 2; + qp8_f32_qb4w_gemm_config.mr_packed = 1; +#endif // XNN_ENABLE_ARM_DOTPROD + } +#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI +} + static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; @@ -3880,6 +3904,20 @@ XNN_INIT_ONCE(qp8_f32_qc4w_gemm); return NULL; } +const struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() { + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } +XNN_INIT_ONCE(qp8_f32_qb4w_gemm); + // Only return the config pointer if it actually provides a kernel. + if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) { + return &qp8_f32_qb4w_gemm_config; + } + return NULL; +} + const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { diff --git a/src/enums/operator-type.c b/src/enums/operator-type.c index a1d9acaf394..23b6e2e6cca 100644 --- a/src/enums/operator-type.c +++ b/src/enums/operator-type.c @@ -12,16 +12,16 @@ #include "xnnpack/operator-type.h" -static const uint16_t offset[171] = { +static const uint16_t offset[172] = { 0, 8, 22, 36, 50, 64, 78, 92, 119, 147, 175, 203, 230, 257, 289, 321, 364, 382, 400, 425, 451, 467, 483, 498, 513, 535, 558, 581, 604, 627, 650, 673, 696, 719, 742, 760, 783, 806, 830, 848, 871, 895, 919, 943, 967, 1002, 1037, 1061, 1085, 1109, 1123, 1138, 1153, 1173, 1199, 1225, 1262, 1288, 1318, 1344, 1376, 1408, 1434, 1461, 1488, 1505, 1522, 1556, 1590, 1604, 1618, 1632, 1646, 1662, 1678, 1704, 1730, 1762, 1794, 1831, 1868, 1905, 1942, 1979, 2016, 2053, - 2079, 2111, 2137, 2152, 2186, 2220, 2254, 2288, 2322, 2356, 2386, 2416, 2436, 2456, 2477, 2498, 2519, 2540, 2554, - 2578, 2602, 2625, 2648, 2666, 2684, 2699, 2714, 2729, 2744, 2762, 2780, 2799, 2818, 2837, 2856, 2875, 2892, 2909, - 2925, 2941, 2974, 3007, 3035, 3063, 3091, 3119, 3146, 3173, 3190, 3207, 3248, 3289, 3307, 3325, 3343, 3361, 3376, - 3392, 3408, 3426, 3444, 3462, 3488, 3515, 3542, 3559, 3576, 3598, 3620, 3649, 3678, 3697, 3716, 3735, 3754, 3769, - 3784, 3799, 3814, 3833, 3853, 3873, 3893, 3914, 3935 + 2090, 2116, 2148, 2174, 2189, 2223, 2257, 2291, 2325, 2359, 2393, 2423, 2453, 2473, 2493, 2514, 2535, 2556, 2577, + 2591, 2615, 2639, 2662, 2685, 2703, 2721, 2736, 2751, 2766, 2781, 2799, 2817, 2836, 2855, 2874, 2893, 2912, 2929, + 2946, 2962, 2978, 3011, 3044, 3072, 3100, 3128, 3156, 3183, 3210, 3227, 3244, 3285, 3326, 3344, 3362, 3380, 3398, + 3413, 3429, 3445, 3463, 3481, 3499, 3525, 3552, 3579, 3596, 3613, 3635, 3657, 3686, 3715, 3734, 3753, 3772, 3791, + 3806, 3821, 3836, 3851, 3870, 3890, 3910, 3930, 3951, 3972 }; static const char data[] = @@ -110,6 +110,7 @@ static const char data[] = "Fully Connected (NC, QD8, F32, QC4W)\0" "Fully Connected (NC, QD8, F32, QC8W)\0" "Fully Connected (NC, QP8, F32, QC4W)\0" + "Fully Connected (NC, QP8, F32, QB4W)\0" "Fully Connected (NC, QS8)\0" "Fully Connected (NC, QS8, QC8W)\0" "Fully Connected (NC, QU8)\0" diff --git a/src/enums/operator-type.yaml b/src/enums/operator-type.yaml index 8a2741526ea..2e79e3e7ec9 100644 --- a/src/enums/operator-type.yaml +++ b/src/enums/operator-type.yaml @@ -175,6 +175,8 @@ string: "Fully Connected (NC, QD8, F32, QC8W)" - name: xnn_operator_type_fully_connected_nc_qp8_f32_qc4w string: "Fully Connected (NC, QP8, F32, QC4W)" +- name: xnn_operator_type_fully_connected_nc_qp8_f32_qb4w + string: "Fully Connected (NC, QP8, F32, QB4W)" - name: xnn_operator_type_fully_connected_nc_qs8 string: "Fully Connected (NC, QS8)" - name: xnn_operator_type_fully_connected_nc_qs8_qc8w diff --git a/src/operator-run.c b/src/operator-run.c index 24aaafcd1bb..850de62f8a2 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -529,6 +529,33 @@ void xnn_compute_qp8gemm( nr_block_start, mr_block_size, nr_block_size); } +void xnn_compute_hmp_qp8gemm_bl( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size) { + const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( + mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); + const size_t cm_stride = context->cm_stride; + + context->qp8_bl_ukernel.function[uarch_index]( + mr_block_size, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + a_offset), + (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride), + (void*)((uintptr_t)context->c + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize)), + cm_stride, + /*dst_stride_col=*/sizeof(float), context->fused_params); +} + +void xnn_compute_qp8gemm_bl( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, + size_t nr_block_size) { + xnn_compute_hmp_qp8gemm_bl(context, XNN_UARCH_DEFAULT, mr_block_start, + nr_block_start, mr_block_size, nr_block_size); +} + void xnn_compute_hmp_dqgemm_bl( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index fe5a4f4d26f..045b026d0eb 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -154,7 +154,7 @@ static enum xnn_status create_fully_connected_nc( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, input_channels, k_stride, extra_weights_bytes) + gemm_config, input_channels, block_wise ? block_size : k_stride, extra_weights_bytes) : (k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes + block_scale_bytes; const size_t packed_weights_size = n_stride * weights_stride; @@ -190,7 +190,8 @@ static enum xnn_status create_fully_connected_nc( if (gemm_config->pack_weights_and_biases) { gemm_config->pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, - /*groups=*/1, k_stride, + /*groups=*/1, + block_wise ? block_size : k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, /*int_extra_data0_fn=*/(xnn_init_scale_params_fn)init_scale_params, @@ -198,10 +199,20 @@ static enum xnn_status create_fully_connected_nc( /*extra_data0_size=*/init_scale_params != NULL ? sizeof(float) : 0, /*init_extra_data1_fn=*/ (xnn_init_scale_params_fn)init_kernel_scale_params, - /*extra_data1=*/kernel_scale_params, + /*extra_data1=*/block_wise ? (void *) blockwise_kernel_scale_params : (void *) kernel_scale_params, /*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float) : 0, /*packed_weights_ptr=*/weights_ptr, packing_params); + + if (block_wise && bias != NULL) { + void* weights_start = (void*) ((uintptr_t) weights_ptr + + gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, gemm_config->nr, gemm_config->nr, + gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, + bias, weights_start); + } } else { if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { pack_gemm_gio_w( @@ -822,6 +833,130 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( /*weights_cache=*/weights_cache, fully_connected_op_out); } +enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t block_size, + uint8_t kernel_zero_point, + const uint16_t* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + if (isnan(output_min)) { + xnn_log_error( + "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w)); + return xnn_status_invalid_parameter; + } + + if (isnan(output_max)) { + xnn_log_error( + "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w)); + return xnn_status_invalid_parameter; + } + + if (output_min > output_max) { + xnn_log_error( + "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w), output_min, output_max); + return xnn_status_invalid_parameter; + } + + const struct xnn_gemm_config* gemm_config = xnn_init_qp8_f32_qb4w_gemm_config(); + if (gemm_config == NULL) { + xnn_log_error("failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w)); + return xnn_status_unsupported_hardware; + } + + const struct gemm_fused_ukernels* gemm_ukernels = &gemm_config->minmax; + const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max); + if (linear_activation && gemm_config->linear.gemm[gemm_config->mr-1].function[XNN_UARCH_DEFAULT] != NULL) { + gemm_ukernels = &gemm_config->linear; + } + + if (block_size < XNN_MIN_BLOCKSIZE || block_size % XNN_MIN_BLOCKSIZE != 0) { + xnn_log_error( + "failed to create %s operator with block_size: %zu: expecting block_size to be a multiple of %d.", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w), block_size, XNN_MIN_BLOCKSIZE); + return xnn_status_invalid_parameter; + } + + if (input_channels % block_size != 0) { + xnn_log_error( + "failed to create %s operator with input_channels: %zu, and block_size: %zu: expecting input_channels %% block_size == 0.", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w), input_channels, block_size); + return xnn_status_invalid_parameter; + } + + if (kernel_zero_point != 8) { + xnn_log_error( + "failed to create %s operator with %" PRIu8 " kernel zero point: kernel zero point must be equal to 8 " + "(unsigned weights) or 0 (signed weights)", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w), kernel_zero_point); + return xnn_status_invalid_parameter; + } + // Assuming kernel_scale.size() is output_channels * num_blocks. + size_t num_blocks = input_channels / block_size; + for (size_t output_channel = 0; output_channel < output_channels; output_channel++) { + for(size_t block_index=0; block_index < num_blocks; block_index++) { + size_t scale_index = output_channel * num_blocks + block_index; + float fp32_scale = math_cvt_fp32_bf16(kernel_scale[scale_index]); + if (fp32_scale <= 0.0f || !isnormal(fp32_scale)) { + xnn_log_error( + "failed to create %s operator with %.7g kernel scale in output channel #%zu, block #%zu: scale must be finite and positive", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w), + fp32_scale, output_channel, block_index); + return xnn_status_invalid_parameter; + } + } + } + + struct xnn_f32_qb4w_minmax_params params; + if XNN_LIKELY(gemm_config->init.f32_qb4w != NULL) { + gemm_config->init.f32_qb4w(¶ms, output_min, output_max, kernel_zero_point, block_size); + } + + // We don't know input zero point until runtime, row sum is multiplied by it during packing, so set it to 1. + const struct xnn_qs8_qc4w_packing_params packing_params = { /*input_zero_point=*/1, kernel_zero_point }; + + return create_fully_connected_nc( + input_channels, output_channels, + input_stride, output_stride, + kernel, bias, flags, + /*block_size=*/block_size, + /*extra_bl_bytes=*/sizeof(uint16_t), + /*blockwise_kernel_scale_params=*/kernel_scale, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/true, + /*bias_element_size=*/sizeof(float), + /*pack_gemm_gio_w,=*/ NULL, + /*pack_gemm_goi_w=*/ NULL, + /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, + &packing_params, + /*packed_weights_padding_byte=*/0, + /*extra_weights_bytes=*/0, + /*init_scale_params=*/NULL, + /*scale_params=*/NULL, + /*init_kernel_scale_params=*/NULL, + /*kernel_scale_params=*/NULL, + ¶ms, sizeof(params), + gemm_config, gemm_ukernels, + xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, + /*weights_cache=*/weights_cache, + fully_connected_op_out); +} + enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( size_t input_channels, size_t output_channels, @@ -1774,7 +1909,14 @@ static enum xnn_status reshape_fully_connected_nc( } const bool is_qp8_ukernel = fully_connected_op->type == - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w; + xnn_operator_type_fully_connected_nc_qp8_f32_qc4w || + fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qp8_f32_qb4w; + + const bool is_blockwise_kernel = fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qd8_f32_qb4w || + fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qp8_f32_qb4w; fully_connected_op->context.gemm.gemm.gemm = (struct gemm_context){ .k_scaled = input_channels << log2_input_element_size, @@ -1813,20 +1955,38 @@ static enum xnn_status reshape_fully_connected_nc( if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch; if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm; + } } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = - (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; + } } else { fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; } } else { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + } } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + } } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } @@ -1834,10 +1994,19 @@ static enum xnn_status reshape_fully_connected_nc( #else fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + } } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + if (is_blockwise_kernel) { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm_bl; + } else { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + } } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } @@ -2061,6 +2230,23 @@ enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc4w( sizeof(fully_connected_op->params.f32_minmax), threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qb4w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool) { + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, + batch_size, + /*log2_input_element_size=*/0, + // Pass 1 byte even though it is half byte, we handle the division via + // filter_is_nibble == true. + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/true, + /*dynamic_quantization=*/false, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &fully_connected_op->params.f32_qb4w_minmax, + sizeof(fully_connected_op->params.f32_qb4w_minmax), threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qs8( xnn_operator_t fully_connected_op, size_t batch_size, @@ -2270,6 +2456,13 @@ enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( input, output, /*quantization_params=*/NULL); } +enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qb4w( + xnn_operator_t fully_connected_op, const int8_t* input, float* output) { + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, + input, output, /*quantization_params=*/NULL); +} + enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w( xnn_operator_t fully_connected_op, const int8_t* input, diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index a15512a8011..7bd9c8dda72 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -326,6 +326,7 @@ struct gemm_context { struct xnn_hmp_dqgemm_ukernel dq_ukernel; struct xnn_hmp_qp8gemm_ukernel qp8_ukernel; struct xnn_hmp_dqgemm_bl_ukernel dq_bl_ukernel; + struct xnn_hmp_qp8gemm_bl_ukernel qp8_bl_ukernel; }; // Parameters for dynamically quantized inputs. const struct xnn_qd8_quantization_params* quantization_params; @@ -377,6 +378,11 @@ struct gemm_context { size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); + XNN_PRIVATE void xnn_compute_qp8gemm_bl( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, + size_t nr_block_size); + #if XNN_MAX_UARCH_TYPES > 1 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], @@ -415,6 +421,11 @@ struct gemm_context { size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); + + XNN_PRIVATE void xnn_compute_hmp_qp8gemm_bl( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); #endif // XNN_MAX_UARCH_TYPES > 1 #endif diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 340c60b6cc2..59eaf94ced5 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -194,6 +194,15 @@ static inline struct xnn_hmp_qp8gemm_ukernel xnn_init_hmp_qp8gemm_ukernel( return ukernel; } +static inline struct xnn_hmp_qp8gemm_bl_ukernel xnn_init_hmp_qp8gemm_bl_ukernel( + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn function) { + struct xnn_hmp_qp8gemm_bl_ukernel ukernel = {{function}}; + for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { + ukernel.function[i] = function; + } + return ukernel; +} + static inline struct xnn_hmp_gemm_ukernel xnn_init_hmp_gemm_ukernel(xnn_gemm_ukernel_fn function) { struct xnn_hmp_gemm_ukernel ukernel = {{ function }}; for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { @@ -248,6 +257,7 @@ XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qb4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qc4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qp8_f32_qc4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qu8_gemm_config(); diff --git a/src/xnnpack/internal.h b/src/xnnpack/internal.h index bd6059b7d9c..7f7f1550f30 100644 --- a/src/xnnpack/internal.h +++ b/src/xnnpack/internal.h @@ -60,6 +60,33 @@ enum xnn_status xnn_setup_convert_nc_f32_qp8(xnn_operator_t convert_op, // const float* input, // int8_t* output); +enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( + size_t input_channels, // + size_t output_channels, // + size_t input_stride, // + size_t output_stride, // + size_t block_size, // + uint8_t kernel_zero_point, // + const uint16_t* kernel_scale, // + const void* kernel, // + const float* bias, // + float output_min, // + float output_max, // + uint32_t flags, // + xnn_code_cache_t code_cache, // + xnn_weights_cache_t weights_cache, // + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qb4w( + xnn_operator_t fully_connected_op, // + const int8_t* input, // + float* output); + +enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qb4w( + xnn_operator_t fully_connected_op, // + size_t batch_size, // + pthreadpool_t threadpool); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 26e98a22aaf..6f3bcacd3f8 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -3003,6 +3003,10 @@ struct xnn_hmp_qp8gemm_ukernel { xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn function[XNN_MAX_UARCH_TYPES]; }; +struct xnn_hmp_qp8gemm_bl_ukernel { + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn function[XNN_MAX_UARCH_TYPES]; +}; + // Largest GEMM/IGEMM MR used in init.c is 16 (x86 AVX512AMX). // Largest GEMM/IGEMM MR is 8 in e2e benchmarks. #define XNN_MAX_MR 16 @@ -3013,10 +3017,10 @@ struct gemm_fused_ukernels { struct xnn_hmp_dqgemm_ukernel dqgemm[XNN_MAX_MR]; struct xnn_hmp_qp8gemm_ukernel qp8gemm[XNN_MAX_MR]; struct xnn_hmp_dqgemm_bl_ukernel dqgemm_bl[XNN_MAX_MR]; + struct xnn_hmp_qp8gemm_bl_ukernel qp8gemm_bl[XNN_MAX_MR]; }; union { struct xnn_hmp_igemm_ukernel igemm[XNN_MAX_MR]; struct xnn_hmp_dqigemm_ukernel dqigemm[XNN_MAX_MR]; }; }; - diff --git a/src/xnnpack/operator-type.h b/src/xnnpack/operator-type.h index 0a6749d1667..53348bd4b9d 100644 --- a/src/xnnpack/operator-type.h +++ b/src/xnnpack/operator-type.h @@ -102,6 +102,7 @@ enum xnn_operator_type { xnn_operator_type_fully_connected_nc_qd8_f32_qc4w, xnn_operator_type_fully_connected_nc_qd8_f32_qc8w, xnn_operator_type_fully_connected_nc_qp8_f32_qc4w, + xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, xnn_operator_type_fully_connected_nc_qs8, xnn_operator_type_fully_connected_nc_qs8_qc8w, xnn_operator_type_fully_connected_nc_qu8, diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc index 88aa9b9ddac..e2c7218bdbf 100644 --- a/test/fully-connected-nc.cc +++ b/test/fully-connected-nc.cc @@ -2070,8 +2070,157 @@ TEST(FULLY_CONNECTED_NC_QD8_F16_QB4W, bl_no_bias) { .input_channels(ic) .block_size(bs) .kernel_zero_point(8) - .iterations(3) + .iterations(1) .TestQD8F16QB4W(); + } } +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, BL_BIAS) { + for (size_t ic=32; ic<=256; ic*=2){ + for (size_t bs=32; bs<=ic; bs=bs*2) { + FullyConnectedOperatorTester() + .has_bias(true) + .batch_size(1) + .output_channels(16) + .input_channels(ic) + .block_size(bs) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); } } } + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, BL_NO_BIAS) { + for (size_t ic=32; ic<=256; ic*=2){ + for (size_t bs=32; bs<=ic; bs=bs*2) { + FullyConnectedOperatorTester() + .has_bias(false) + .batch_size(1) + .output_channels(16) + .input_channels(ic) + .block_size(bs) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); + } + } +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(32) + .output_channels(19) + .block_size(32) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch_with_qmin) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(32) + .output_channels(19) + .block_size(32) + .kernel_zero_point(8) + .qmin(128) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch_with_qmax) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(32) + .output_channels(19) + .block_size(32) + .kernel_zero_point(8) + .qmax(128) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(32) + .block_size(32) + .input_stride(38) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(32) + .block_size(32) + .output_channels(19) + .kernel_zero_point(8) + .output_stride(29) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(32) + .block_size(32) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch_with_qmin) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(32) + .block_size(32) + .output_channels(19) + .kernel_zero_point(8) + .qmin(128) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch_with_qmax) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(32) + .block_size(32) + .output_channels(19) + .kernel_zero_point(8) + .qmax(128) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(32) + .block_size(32) + .input_stride(38) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(32) + .block_size(32) + .output_channels(19) + .kernel_zero_point(8) + .output_stride(29) + .iterations(3) + .TestQP8F32QB4W(); +} diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index 5fa9c7e570c..c965961081a 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -1217,6 +1217,217 @@ class FullyConnectedOperatorTester { } } + void TestQP8F32QB4W() const { + // Get the parameters of this GEMM, skip if not available. + const struct xnn_gemm_config* gemm_config = + xnn_init_qp8_f32_qb4w_gemm_config(); + if (gemm_config == nullptr) { + GTEST_SKIP(); + } + + // Note that the microkernel will force `mr` to 1 if `mc` is 1, so we have + // to anticipate that when packing the left-hand operand. + const uint32_t mr_packed = batch_size() > 1 ? gemm_config->mr_packed : 1; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + + ASSERT_EQ(weights_type(), WeightsType::Default); + + xnnpack::ReplicableRandomDevice rng; + std::uniform_real_distribution f32dist(-1.f, 1.f); + std::uniform_real_distribution f32idist(0.5f, 2.0f); + std::uniform_int_distribution w8dist( + std::numeric_limits::min(), std::numeric_limits::max()); + + const size_t k2 = + round_up_po2(input_channels(), 2); // tester assumes byte aligned rows + + std::vector input(XNN_EXTRA_BYTES / sizeof(float) + + (batch_size() - 1) * input_stride() + + input_channels()); + const size_t kernel_stride = calc_kernel_stride(); + std::vector kernel((output_channels()) * + kernel_stride); + std::vector bias(output_channels()); + std::vector output((batch_size() - 1) * output_stride() + + output_channels()); + std::vector output_ref(batch_size() * output_channels()); + size_t num_blocks = k2 / block_size(); + std::vector kernel_scale2d(output_channels() * num_blocks); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::generate(kernel.begin(), kernel.end(), + [&]() { return w8dist(rng); }); + std::generate(kernel_scale2d.begin(), kernel_scale2d.end(), [&]() { return math_cvt_bf16_fp32(f32idist(rng)); }); + std::fill(output.begin(), output.end(), nanf("")); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); + + // Quantize the left-hand operand. + const size_t input_packed_size = + xnn_x8_packq_f32qp8_packed_size(batch_size(), k2, mr_packed, kr, sr); + std::vector input_qp8(input_packed_size); + xnn_x8_packq_f32qp8_ukernel__scalar_u1( + batch_size(), k2, mr_packed, kr, sr, + /*m_idx_start=*/0, input.data(), + /*lhs_stride=*/k2 * sizeof(float), input_qp8.data()); + + // Compute reference results, without renormalization. + std::fill(output_ref.begin(), output_ref.end(), 0); + + for (size_t mi = 0; mi < batch_size(); mi++) { + for (size_t ni = 0; ni < output_channels(); ni++) { + float kfsum = 0.0; + for (size_t bi = 0; bi < num_blocks; bi++){ + int32_t ksum = 0; + int32_t c_ref_acc = 0; + for (size_t ki = 0; ki < block_size(); ki++) { + const size_t k_index = bi * block_size() + ki; + const size_t nb_index = (ni * k2 + k_index) / 2; + const int32_t kernel_value = int32_t((k_index % 2 == 0) ? (kernel[nb_index] & UINT8_C(0xF)) : (kernel[nb_index] >> 4)) - kernel_zero_point(); + ksum += kernel_value; + c_ref_acc += int32_t(xnn_x8_packq_f32qp8_get_quantized(mi, k_index, input_qp8.data(), + k2, mr_packed, kr, sr)) * int32_t(kernel_value); + } + size_t scale_index = ni * num_blocks + bi; + float scale = math_cvt_fp32_bf16(kernel_scale2d[scale_index]); + output_ref[mi * output_channels() + ni] += c_ref_acc * scale; + kfsum += scale * ksum; + } + float inv_scale = xnn_x8_packq_f32qp8_get_recip_scale(mi, input_qp8.data(), k2, mr_packed, kr, sr); + int32_t neg_nudged_zero_point= xnn_x8_packq_f32qp8_get_neg_nudged_zp(mi, input_qp8.data(), k2, mr_packed, kr, sr); + output_ref[mi * output_channels() + ni] += (neg_nudged_zero_point * kfsum); + output_ref[mi * output_channels() + ni] *= inv_scale; + if (has_bias()) { + output_ref[mi * output_channels() + ni] += bias[ni]; + } + } + } + + // Compute clamping parameters. + const float accumulated_max = + *std::max_element(output_ref.cbegin(), output_ref.cend()); + const float accumulated_min = + *std::min_element(output_ref.cbegin(), output_ref.cend()); + + const float output_min = + qmin() == 0 + ? -std::numeric_limits::infinity() + : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * + static_cast(qmin()); + const float output_max = + qmax() == 255 + ? std::numeric_limits::infinity() + : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * + static_cast(255 - qmax()); + + // Clamp reference results. + for (float& value : output_ref) { + value = std::max(std::min(value, output_max), output_min); + } + + // Create, setup, run, and destroy Fully Connected operator. + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + xnn_operator_t fully_connected_op = nullptr; + + struct xnn_internal_weights_cache* internal_weights_cache = nullptr; + std::unique_ptr auto_weights_cache( + nullptr, xnn_delete_weights_cache); + if (use_weights_cache()) { + xnn_weights_cache_t weights_cache = nullptr; + xnn_create_weights_cache(&weights_cache); + auto_weights_cache.reset(weights_cache); + if (weights_cache) { + internal_weights_cache = (struct xnn_internal_weights_cache*) weights_cache->context; + } + } + + const xnn_status status = xnn_create_fully_connected_nc_qp8_f32_qb4w( + input_channels(), output_channels(), + input_stride(), output_stride(), + block_size(), + kernel_zero_point(), + kernel_scale2d.data(), + kernel.data(), + has_bias() ? bias.data() : nullptr, + output_min, output_max, + 0, //TODO Handle XNN_FLAG_TRANSPOSE_WEIGHTS + nullptr, auto_weights_cache.get(), &fully_connected_op); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, fully_connected_op); + if (use_weights_cache()) { + ASSERT_EQ(xnn_status_success, + xnn_finalize_weights_cache(auto_weights_cache.get(), xnn_weights_cache_finalization_kind_soft)); + } + + // Smart pointer to automatically delete fully_connected_op. + std::unique_ptr + auto_fully_connected_op(fully_connected_op, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qb4w( + fully_connected_op, batch_size(), + /*threadpool=*/nullptr)); + + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qp8_f32_qb4w( + fully_connected_op, input_qp8.data(), output.data())); + + ASSERT_EQ(xnn_status_success, + xnn_run_operator(fully_connected_op, /*threadpool=*/nullptr)); + + // Verify results. + VerifyF32(output, output_ref, output_max, output_min); + + if (use_weights_cache()) { + // Create another operator with the same weights cache. + xnn_operator_t fully_connected_op2 = nullptr; + size_t old_weights_cache_size = internal_weights_cache->cache.weights.size; + + ASSERT_EQ(xnn_status_success, xnn_create_fully_connected_nc_qp8_f32_qb4w( + input_channels(), output_channels(), + input_stride(), output_stride(), + /*batch_size=*/ block_size(), + kernel_zero_point(), + kernel_scale2d.data(), + kernel.data(), has_bias() ? bias.data() : nullptr, + output_min, output_max, + 0, + nullptr, auto_weights_cache.get(), + &fully_connected_op2)); + ASSERT_NE(nullptr, fully_connected_op2); + + // Smart pointer to automatically delete fully_connected_op. + std::unique_ptr + auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, + xnn_reshape_fully_connected_nc_qp8_f32_qb4w( + fully_connected_op2, + batch_size(), + /*threadpool=*/nullptr)); + + std::vector output2(output.size(), nanf("")); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qp8_f32_qb4w( + fully_connected_op2, + input_qp8.data(), output2.data())); + + + ASSERT_EQ( + xnn_status_success, + xnn_run_operator(fully_connected_op2, /*threadpool=*/nullptr)); + + VerifyWeightsCache(*internal_weights_cache, old_weights_cache_size); + + VerifyF32(output, output_ref, output_max, output_min); + } + } + } + void TestQD8F16QC8W() const { ASSERT_EQ(weights_type(), WeightsType::Default); From 3b710774d4f1b08f74c44ef8964199f7aee9ad9e Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 3 Sep 2024 17:38:49 -0700 Subject: [PATCH 04/17] Wired QP8_QB4W Subgraph APIs --- src/subgraph/fully-connected.c | 44 +++++++- test/fully-connected.cc | 180 +++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 1 deletion(-) diff --git a/src/subgraph/fully-connected.c b/src/subgraph/fully-connected.c index c11f7b181f3..2a6459b818b 100644 --- a/src/subgraph/fully-connected.c +++ b/src/subgraph/fully-connected.c @@ -42,6 +42,7 @@ enum fully_connected_op_type { fc_type_qs8_qs8_qc8w = 16, fc_type_qs8_qs8_qs8 = 17, fc_type_qu8_qu8_qu8 = 18, + fc_type_qp8_f32_qb4w = 19, }; enum fully_connected_op_type get_fully_connected_op_type( @@ -95,7 +96,14 @@ enum fully_connected_op_type get_fully_connected_op_type( return fc_type_f32_f32_f32; } case xnn_datatype_qbint4: - return fc_type_qd8_f32_qb4w; + switch (input_datatype) { + case xnn_datatype_qdint8: + return fc_type_qd8_f32_qb4w; + case xnn_datatype_qpint8: + return fc_type_qp8_f32_qb4w; + default: + XNN_UNREACHABLE; + } case xnn_datatype_qcint4: switch (input_datatype) { case xnn_datatype_fp32: @@ -273,6 +281,18 @@ static enum xnn_status create_fully_connected_operator( node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qp8_f32_qb4w: + status = xnn_create_fully_connected_nc_qp8_f32_qb4w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + /*block_size=*/values[filter_id].quantization.block_size, + /*kernel_zero_point=*/values[filter_id].quantization.zero_point, + (const uint16_t*)values[filter_id].quantization.blockwise_scale, + kernel_data, bias_data, node->activation.output_min, + node->activation.output_max, node->flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; case fc_type_f32_f32_qc4w: status = xnn_create_fully_connected_nc_f32_qc4w( input_channels, output_channels, @@ -522,6 +542,12 @@ static enum xnn_status reshape_fully_connected_operator( status = xnn_reshape_fully_connected_nc_qp8_f32_qc4w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w: + status = xnn_reshape_fully_connected_nc_qp8_f32_qb4w( + opdata->operator_objects[0], + batch_size, + threadpool); + break; case xnn_operator_type_fully_connected_nc_qs8: status = xnn_reshape_fully_connected_nc_qs8(opdata->operator_objects[0], batch_size, threadpool); @@ -688,6 +714,15 @@ static enum xnn_status setup_fully_connected_operator( return xnn_setup_fully_connected_nc_qp8_f32_qc4w( opdata->operator_objects[0], input_data, output_data); } + case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w: + { + assert(kernel_data == NULL); + assert(bias_data == NULL); + return xnn_setup_fully_connected_nc_qp8_f32_qb4w( + opdata->operator_objects[0], + input_data, + output_data); + } case xnn_operator_type_fully_connected_nc_qs8: assert(kernel_data == NULL); assert(bias_data == NULL); @@ -759,6 +794,11 @@ static inline enum xnn_compute_type validate_datatypes_with_bias( bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp16) { return xnn_compute_type_qd8_to_fp16; + } else if (input_datatype == xnn_datatype_qpint8 && + bias_datatype == xnn_datatype_fp32 && + output_datatype == xnn_datatype_fp32) + { + return xnn_compute_type_qp8_to_fp32; } break; case xnn_datatype_qcint8: @@ -846,6 +886,8 @@ static inline enum xnn_compute_type validate_datatypes_without_bias( } else if (input_datatype == xnn_datatype_qdint8 && output_datatype == xnn_datatype_fp16) { return xnn_compute_type_qd8_to_fp16; + } else if (input_datatype == xnn_datatype_qpint8 && output_datatype == xnn_datatype_fp32) { + return xnn_compute_type_qp8_to_fp32; } break; case xnn_datatype_qcint8: diff --git a/test/fully-connected.cc b/test/fully-connected.cc index e00fc6219b2..afdda2a403c 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -876,6 +876,9 @@ class FullyConnectedTestF32QC4W : public FullyConnectedTestBase { }; +class FullyConnectedTestQP8F32QB4W + : public FullyConnectedTestBase {}; + using FullyConnectedTestQC8 = QuantizedFullyConnectedTestBase; using FullyConnectedTestQS8 = QuantizedFullyConnectedTestBase; @@ -4228,3 +4231,180 @@ TEST_F(FullyConnectedTestF32, reshape) size_t num_output_elements = std::accumulate(new_input_dims.begin(), new_input_dims.end() - 1, size_t{1}, std::multiplies()) * kernel_shape->dim[0]; ASSERT_EQ(runtime->values[node->outputs[0]].size, num_output_elements * sizeof(float)); } + +TEST_F(FullyConnectedTestQP8F32QB4W, define) +{ + size_t block_size = 32; + input_channels = round_up_po2(input_channels, block_size); + + input_dims[input_dims.size() - 1] = input_channels; + kernel_dims[kernel_dims.size() - 1] = input_channels; + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + uint32_t input_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qpint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); + + // Adjust number of kernel elements for QB4W. input_channels should be padded to byte boundary, hence even. + const size_t rounded_input_channels = round_up_po2(input_channels, 2); + kernel = std::vector(output_channels * rounded_input_channels); + const uint8_t kernel_zero_point = 8; + std::vector kernel_scale(output_channels * block_size); + std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); }); + uint32_t kernel_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + + uint32_t bias_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + + uint32_t output_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); + ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); + + ASSERT_EQ( + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_fully_connected); + ASSERT_EQ(node->compute_type, xnn_compute_type_qp8_to_fp32); + ASSERT_EQ(node->activation.output_min, output_min); + ASSERT_EQ(node->activation.output_max, output_max); + ASSERT_EQ(node->num_inputs, 3); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->inputs[1], kernel_id); + ASSERT_EQ(node->inputs[2], bias_id); + ASSERT_EQ(node->num_outputs, 1); + ASSERT_EQ(node->outputs[0], output_id); + ASSERT_EQ(node->flags, 0); +} + +TEST_F(FullyConnectedTestQP8F32QB4W, internally_allocated_dynamic_quantization_parameters) +{ + size_t block_size = 32; + input_channels = round_up_po2(input_channels, block_size); + + input_dims[input_dims.size() - 1] = input_channels; + kernel_dims[kernel_dims.size() - 1] = input_channels; + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + uint32_t input_id = XNN_INVALID_NODE_ID; + std::vector convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); + std::vector operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + std::vector subgraph_output(batch_size * output_channels); + std::vector operator_output(batch_size * output_channels); + std::fill(operator_output.begin(), operator_output.end(), nanf("")); + std::fill(subgraph_output.begin(), subgraph_output.end(), nanf("")); + std::vector quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + + std::vector kernel_scale(output_channels * block_size); + std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); }); + std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); + std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_dynamic_quantization_params{w8dist(rng), f32dist(rng)}; }); + + const size_t rounded_input_channels = round_up_po2(input_channels, 2); + kernel = std::vector(output_channels * rounded_input_channels); + + const float output_min = -std::numeric_limits::infinity(); + const float output_max = std::numeric_limits::infinity(); + + const uint8_t kernel_zero_point = 8; + + // Call operator API. + xnn_operator_t convert_op = nullptr; + xnn_operator_t fc_op = nullptr; + xnn_status status = xnn_create_convert_nc_f32_qp8( + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, convert_op); + ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8(convert_op, batch_size, input_channels, input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(), + operator_dq_data.data())); + ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + + status = xnn_create_fully_connected_nc_qp8_f32_qb4w( + input_channels, output_channels, input_channels, output_channels, block_size, kernel_zero_point, kernel_scale.data(), + kernel.data(), bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, fc_op); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qb4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qp8_f32_qb4w(fc_op, operator_dq_data.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + + // Call subgraph API. + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + + uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); + uint32_t kernel_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + + uint32_t bias_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + uint32_t output_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_NE(output_id, XNN_INVALID_NODE_ID); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_define_convert(subgraph, input_id, dq_quantized_id, /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM)); + ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, + kernel_id, bias_id, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::array external = { + xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); +} From ac93c3ebe92e27846e93906d3a1911950e7dc00a Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 5 Sep 2024 13:37:15 -0700 Subject: [PATCH 05/17] i8mm QP8 Kernels, Tests, Benchmarks, and Gemm Config --- bench/qp8-f32-qb4w-gemm.cc | 31 ++++++++++++ cmake/gen/neondot_aarch64_microkernels.cmake | 2 +- cmake/gen/neoni8mm_microkernels.cmake | 2 + gen/neondot_aarch64_microkernels.bzl | 2 +- gen/neoni8mm_microkernels.bzl | 2 + src/configs/gemm-config.c | 16 +++++- ...8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c | 32 ++++++++++++ ...b4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c | 32 ++++++++++++ src/xnnpack/gemm.h | 2 + test/qp8-f32-qb4w-gemm-minmax.cc | 49 +++++++++++++++++++ test/qp8-f32-qb4w-gemm-minmax.yaml | 14 ++++++ 11 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c create mode 100644 src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c diff --git a/bench/qp8-f32-qb4w-gemm.cc b/bench/qp8-f32-qb4w-gemm.cc index 193f8ac48f8..75de5c9c8dc 100644 --- a/bench/qp8-f32-qb4w-gemm.cc +++ b/bench/qp8-f32-qb4w-gemm.cc @@ -19,6 +19,37 @@ #include "xnnpack/packw.h" +#if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + static void qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__aarch64_neoni8mm(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases, + /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/4, + benchmark::utils::CheckNEONI8MM); + } + + BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__aarch64_neoni8mm) + + static void qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases, + /*mr=*/8, /*nr=*/4, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/4, + benchmark::utils::CheckNEONI8MM); + } + + BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2) + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + + #if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 #if XNN_ENABLE_KLEIDIAI static void qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(benchmark::State& state, const char* net) { diff --git a/cmake/gen/neondot_aarch64_microkernels.cmake b/cmake/gen/neondot_aarch64_microkernels.cmake index e25d17f73c3..0d54cdc6bfc 100644 --- a/cmake/gen/neondot_aarch64_microkernels.cmake +++ b/cmake/gen/neondot_aarch64_microkernels.cmake @@ -10,6 +10,7 @@ SET(PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS + src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c) @@ -18,7 +19,6 @@ SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c - src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c diff --git a/cmake/gen/neoni8mm_microkernels.cmake b/cmake/gen/neoni8mm_microkernels.cmake index 56990174531..e3c3fbdc9d5 100644 --- a/cmake/gen/neoni8mm_microkernels.cmake +++ b/cmake/gen/neoni8mm_microkernels.cmake @@ -26,6 +26,7 @@ SET(PROD_NEONI8MM_MICROKERNEL_SRCS src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c8-minmax-neoni8mm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-neoni8mm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c8-minmax-neoni8mm.c + src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x8c16s2-mstep2-neoni8mm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-neoni8mm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-neoni8mm.c @@ -185,6 +186,7 @@ SET(NON_PROD_NEONI8MM_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x16c8-minmax-neoni8mm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-neoni8mm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-neoni8mm.c + src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-4x4c16s2-neoni8mm.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-4x8c16s2-neoni8mm.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c diff --git a/gen/neondot_aarch64_microkernels.bzl b/gen/neondot_aarch64_microkernels.bzl index 6d2c41fe275..ec73493e4d9 100644 --- a/gen/neondot_aarch64_microkernels.bzl +++ b/gen/neondot_aarch64_microkernels.bzl @@ -6,6 +6,7 @@ Auto-generated file. Do not edit! """ PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ + "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", ] @@ -15,7 +16,6 @@ NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c", - "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c", diff --git a/gen/neoni8mm_microkernels.bzl b/gen/neoni8mm_microkernels.bzl index 8fe48f6099a..2898b08e259 100644 --- a/gen/neoni8mm_microkernels.bzl +++ b/gen/neoni8mm_microkernels.bzl @@ -22,6 +22,7 @@ PROD_NEONI8MM_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c8-minmax-neoni8mm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-neoni8mm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c8-minmax-neoni8mm.c", + "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x8c16s2-mstep2-neoni8mm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-neoni8mm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-neoni8mm.c", @@ -182,6 +183,7 @@ NON_PROD_NEONI8MM_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x16c8-minmax-neoni8mm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-neoni8mm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-neoni8mm.c", + "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-4x4c16s2-neoni8mm.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-4x8c16s2-neoni8mm.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c", diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 1f5c732c04f..f924282d252 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1738,7 +1738,21 @@ static void init_qp8_f32_qb4w_gemm_config(void) { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); - if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { + if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { +#if XNN_ENABLE_ARM_I8MM + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot); + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2); + qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.mr = 8; + qp8_f32_qb4w_gemm_config.nr = 4; + qp8_f32_qb4w_gemm_config.log2_kr = 4; + qp8_f32_qb4w_gemm_config.log2_sr = 1; + qp8_f32_qb4w_gemm_config.planes = 2; + qp8_f32_qb4w_gemm_config.mr_packed = 4; +#endif // XNN_ENABLE_ARM_I8MM + } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { #if XNN_ENABLE_ARM_DOTPROD qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot); qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c new file mode 100644 index 00000000000..327ab7279c6 --- /dev/null +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "xnnpack/log.h" +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + const struct xnn_f32_qb4w_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + xnn_log_fatal( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`."); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c new file mode 100644 index 00000000000..8fc32aab2aa --- /dev/null +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "xnnpack/log.h" +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + const struct xnn_f32_qb4w_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + xnn_log_fatal( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`."); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 29a0a834221..72c81ea8277 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -2267,6 +2267,8 @@ DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_u DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot) DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot) +DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm) +DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2) #define DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ diff --git a/test/qp8-f32-qb4w-gemm-minmax.cc b/test/qp8-f32-qb4w-gemm-minmax.cc index 2dc63d41508..033592c3d95 100644 --- a/test/qp8-f32-qb4w-gemm-minmax.cc +++ b/test/qp8-f32-qb4w-gemm-minmax.cc @@ -153,3 +153,52 @@ std::vector CreateTests1( #endif // XNN_ENABLE_KLEIDIAI #endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + + +#if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QB4W_GEMM_MINMAX_4X8C16S2__AARCH64_NEONI8MM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/32, + /*adj_k_block=*/32, + /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/4, + /*is_igemm=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_I8MM; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QB4W_GEMM_MINMAX_8X4C16S2__AARCH64_NEONI8MM_MSTEP2, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/32, + /*adj_k_block=*/32, + /*mr=*/8, /*nr=*/4, /*kr=*/16, /*sr=*/2, + /*mr_packed=*/4, + /*is_igemm=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2, + xnn_init_f32_qb4w_minmax_scalar_params, + xnn_pack_kai_qb4_weights_and_biases, + xnn_packed_stride_kai_qb4_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_I8MM; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 diff --git a/test/qp8-f32-qb4w-gemm-minmax.yaml b/test/qp8-f32-qb4w-gemm-minmax.yaml index 5c7c3b41d1d..bd44ddc3dcb 100644 --- a/test/qp8-f32-qb4w-gemm-minmax.yaml +++ b/test/qp8-f32-qb4w-gemm-minmax.yaml @@ -17,3 +17,17 @@ packed-stride: xnn_packed_stride_kai_qb4_weights_and_biases k-block: 32 cpp-check: XNN_ENABLE_KLEIDIAI + +- name: xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm + init: xnn_init_f32_qb4w_minmax_scalar_params + pack: xnn_pack_kai_qb4_weights_and_biases + packed-stride: xnn_packed_stride_kai_qb4_weights_and_biases + k-block: 32 + cpp-check: XNN_ENABLE_KLEIDIAI + +- name: xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2 + init: xnn_init_f32_qb4w_minmax_scalar_params + pack: xnn_pack_kai_qb4_weights_and_biases + packed-stride: xnn_packed_stride_kai_qb4_weights_and_biases + k-block: 32 + cpp-check: XNN_ENABLE_KLEIDIAI From 379bd1c3f3200b43d83380f13008a87566204674 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 6 Sep 2024 00:49:30 -0700 Subject: [PATCH 06/17] Add Benchmarking for Kernels --- CMakeLists.txt | 1 + bench/gemm-benchmark.cc | 132 +++++++++++++++++++++++++++++++++++++++- bench/gemm-benchmark.h | 9 +++ 3 files changed, 141 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a7c7a2358f..6551c0ef8d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2039,6 +2039,7 @@ IF(XNNPACK_BUILD_BENCHMARKS) qd8-f32-qc4w-gemm qd8-f32-qc8w-gemm qp8-f32-qc4w-gemm + qp8-f32-qb4w-gemm qs16-qs8-vcvt qs8-dwconv qs8-f16-vcvt diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index ccb0b219bcf..646df889a11 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -879,6 +879,137 @@ void GEMMBenchmark(benchmark::State& state, benchmark::Counter::kIsRate); } + +void GEMMBenchmark(benchmark::State& state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_qb4w_minmax_params_fn init_params, + xnn_pack_weights_and_biases_fn pack_weights, + xnn_packed_stride_weights_and_biases_fn packed_stride, + size_t mr, size_t nr, size_t kr, size_t sr, size_t mr_packed, + benchmark::utils::IsaCheckFunction isa_check) { + if (isa_check != nullptr && !isa_check(state)) { + return; + } + + const size_t mc = state.range(0); + const size_t nc = state.range(1); + const size_t bl = state.range(3); + const size_t kc = round_up(state.range(2), 2UL); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(-10.0f, 10.0f), + std::ref(rng)); + auto u8rng = std::bind(std::uniform_int_distribution( + 0, std::numeric_limits::max()), + std::ref(rng)); + auto scalerng = std::bind(std::uniform_real_distribution(0.5f, 2.f), + std::ref(rng)); + + const size_t planes = 2; // 4 bit is 2 planes - low nibbles and high nibbles + const size_t k2 = round_up_po2(kc, 2); // tester assumes byte aligned rows + const size_t packed_k2 = + round_up_po2(kc, kr * sr * planes); // 2 blocks for nibbles + + const size_t packed_k_bytes = (packed_k2 + 1) / 2; + const size_t num_blocks = packed_k2 / bl; + const size_t packed_n = round_up_po2(nc, nr); + + std::vector a(mc * k2); + std::generate(a.begin(), a.end(), std::ref(f32rng)); + std::vector k(nc * k2 / 2); + std::generate(k.begin(), k.end(), std::ref(u8rng)); + + // Create a fake `gemm_config` for the packing functions. + struct xnn_gemm_config gemm_config; + gemm_config.mr = static_cast(mr); + gemm_config.mr_packed = static_cast(mr_packed); + gemm_config.nr = static_cast(nr); + gemm_config.log2_kr = static_cast(31 - math_clz_nonzero_u32(kr)); + gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr)); + + const size_t packed_w_stride = + packed_stride(&gemm_config, k2, /*k_stride=*/bl, /*extra_bytes=*/0); + const size_t packed_w_size = packed_w_stride * round_up(nc, nr); + + const size_t c_elements = mc * nc; + const size_t num_buffers = + 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + sizeof(float) * (packed_w_size + c_elements)); + + std::vector> w(packed_w_size * num_buffers); + std::fill(w.begin(), w.end(), 0); + + // Quantize the left-hand operand. + const size_t input_packed_size = + xnn_x8_packq_f32qp8_packed_size(mc, k2, mr_packed, kr, sr); + std::vector input_qp8(input_packed_size); + xnn_x8_packq_f32qp8_ukernel__scalar_u1(mc, k2, mr_packed, kr, sr, + /*m_idx_start=*/0, a.data(), + /*lhs_stride=*/k2 * sizeof(float), + input_qp8.data()); + + // RHS packing + std::vector kernel_scale2d(nc * k2 / bl); + std::generate(kernel_scale2d.begin(), kernel_scale2d.end(), + [&]() { return math_cvt_bf16_fp32(scalerng()); }); + const xnn_qs8_qc4w_packing_params packing_params = {/*input_zero_point=*/1, + /*kernel_zero_point=*/8}; + pack_weights(/*flags=*/0, &gemm_config, k2, nc, + /*groups=*/1, /*k_stride=*/bl, + /*accumulator_init=*/nullptr, + /*weights=*/k.data(), + /*int_extra_data0_fn=*/nullptr, + /*extra_data0=*/nullptr, + /*extra_data0_size=*/0, + /*init_extra_data1_fn=*/ + nullptr, + /*extra_data1=*/kernel_scale2d.data(), + /*extra_data1_size=*/sizeof(float), + /*packed_weights_ptr=*/w.data(), &packing_params); + + std::vector c(c_elements * num_buffers); + std::fill(c.begin(), c.end(), std::nanf("")); + + // Prepare parameters. + xnn_f32_qb4w_minmax_params minmax_params; + init_params(&minmax_params, std::numeric_limits::min(), + std::numeric_limits::max(), 8, bl); + + size_t buffer_index = 0; + for (auto _ : state) { + // Use circular buffers (exceeding cache size) and prefetch to control cache + // state: + // - A is always in L1 cache (if fits, otherwise L2, L3, etc) + // - W is not in cache (for any cache level) + // - C is not in cache (for any cache level) + state.PauseTiming(); + benchmark::utils::PrefetchToL1(a.data(), a.size()); + buffer_index = (buffer_index + 1) % num_buffers; + state.ResumeTiming(); + + for (uint32_t m = 0; m < mc; m += mr) { + const uint32_t mb = min(mc - m, mr); + gemm(mb, nc, kc, + input_qp8.data() + + xnn_x8_packq_f32qp8_packed_offset(m, kc, mr, kr, sr), + w.data() + packed_w_size * buffer_index, + c.data() + (buffer_index * mc + m) * nc, nc * sizeof(float), + sizeof(float), &minmax_params); + } + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + state.counters["OPS"] = benchmark::Counter( + static_cast(state.iterations()) * 2 * mc * nc * kc, + benchmark::Counter::kIsRate); +} + void GEMMBenchmark(benchmark::State& state, xnn_qu8_gemm_minmax_ukernel_fn gemm, xnn_init_qu8_conv_minmax_params_fn init_params, xnn_pack_qu8_gemm_fn pack, size_t mr, size_t nr, size_t kr, @@ -1194,4 +1325,3 @@ void GEMMBenchmark(benchmark::State& state, xnn_f16_gemm_minmax_ukernel_fn gemm, benchmark::Counter(uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate); } - diff --git a/bench/gemm-benchmark.h b/bench/gemm-benchmark.h index 942b76217c3..b582d08aa3c 100644 --- a/bench/gemm-benchmark.h +++ b/bench/gemm-benchmark.h @@ -80,6 +80,15 @@ void GEMMBenchmark(benchmark::State& state, size_t nr, size_t kr, size_t sr, size_t mr_packed, benchmark::utils::IsaCheckFunction isa_check); +void GEMMBenchmark(benchmark::State& state, + xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_qb4w_minmax_params_fn init_params, + xnn_pack_weights_and_biases_fn pack_weights, + xnn_packed_stride_weights_and_biases_fn packed_stride, + size_t mr, + size_t nr, size_t kr, size_t sr, size_t mr_packed, + benchmark::utils::IsaCheckFunction isa_check); + void GEMMBenchmark(benchmark::State& state, xnn_qu8_gemm_minmax_ukernel_fn gemm, xnn_init_qu8_conv_minmax_params_fn init_params, xnn_pack_qu8_gemm_fn pack, size_t mr, size_t nr, size_t kr, From e64df81e04295f27304cd210a394ee85ffdf9990 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 6 Sep 2024 19:35:28 -0700 Subject: [PATCH 07/17] fix init_qp8_qb4_config --- src/configs/gemm-config.c | 68 +++++++++++++++++++-------------------- src/subgraph/convert.c | 3 +- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index f924282d252..079622e7e32 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1734,39 +1734,39 @@ static void init_qp8_f32_qc4w_gemm_config(void) { } static void init_qp8_f32_qb4w_gemm_config(void) { -#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI - const struct xnn_hardware_config* hardware_config = - xnn_init_hardware_config(); - assert(hardware_config != NULL); - if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { -#if XNN_ENABLE_ARM_I8MM - qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot); - qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2); - qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; - qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; - qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; - qp8_f32_qb4w_gemm_config.mr = 8; - qp8_f32_qb4w_gemm_config.nr = 4; - qp8_f32_qb4w_gemm_config.log2_kr = 4; - qp8_f32_qb4w_gemm_config.log2_sr = 1; - qp8_f32_qb4w_gemm_config.planes = 2; - qp8_f32_qb4w_gemm_config.mr_packed = 4; -#endif // XNN_ENABLE_ARM_I8MM - } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { -#if XNN_ENABLE_ARM_DOTPROD - qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot); - qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; - qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; - qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; - qp8_f32_qb4w_gemm_config.mr = 1; - qp8_f32_qb4w_gemm_config.nr = 8; - qp8_f32_qb4w_gemm_config.log2_kr = 4; - qp8_f32_qb4w_gemm_config.log2_sr = 1; - qp8_f32_qb4w_gemm_config.planes = 2; - qp8_f32_qb4w_gemm_config.mr_packed = 1; -#endif // XNN_ENABLE_ARM_DOTPROD - } -#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + #if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { + #if XNN_ENABLE_ARM_I8MM + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot); + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2); + qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.mr = 8; + qp8_f32_qb4w_gemm_config.nr = 4; + qp8_f32_qb4w_gemm_config.log2_kr = 4; + qp8_f32_qb4w_gemm_config.log2_sr = 1; + qp8_f32_qb4w_gemm_config.planes = 2; + qp8_f32_qb4w_gemm_config.mr_packed = 4; + #endif // XNN_ENABLE_ARM_I8MM + } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { + #if XNN_ENABLE_ARM_DOTPROD + qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot); + qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases; + qp8_f32_qb4w_gemm_config.mr = 1; + qp8_f32_qb4w_gemm_config.nr = 8; + qp8_f32_qb4w_gemm_config.log2_kr = 4; + qp8_f32_qb4w_gemm_config.log2_sr = 1; + qp8_f32_qb4w_gemm_config.planes = 2; + qp8_f32_qb4w_gemm_config.mr_packed = 1; + #endif // XNN_ENABLE_ARM_DOTPROD + } + #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } static void init_qd8_f32_qb4w_gemm_config(void) { @@ -3926,7 +3926,7 @@ const struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() { } XNN_INIT_ONCE(qp8_f32_qb4w_gemm); // Only return the config pointer if it actually provides a kernel. - if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) { + if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[0].function[0] != NULL) { return &qp8_f32_qb4w_gemm_config; } return NULL; diff --git a/src/subgraph/convert.c b/src/subgraph/convert.c index c1940f00b6b..76b70e98a6c 100644 --- a/src/subgraph/convert.c +++ b/src/subgraph/convert.c @@ -530,7 +530,8 @@ enum xnn_status xnn_define_convert( if ((flags & XNN_FLAG_MAYBE_PACK_FOR_GEMM) && input_value->datatype == xnn_datatype_fp32 && output_value->datatype == xnn_datatype_qdint8 && - xnn_init_qp8_f32_qc4w_gemm_config() != NULL) { + (xnn_init_qp8_f32_qc4w_gemm_config() != NULL || + xnn_init_qp8_f32_qb4w_gemm_config() != NULL)) { xnn_log_debug("Coercing type of output ID #%" PRIu32 " of %s operator from `%s` to `%s`.", output_id, xnn_node_type_to_string(xnn_node_type_convert), From ba0d2b583d6364bc07d4ba7b59412f4efc9c827f Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 22 Sep 2024 22:25:53 -0700 Subject: [PATCH 08/17] Remove log from ukernels --- .../qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c | 5 ++--- .../qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c | 5 ++--- .../qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c | 5 ++--- .../qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c | 5 ++--- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c index f505c42b583..0ee26a8fc93 100644 --- a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include "xnnpack/log.h" #include "xnnpack/microparams.h" #if XNN_ENABLE_KLEIDIAI @@ -25,8 +24,8 @@ void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot( m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, minmax_params->scalar.min, minmax_params->scalar.max); #else - xnn_log_fatal( + assert( "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " - "`XNN_ENABLE_KLEIDIAI`."); + "`XNN_ENABLE_KLEIDIAI`." && 0); #endif // XNN_ENABLE_KLEIDIAI } diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c index 9fa6c95889b..f5903c86097 100644 --- a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include "xnnpack/log.h" #include "xnnpack/microparams.h" #if XNN_ENABLE_KLEIDIAI @@ -25,8 +24,8 @@ void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot( m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, minmax_params->scalar.min, minmax_params->scalar.max); #else - xnn_log_fatal( + assert( "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " - "`XNN_ENABLE_KLEIDIAI`."); + "`XNN_ENABLE_KLEIDIAI`." && 0); #endif // XNN_ENABLE_KLEIDIAI } diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c index 327ab7279c6..b532fb72ec7 100644 --- a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-4x8c16s2-neoni8mm.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include "xnnpack/log.h" #include "xnnpack/microparams.h" #if XNN_ENABLE_KLEIDIAI @@ -25,8 +24,8 @@ void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm( m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, minmax_params->scalar.min, minmax_params->scalar.max); #else - xnn_log_fatal( + assert( "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " - "`XNN_ENABLE_KLEIDIAI`."); + "`XNN_ENABLE_KLEIDIAI`." && 0); #endif // XNN_ENABLE_KLEIDIAI } diff --git a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c index 8fc32aab2aa..c449df40893 100644 --- a/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c +++ b/src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-8x4c16s2-mstep2-neoni8mm.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include "xnnpack/log.h" #include "xnnpack/microparams.h" #if XNN_ENABLE_KLEIDIAI @@ -25,8 +24,8 @@ void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2( m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, minmax_params->scalar.min, minmax_params->scalar.max); #else - xnn_log_fatal( + assert( "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " - "`XNN_ENABLE_KLEIDIAI`."); + "`XNN_ENABLE_KLEIDIAI`." && 0); #endif // XNN_ENABLE_KLEIDIAI } From 7c92a68425a0c7781a89caa1cd4cee59a7450e7d Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 22 Sep 2024 22:46:07 -0700 Subject: [PATCH 09/17] add XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag --- src/subgraph/convert.c | 14 ++++++++++---- src/xnnpack/internal.h | 1 + test/fully-connected.cc | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/subgraph/convert.c b/src/subgraph/convert.c index 76b70e98a6c..16ccb7bd015 100644 --- a/src/subgraph/convert.c +++ b/src/subgraph/convert.c @@ -527,11 +527,17 @@ enum xnn_status xnn_define_convert( // available. // TODO(b/340399245) - Remove xnn_init_qp8_f32_qc4w_gemm_config check once we // have full qp8 support. - if ((flags & XNN_FLAG_MAYBE_PACK_FOR_GEMM) && + bool pack_activation_for_qc4w = ( + (flags & XNN_FLAG_MAYBE_PACK_FOR_GEMM) && + xnn_init_qp8_f32_qc4w_gemm_config() != NULL + ); + bool pack_activation_for_qb4w = ( + (flags & XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM) && + xnn_init_qp8_f32_qb4w_gemm_config() != NULL + ); + if ((pack_activation_for_qb4w || pack_activation_for_qc4w) && input_value->datatype == xnn_datatype_fp32 && - output_value->datatype == xnn_datatype_qdint8 && - (xnn_init_qp8_f32_qc4w_gemm_config() != NULL || - xnn_init_qp8_f32_qb4w_gemm_config() != NULL)) { + output_value->datatype == xnn_datatype_qdint8) { xnn_log_debug("Coercing type of output ID #%" PRIu32 " of %s operator from `%s` to `%s`.", output_id, xnn_node_type_to_string(xnn_node_type_convert), diff --git a/src/xnnpack/internal.h b/src/xnnpack/internal.h index 7f7f1550f30..369eaaf5b0b 100644 --- a/src/xnnpack/internal.h +++ b/src/xnnpack/internal.h @@ -20,6 +20,7 @@ extern "C" { /// If set, try to pack the quantized values for use by a GEMM. #define XNN_FLAG_MAYBE_PACK_FOR_GEMM 0x00000080 +#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100 enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( size_t input_channels, // diff --git a/test/fully-connected.cc b/test/fully-connected.cc index afdda2a403c..85be3af0225 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -4397,7 +4397,7 @@ TEST_F(FullyConnectedTestQP8F32QB4W, internally_allocated_dynamic_quantization_p ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_convert(subgraph, input_id, dq_quantized_id, /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM)); + ASSERT_EQ(xnn_status_success, xnn_define_convert(subgraph, input_id, dq_quantized_id, /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM)); ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); From 3c19d8cf2f19d7ab25f13424003f5bb6726c9a5e Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 22 Sep 2024 22:59:10 -0700 Subject: [PATCH 10/17] rm xnn_compute_*_bl compute fns --- src/operator-run.c | 62 ------------------------------ src/operators/fully-connected-nc.c | 46 ++++------------------ src/xnnpack/compute.h | 28 -------------- 3 files changed, 7 insertions(+), 129 deletions(-) diff --git a/src/operator-run.c b/src/operator-run.c index 850de62f8a2..ec9d1cb0a9f 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -529,68 +529,6 @@ void xnn_compute_qp8gemm( nr_block_start, mr_block_size, nr_block_size); } -void xnn_compute_hmp_qp8gemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { - const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( - mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); - const size_t cm_stride = context->cm_stride; - - context->qp8_bl_ukernel.function[uarch_index]( - mr_block_size, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + a_offset), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, - /*dst_stride_col=*/sizeof(float), context->fused_params); -} - -void xnn_compute_qp8gemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size) { - xnn_compute_hmp_qp8gemm_bl(context, XNN_UARCH_DEFAULT, mr_block_start, - nr_block_start, mr_block_size, nr_block_size); -} - -void xnn_compute_hmp_dqgemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; - - context->dq_bl_ukernel.function[uarch_index]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params, - (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); -} - -void xnn_compute_dqgemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ - return xnn_compute_hmp_dqgemm_bl(context, /*uarch_index=*/0, mr_block_start, nr_block_start, mr_block_size, nr_block_size); -} - void xnn_compute_spmm( const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 045b026d0eb..8223029e10a 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -1913,11 +1913,6 @@ static enum xnn_status reshape_fully_connected_nc( fully_connected_op->type == xnn_operator_type_fully_connected_nc_qp8_f32_qb4w; - const bool is_blockwise_kernel = fully_connected_op->type == - xnn_operator_type_fully_connected_nc_qd8_f32_qb4w || - fully_connected_op->type == - xnn_operator_type_fully_connected_nc_qp8_f32_qb4w; - fully_connected_op->context.gemm.gemm.gemm = (struct gemm_context){ .k_scaled = input_channels << log2_input_element_size, .w_stride = fully_connected_op->weights_stride, @@ -1955,38 +1950,20 @@ static enum xnn_status reshape_fully_connected_nc( if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch; if (dynamic_quantization) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm_bl; - } else { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm; - } + fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm; } else if (is_qp8_ukernel) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = - (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm_bl; - } else { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = - (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; - } + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; } else { fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; } } else { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; if (dynamic_quantization) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm_bl; - } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; - } + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; } else if (is_qp8_ukernel) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm_bl; - } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; - } } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } @@ -1994,19 +1971,10 @@ static enum xnn_status reshape_fully_connected_nc( #else fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; if (dynamic_quantization) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm_bl; - } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; - } + fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; } else if (is_qp8_ukernel) { - if (is_blockwise_kernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm_bl; - } else { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; - } + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 7bd9c8dda72..0f7030ea92b 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -325,8 +325,6 @@ struct gemm_context { struct xnn_hmp_gemm_ukernel ukernel; struct xnn_hmp_dqgemm_ukernel dq_ukernel; struct xnn_hmp_qp8gemm_ukernel qp8_ukernel; - struct xnn_hmp_dqgemm_bl_ukernel dq_bl_ukernel; - struct xnn_hmp_qp8gemm_bl_ukernel qp8_bl_ukernel; }; // Parameters for dynamically quantized inputs. const struct xnn_qd8_quantization_params* quantization_params; @@ -366,23 +364,10 @@ struct gemm_context { size_t mr_block_size, size_t nr_block_size); - XNN_PRIVATE void xnn_compute_dqgemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - XNN_PRIVATE void xnn_compute_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_qp8gemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size); - #if XNN_MAX_UARCH_TYPES > 1 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], @@ -413,19 +398,6 @@ struct gemm_context { const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_dqgemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_qp8gemm_bl( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); #endif // XNN_MAX_UARCH_TYPES > 1 #endif From 11093cdd554636700f097be96e686bb2f6383071 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 22 Sep 2024 23:57:51 -0700 Subject: [PATCH 11/17] add kxn packing, and use kleidi's packed stride fn --- src/operators/fully-connected-nc.c | 4 +-- src/packing.cc | 42 +++++++++++++++++--------- test/fully-connected-nc.cc | 34 +++++++++++++++++++++ test/fully-connected-operator-tester.h | 13 +++++--- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 8223029e10a..f384f5047fc 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -207,10 +207,10 @@ static enum xnn_status create_fully_connected_nc( if (block_wise && bias != NULL) { void* weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; + weights_start = (void*) ((uintptr_t) weights_ptr + (weights_stride - gemm_config->nr * sizeof(float))) ; xnn_init_qs8_qc8w_scale_fp32_params( output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, + weights_stride, weights_stride, 0, bias, weights_start); } } else { diff --git a/src/packing.cc b/src/packing.cc index ee737b5f3a3..0c30876da08 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -25,6 +25,7 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #endif // XNN_ENABLE_KLEIDIAI #include @@ -1684,16 +1685,16 @@ size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const uint32_t nr = gemm_config->nr; - const size_t kai_num_bytes_sum_rhs = sizeof(float); - const size_t kai_num_bytes_bias = sizeof(float); - // perhaps derive Bf16 from gemm-config? - // This needs to be updated in the kleidi branch to be in header - // return kai_rhs_packed_stride(k, /*nr=*/1, kr, block_size, Bf16); - const size_t num_bytes_multiplier_rhs = sizeof(uint16_t); - const size_t num_blocks_per_row = k/block_size; - const size_t num_bytes_per_block = (block_size / 2) + num_bytes_multiplier_rhs; - return 1 * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, + nr, + kr, + sr, + block_size, + kai_datatype::kai_dt_bf16); } void xnn_pack_kai_qb4_weights_and_biases( @@ -1712,17 +1713,30 @@ void xnn_pack_kai_qb4_weights_and_biases( reinterpret_cast(params); if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { - // no nxk as of now - xnn_log_fatal( - "KleidiAI does not currently have gio packing routine" - ); + struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params kai_params; + kai_params.lhs_zero_point = xnn_params->input_zero_point; + kai_params.rhs_zero_point = xnn_params->kernel_zero_point; + kai_params.scale_dt = kai_datatype::kai_dt_bf16; + size_t rhs_stride = (output_channels + 1) / 2; + size_t blocks_per_row = (input_channels + block_size - 1) / block_size; + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + groups, output_channels, input_channels, nr, kr, sr, + /*bl=*/block_size, + /*rhs=*/reinterpret_cast(weights), + /*rhs_stride=*/rhs_stride, + /*bias=*/reinterpret_cast(extra_data0), + /*scale=*/reinterpret_cast(extra_data1), + /*scale_stride=*/blocks_per_row * sizeof(uint16_t), + /*rhs_packed*/packed_weights_ptr, + /*extra_bytes=*/0, + &kai_params); } else { // Repack the packing params. struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params kai_params; kai_params.lhs_zero_point = xnn_params->input_zero_point; kai_params.rhs_zero_point = xnn_params->kernel_zero_point; kai_params.scale_dt = kai_datatype::kai_dt_bf16; - size_t rhs_stride = round_up_po2(input_channels, 2) / 2; + size_t rhs_stride = (input_channels + 1) / 2; size_t blocks_per_row = (input_channels + block_size - 1) / block_size; kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( groups, output_channels, input_channels, nr, kr, sr, diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc index e2c7218bdbf..2a89078fa51 100644 --- a/test/fully-connected-nc.cc +++ b/test/fully-connected-nc.cc @@ -2107,6 +2107,40 @@ TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, BL_NO_BIAS) { } } +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, BL_BIAS_TRANSPOSE_W) { + for (size_t ic=32; ic<=256; ic*=2){ + for (size_t bs=32; bs<=ic; bs=bs*2) { + FullyConnectedOperatorTester() + .transpose_weights(true) + .has_bias(true) + .batch_size(1) + .output_channels(16) + .input_channels(ic) + .block_size(bs) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); + } + } +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, BL_NO_BIAS_TRANSPOSE_W) { + for (size_t ic=32; ic<=256; ic*=2){ + for (size_t bs=32; bs<=ic; bs=bs*2) { + FullyConnectedOperatorTester() + .transpose_weights(true) + .has_bias(false) + .batch_size(1) + .output_channels(16) + .input_channels(ic) + .block_size(bs) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QB4W(); + } + } +} + TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, unit_batch) { FullyConnectedOperatorTester() .batch_size(1) diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index c965961081a..817d8db2358 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -1246,7 +1246,7 @@ class FullyConnectedOperatorTester { (batch_size() - 1) * input_stride() + input_channels()); const size_t kernel_stride = calc_kernel_stride(); - std::vector kernel((output_channels()) * + std::vector kernel((transpose_weights() ? k2 : output_channels()) * kernel_stride); std::vector bias(output_channels()); std::vector output((batch_size() - 1) * output_stride() + @@ -1283,8 +1283,11 @@ class FullyConnectedOperatorTester { int32_t c_ref_acc = 0; for (size_t ki = 0; ki < block_size(); ki++) { const size_t k_index = bi * block_size() + ki; - const size_t nb_index = (ni * k2 + k_index) / 2; - const int32_t kernel_value = int32_t((k_index % 2 == 0) ? (kernel[nb_index] & UINT8_C(0xF)) : (kernel[nb_index] >> 4)) - kernel_zero_point(); + const size_t nb_index = transpose_weights() ? + (k_index * kernel_stride) + (ni / 2) : + (ni * kernel_stride) + (k_index / 2); + const size_t plane_idx = transpose_weights() ? ni : ki; + const int32_t kernel_value = int32_t((plane_idx % 2 == 0) ? (kernel[nb_index] & UINT8_C(0xF)) : (kernel[nb_index] >> 4)) - kernel_zero_point(); ksum += kernel_value; c_ref_acc += int32_t(xnn_x8_packq_f32qp8_get_quantized(mi, k_index, input_qp8.data(), k2, mr_packed, kr, sr)) * int32_t(kernel_value); @@ -1351,7 +1354,7 @@ class FullyConnectedOperatorTester { kernel.data(), has_bias() ? bias.data() : nullptr, output_min, output_max, - 0, //TODO Handle XNN_FLAG_TRANSPOSE_WEIGHTS + transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, nullptr, auto_weights_cache.get(), &fully_connected_op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1395,7 +1398,7 @@ class FullyConnectedOperatorTester { kernel_scale2d.data(), kernel.data(), has_bias() ? bias.data() : nullptr, output_min, output_max, - 0, + transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, nullptr, auto_weights_cache.get(), &fully_connected_op2)); ASSERT_NE(nullptr, fully_connected_op2); From 586b3905b8182493ac02891167a0dea0639f0425 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 22 Sep 2024 23:59:20 -0700 Subject: [PATCH 12/17] fix double new line --- test/gemm-microkernel-tester.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 929c4a98381..18e9cf0f880 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -1903,7 +1903,6 @@ void GemmMicrokernelTester::Test( input_qp8.data() ); - // RHS packing. struct xnn_qs8_qc4w_packing_params params; params.input_zero_point = 1; From 2ba5867c9555d0e7f0fd82b15a6da9887f9ae8d3 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 24 Sep 2024 02:02:17 -0700 Subject: [PATCH 13/17] remove kai_common.h from microkernel tester, fix indentation --- src/operators/fully-connected-nc.c | 4 ++-- test/gemm-microkernel-tester.cc | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index f384f5047fc..fa3233cbb50 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -1962,8 +1962,8 @@ static enum xnn_status reshape_fully_connected_nc( if (dynamic_quantization) { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; } else { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 18e9cf0f880..5cbc1944f49 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -3,7 +3,6 @@ #include #include -#include "kai/kai_common.h" #include #include #include From a95ace194a0bc7817762369ae230a1a6c412ce91 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Wed, 25 Sep 2024 06:39:12 -0700 Subject: [PATCH 14/17] Set the number of dims for newly inserted clamp nodes. Some subgraph optimizations requires the number of dims to be known PiperOrigin-RevId: 678678608 --- src/subgraph.c | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/subgraph.c b/src/subgraph.c index f7dfd9a6603..5a0e2dfee1d 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -33,25 +33,29 @@ enum xnn_status xnn_insert_clamp_node(xnn_subgraph_t subgraph, float output_min, struct xnn_value* output_value = &subgraph->values[output_id]; uint32_t new_id = XNN_INVALID_VALUE_ID; enum xnn_status status; + const size_t num_dims = output_value->shape.num_dims; + const size_t* dims = output_value->shape.dim; switch (output_value->datatype) { case xnn_datatype_fp16: status = xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, 0, NULL, NULL, + subgraph, xnn_datatype_fp16, num_dims, dims, NULL, /*external_id=*/XNN_INVALID_VALUE_ID, /*flags=*/0, &new_id); break; case xnn_datatype_fp32: status = xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, 0, NULL, NULL, + subgraph, xnn_datatype_fp32, num_dims, dims, NULL, /*external_id=*/XNN_INVALID_VALUE_ID, /*flags=*/0, &new_id); break; case xnn_datatype_quint8: status = xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, output_value->quantization.zero_point, output_value->quantization.scale, /*num_dims=*/0, /*dims=*/NULL, NULL, + subgraph, xnn_datatype_quint8, output_value->quantization.zero_point, + output_value->quantization.scale, num_dims, dims, NULL, /*external_id=*/XNN_INVALID_VALUE_ID, /*flags=*/0, &new_id); break; case xnn_datatype_qint8: status = xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, output_value->quantization.zero_point, output_value->quantization.scale, /*num_dims=*/0, /*dims=*/NULL, NULL, + subgraph, xnn_datatype_qint8, output_value->quantization.zero_point, + output_value->quantization.scale, num_dims, dims, NULL, /*external_id=*/XNN_INVALID_VALUE_ID, /*flags=*/0, &new_id); break; default: From f208ebe9c17c8fb85ff1171449fa6c7e4706ac39 Mon Sep 17 00:00:00 2001 From: XNNPACK Team Date: Wed, 25 Sep 2024 10:50:56 -0700 Subject: [PATCH 15/17] Internal code fix PiperOrigin-RevId: 678763304 --- bench/models/qs8-mobilenet-v2.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/models/qs8-mobilenet-v2.cc b/bench/models/qs8-mobilenet-v2.cc index ad7cff160fb..e2aea9b6965 100644 --- a/bench/models/qs8-mobilenet-v2.cc +++ b/bench/models/qs8-mobilenet-v2.cc @@ -41,7 +41,7 @@ xnn_subgraph_t QS8MobileNetV2() { subgraph, xnn_datatype_fp32, v0_dims.size(), v0_dims.data(), /*data=*/nullptr, - 1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v0); + 0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v0); if (status != xnn_status_success) { std::cerr << "failed to create tensor v0" << std::endl; return nullptr; @@ -963,7 +963,7 @@ xnn_subgraph_t QS8MobileNetV2() { subgraph, xnn_datatype_fp32, v66_dims.size(), v66_dims.data(), /*data=*/nullptr, - 0, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &v66); + 1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &v66); if (status != xnn_status_success) { std::cerr << "failed to create tensor v66" << std::endl; return nullptr; From 3ac5962feb740c51071bf9068712d130c12d497a Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Wed, 25 Sep 2024 11:35:28 -0700 Subject: [PATCH 16/17] Initialize the new value size to 0 PiperOrigin-RevId: 678784133 --- src/subgraph.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/subgraph.c b/src/subgraph.c index 5a0e2dfee1d..1eaca96d431 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -64,6 +64,8 @@ enum xnn_status xnn_insert_clamp_node(xnn_subgraph_t subgraph, float output_min, if (status != xnn_status_success) { return status; } + struct xnn_value* new_value = &subgraph->values[new_id]; + new_value->size = 0; node->outputs[0] = new_id; node->activation.output_min = -INFINITY; node->activation.output_max = INFINITY; From 228671590de107bb07675431d889eb34674b8ede Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Wed, 25 Sep 2024 12:37:57 -0700 Subject: [PATCH 17/17] Fix use after free in xnn_insert_clamp_node xnn_define_tensor_value can invalidate the output_value pointer, so copy data referred from it before calling xnn_define_tensor_value. PiperOrigin-RevId: 678807574 --- src/subgraph.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/subgraph.c b/src/subgraph.c index 1eaca96d431..d227add1624 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -33,8 +33,9 @@ enum xnn_status xnn_insert_clamp_node(xnn_subgraph_t subgraph, float output_min, struct xnn_value* output_value = &subgraph->values[output_id]; uint32_t new_id = XNN_INVALID_VALUE_ID; enum xnn_status status; - const size_t num_dims = output_value->shape.num_dims; - const size_t* dims = output_value->shape.dim; + size_t num_dims = output_value->shape.num_dims; + size_t dims[XNN_MAX_TENSOR_DIMS]; + memcpy(dims, output_value->shape.dim, num_dims * sizeof(size_t)); switch (output_value->datatype) { case xnn_datatype_fp16: status = xnn_define_tensor_value(