-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add optimised 'Indirect BGEMM' binary convolution kernels.
To start, add portable 4x2 C++ kernels for float/int8/bitpacked output. Facilitate easy implementation of new indirect bgemm kernels, including architecture-specific variations.
- Loading branch information
1 parent
5f75001
commit 11e57cd
Showing
11 changed files
with
713 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
larq_compute_engine/core/bconv2d/optimized_indirect_bgemm.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#ifndef COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ | ||
#define COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ | ||
|
||
#include "larq_compute_engine/core/bconv2d/zero_padding_correction.h" | ||
#include "larq_compute_engine/core/indirect_bgemm/kernel.h" | ||
#include "ruy/profiler/instrumentation.h" | ||
#include "tensorflow/lite/kernels/internal/types.h" | ||
|
||
namespace compute_engine { | ||
namespace core { | ||
namespace bconv2d { | ||
|
||
template <typename AccumScalar, typename DstScalar> | ||
inline void BConv2DOptimizedIndirectBGEMM( | ||
const indirect_bgemm::IndirectBGEMMKernel<DstScalar> kernel, | ||
const compute_engine::tflite::bconv2d::TfLiteBConv2DParams* conv_params, | ||
const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape, | ||
const OutputTransform<DstScalar>& output_transform, | ||
const TBitpacked** indirection_buffer, const TBitpacked* packed_weights, | ||
DstScalar* output_data, const float* padding_buffer, const int pad_value) { | ||
TF_LITE_ASSERT_EQ(bitpacked_input_shape.DimensionsCount(), 4); | ||
TF_LITE_ASSERT_EQ(output_shape.DimensionsCount(), 4); | ||
|
||
ruy::profiler::ScopeLabel label("BConv2D (optimized, indirect BGEMM)"); | ||
|
||
const std::int32_t conv_kernel_size = | ||
conv_params->filter_height * conv_params->filter_width; | ||
const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims(3); | ||
const std::int32_t output_size = output_shape.Dims(1) * output_shape.Dims(2); | ||
const std::int32_t output_channels = conv_params->channels_out; | ||
|
||
indirect_bgemm::RunKernel(kernel, conv_kernel_size, bitpacked_input_channels, | ||
output_size, output_channels, output_transform, | ||
indirection_buffer, packed_weights, output_data); | ||
|
||
if (std::is_same<DstScalar, float>::value && | ||
conv_params->padding_type == TfLitePadding::kTfLitePaddingSame && | ||
pad_value == 0) { | ||
ruy::profiler::ScopeLabel label("Zero padding correction"); | ||
|
||
const int stride_width = conv_params->stride_width; | ||
const int stride_height = conv_params->stride_height; | ||
const int dilation_width_factor = conv_params->dilation_width_factor; | ||
const int dilation_height_factor = conv_params->dilation_height_factor; | ||
const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0); | ||
const int input_depth = conv_params->channels_in; | ||
const int input_width = bitpacked_input_shape.Dims(2); | ||
const int input_height = bitpacked_input_shape.Dims(1); | ||
const int filter_height = conv_params->filter_height; | ||
const int filter_width = conv_params->filter_width; | ||
const int output_depth = output_shape.Dims(3); | ||
const int output_width = output_shape.Dims(2); | ||
const int output_height = output_shape.Dims(1); | ||
|
||
zero_padding_correction::ApplyCorrection( | ||
batches, input_height, input_width, input_depth, filter_height, | ||
filter_width, output_depth, stride_height, stride_width, | ||
dilation_height_factor, dilation_width_factor, | ||
reinterpret_cast<float*>(output_data), output_height, output_width, | ||
padding_buffer); | ||
} | ||
} | ||
|
||
} // namespace bconv2d | ||
} // namespace core | ||
} // namespace compute_engine | ||
|
||
#endif // COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
licenses(["notice"]) # Apache 2.0 | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_library( | ||
name = "prepare", | ||
hdrs = [ | ||
"prepare.h", | ||
], | ||
deps = [ | ||
"//larq_compute_engine/core:types", | ||
"//larq_compute_engine/tflite/kernels:bconv2d_params", | ||
"@org_tensorflow//tensorflow/lite/kernels/internal:types", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "kernels", | ||
hdrs = [ | ||
"kernel.h", | ||
"kernel_4x2_portable.h", | ||
], | ||
deps = [ | ||
"//larq_compute_engine/core:types", | ||
"//larq_compute_engine/core/bconv2d:output_transform", | ||
"//larq_compute_engine/tflite/kernels:bconv2d_params", | ||
"@org_tensorflow//tensorflow/lite/kernels/internal:types", | ||
"@ruy//ruy/profiler:instrumentation", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
|
||
#ifndef COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_ | ||
#define COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_ | ||
|
||
#include <cstdint> | ||
#include <type_traits> | ||
|
||
#include "larq_compute_engine/core/indirect_bgemm/kernel_4x2_portable.h" | ||
#include "larq_compute_engine/core/types.h" | ||
#include "larq_compute_engine/tflite/kernels/bconv2d_params.h" | ||
#include "tensorflow/lite/c/builtin_op_data.h" | ||
#include "tensorflow/lite/kernels/internal/types.h" | ||
|
||
using namespace tflite; | ||
|
||
namespace compute_engine { | ||
namespace core { | ||
namespace indirect_bgemm { | ||
|
||
using compute_engine::tflite::bconv2d::TfLiteBConv2DParams; | ||
|
||
template <typename DstScalar> | ||
struct IndirectBGEMMKernel { | ||
using MicroKernelFunction = void(const std::int32_t, const std::int32_t, | ||
const std::int32_t, const std::int32_t, | ||
const bconv2d::OutputTransform<DstScalar>&, | ||
const TBitpacked**, const TBitpacked*, | ||
DstScalar*); | ||
MicroKernelFunction* micro_kernel_function; | ||
const std::int32_t block_size_output_channels; | ||
const std::int32_t block_size_pixels; | ||
}; | ||
|
||
// This function allows us to select which kernel to use at runtime based on any | ||
// parameter we choose: destination scalar; conv params; input/output shapes; | ||
// even detected CPU features. | ||
// It is very important that this function is deterministic, as we rely on | ||
// the fact that the same kernel is selected for each call to `Eval` (as long as | ||
// the input shape doesn't change). | ||
template <typename DstScalar> | ||
inline IndirectBGEMMKernel<DstScalar> SelectRuntimeKernel( | ||
const TfLiteBConv2DParams* conv_params, | ||
const RuntimeShape& bitpacked_input_shape, | ||
const RuntimeShape& output_shape) { | ||
// For now there is only one kernel available. | ||
return IndirectBGEMMKernel<DstScalar>{&Kernel4x2Portable<DstScalar>, 4, 2}; | ||
} | ||
|
||
template <typename DstScalar> | ||
void RunKernel(const IndirectBGEMMKernel<DstScalar>& kernel, | ||
const std::int32_t conv_kernel_size, | ||
const std::int32_t bitpacked_input_channels, | ||
const std::int32_t output_size, | ||
const std::int32_t output_channels, | ||
const bconv2d::OutputTransform<DstScalar>& output_transform, | ||
const TBitpacked** indirection_buffer, | ||
const TBitpacked* packed_weights_ptr, DstScalar* output_ptr) { | ||
// TODO: implement multithreading here. | ||
for (std::int32_t pixel_start = 0; pixel_start < output_size; | ||
pixel_start += kernel.block_size_pixels) { | ||
const std::int32_t output_stride = | ||
std::is_same<DstScalar, TBitpacked>::value | ||
? bitpacking::GetBitpackedSize(output_channels) | ||
: output_channels; | ||
kernel.micro_kernel_function( | ||
std::min(output_size - pixel_start, kernel.block_size_pixels), | ||
conv_kernel_size, bitpacked_input_channels, output_channels, | ||
output_transform, indirection_buffer + pixel_start * conv_kernel_size, | ||
packed_weights_ptr, output_ptr + pixel_start * output_stride); | ||
} | ||
} | ||
|
||
} // namespace indirect_bgemm | ||
} // namespace core | ||
} // namespace compute_engine | ||
|
||
#endif // COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_ |
Oops, something went wrong.