Skip to content

Commit 16fa9b2

Browse files
author
Marat Dukhan
committed
SDWCONV PSIMD micro-kernel
1 parent 6e35d1c commit 16fa9b2

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

configure.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def main(args):
8888
with build.options(isa=arm.neon if build.target.is_arm else None):
8989
qnnpack_objects += [
9090
build.cc("sconv/6x8-psimd.c"),
91+
build.cc("sdwconv/up4x9-psimd.c"),
9192
build.cc("sgemm/6x8-psimd.c"),
9293
]
9394

src/qnnpack/sdwconv.h

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <stddef.h>
12+
#include <stdint.h>
13+
14+
#include <qnnpack/params.h>
15+
#include <qnnpack/common.h>
16+
17+
#ifdef __cplusplus
18+
extern "C" {
19+
#endif
20+
21+
#define DECLARE_SUPDWCONV_UKERNEL_FUNCTION(fn_name) \
22+
QNNP_INTERNAL void fn_name( \
23+
size_t channels, \
24+
size_t output_width, \
25+
const float** input, \
26+
const float* weights, \
27+
float* output, \
28+
size_t input_stride, \
29+
size_t output_increment, \
30+
const struct qnnp_fp32_clamping_params* clamping_params);
31+
32+
DECLARE_SUPDWCONV_UKERNEL_FUNCTION(sdwconv_ukernel_up4x9__psimd)
33+
34+
#define DECLARE_SMPDWCONV_UKERNEL_FUNCTION(fn_name) \
35+
QNNP_INTERNAL void fn_name( \
36+
size_t channels, \
37+
size_t output_width, \
38+
const uint8_t** input, \
39+
const void* weights, \
40+
int32_t* buffer, \
41+
uint8_t* output, \
42+
size_t input_stride, \
43+
size_t output_increment, \
44+
const struct qnnp_fp32_clamping_params* clamping_params);
45+
46+
#ifdef __cplusplus
47+
} /* extern "C" */
48+
#endif

src/sdwconv/up4x9-psimd.c

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <psimd.h>
10+
11+
#include <qnnpack/sdwconv.h>
12+
13+
14+
void sdwconv_ukernel_up4x9__psimd(
15+
size_t channels,
16+
size_t output_width,
17+
const float** input,
18+
const float* weights,
19+
float* output,
20+
size_t input_stride,
21+
size_t output_increment,
22+
const struct qnnp_fp32_clamping_params clamping_params[restrict static 1])
23+
{
24+
const psimd_f32 vmax = psimd_splat_f32(clamping_params->max);
25+
const psimd_f32 vmin = psimd_splat_f32(clamping_params->min);
26+
do {
27+
const float* i0 = input[0];
28+
const float* i1 = input[1];
29+
const float* i2 = input[2];
30+
const float* i3 = input[3];
31+
const float* i4 = input[4];
32+
const float* i5 = input[5];
33+
const float* i6 = input[6];
34+
const float* i7 = input[7];
35+
const float* i8 = input[8];
36+
37+
input = (const float**) ((uintptr_t) input + input_stride);
38+
39+
size_t c = channels;
40+
const float* w = weights;
41+
for (; c >= 4; c -= 4) {
42+
psimd_f32 vacc = psimd_load_f32(w);
43+
44+
const psimd_f32 vi0 = psimd_load_f32(i0); i0 += 4;
45+
const psimd_f32 vk0 = psimd_load_f32(w + 8);
46+
vacc += vi0 * vk0;
47+
48+
const psimd_f32 vi1 = psimd_load_f32(i1); i1 += 4;
49+
const psimd_f32 vk1 = psimd_load_f32(w + 12);
50+
psimd_f32 vacc2 = vi1 * vk1;
51+
52+
const psimd_f32 vi2 = psimd_load_f32(i2); i2 += 4;
53+
const psimd_f32 vk2 = psimd_load_f32(w + 16);
54+
vacc += vi2 * vk2;
55+
56+
const psimd_f32 vi3 = psimd_load_f32(i3); i3 += 4;
57+
const psimd_f32 vk3 = psimd_load_f32(w + 20);
58+
vacc2 += vi3 * vk3;
59+
60+
const psimd_f32 vi4 = psimd_load_f32(i4); i4 += 4;
61+
const psimd_f32 vk4 = psimd_load_f32(w + 24);
62+
vacc += vi4 * vk4;
63+
64+
const psimd_f32 vi5 = psimd_load_f32(i5); i5 += 4;
65+
const psimd_f32 vk5 = psimd_load_f32(w + 28);
66+
vacc2 += vi5 * vk5;
67+
68+
const psimd_f32 vi6 = psimd_load_f32(i6); i6 += 4;
69+
const psimd_f32 vk6 = psimd_load_f32(w + 32);
70+
vacc += vi6 * vk6;
71+
72+
const psimd_f32 vi7 = psimd_load_f32(i7); i7 += 4;
73+
const psimd_f32 vk7 = psimd_load_f32(w + 36);
74+
vacc2 += vi7 * vk7;
75+
76+
const psimd_f32 vi8 = psimd_load_f32(i8); i8 += 4;
77+
const psimd_f32 vk8 = psimd_load_f32(w + 40);
78+
vacc += vi8 * vk8;
79+
80+
vacc += vacc2;
81+
82+
vacc = psimd_min_f32(vacc, vmax);
83+
vacc = psimd_max_f32(vacc, vmin);
84+
85+
psimd_store_f32(output, vacc);
86+
w += 44;
87+
}
88+
if (c != 0) {
89+
psimd_f32 vacc = psimd_load_f32(w);
90+
c *= sizeof(float);
91+
92+
i0 = (const float*) ((uintptr_t) i0 - c);
93+
const psimd_f32 vi0 = psimd_load_f32(i0);
94+
const psimd_f32 vk0 = psimd_load_f32(w + 8);
95+
vacc += vi0 * vk0;
96+
97+
i1 = (const float*) ((uintptr_t) i1 - c);
98+
const psimd_f32 vi1 = psimd_load_f32(i1);
99+
const psimd_f32 vk1 = psimd_load_f32(w + 12);
100+
psimd_f32 vacc2 = vi1 * vk1;
101+
102+
i2 = (const float*) ((uintptr_t) i2 - c);
103+
const psimd_f32 vi2 = psimd_load_f32(i2);
104+
const psimd_f32 vk2 = psimd_load_f32(w + 16);
105+
vacc += vi2 * vk2;
106+
107+
i3 = (const float*) ((uintptr_t) i3 - c);
108+
const psimd_f32 vi3 = psimd_load_f32(i3);
109+
const psimd_f32 vk3 = psimd_load_f32(w + 20);
110+
vacc2 += vi3 * vk3;
111+
112+
i4 = (const float*) ((uintptr_t) i4 - c);
113+
const psimd_f32 vi4 = psimd_load_f32(i4);
114+
const psimd_f32 vk4 = psimd_load_f32(w + 24);
115+
vacc += vi4 * vk4;
116+
117+
i5 = (const float*) ((uintptr_t) i5 - c);
118+
const psimd_f32 vi5 = psimd_load_f32(i5);
119+
const psimd_f32 vk5 = psimd_load_f32(w + 28);
120+
vacc2 += vi5 * vk5;
121+
122+
i6 = (const float*) ((uintptr_t) i6 - c);
123+
const psimd_f32 vi6 = psimd_load_f32(i6);
124+
const psimd_f32 vk6 = psimd_load_f32(w + 32);
125+
vacc += vi6 * vk6;
126+
127+
i7 = (const float*) ((uintptr_t) i7 - c);
128+
const psimd_f32 vi7 = psimd_load_f32(i7);
129+
const psimd_f32 vk7 = psimd_load_f32(w + 36);
130+
vacc2 += vi7 * vk7;
131+
132+
i8 = (const float*) ((uintptr_t) i8 - c);
133+
const psimd_f32 vi8 = psimd_load_f32(i8);
134+
const psimd_f32 vk8 = psimd_load_f32(w + 40);
135+
vacc += vi8 * vk8;
136+
137+
vacc += vacc2;
138+
139+
vacc = psimd_min_f32(vacc, vmax);
140+
vacc = psimd_max_f32(vacc, vmin);
141+
142+
output = (float*) ((uintptr_t) output - c);
143+
psimd_store_f32(output, vacc);
144+
}
145+
146+
output = (float*) ((uintptr_t) output + output_increment);
147+
} while (--output_width != 0);
148+
}

0 commit comments

Comments
 (0)