|
| 1 | +#include "ATen/ATen.h" |
| 2 | +#include "ATen/NativeFunctions.h" |
| 3 | +#include "ATen/WrapDimUtilsMulti.h" |
| 4 | + |
| 5 | +#ifdef USE_FBGEMM |
| 6 | +#include "fbgemm/Fbgemm.h" |
| 7 | +#include "fbgemm/QuantUtils.h" |
| 8 | +#endif // USE_FBGEMM |
| 9 | + |
| 10 | +#include <array> |
| 11 | +#include <cctype> |
| 12 | +#include <cmath> |
| 13 | +#include <cstddef> |
| 14 | +#include <sstream> |
| 15 | +#include <string> |
| 16 | +#include <vector> |
| 17 | + |
| 18 | +#include <chrono> |
| 19 | +namespace at { |
| 20 | +namespace native { |
| 21 | + |
| 22 | +#ifdef USE_FBGEMM |
| 23 | + |
| 24 | +Tensor fbgemm_linear_int8_weight( |
| 25 | + const Tensor& input, |
| 26 | + const Tensor& weight, |
| 27 | + const Tensor& packed, |
| 28 | + const Tensor& col_offsets, |
| 29 | + Scalar weight_scale, |
| 30 | + Scalar weight_zero_point, |
| 31 | + const Tensor& bias) { |
| 32 | + // We make a strong guarantee that models using these operators will have the |
| 33 | + // same numerics across different machines. Therefore, we do not provide a |
| 34 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 35 | + AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); |
| 36 | + |
| 37 | + // We call contiguous on `input` and `weight` here because these APIs all |
| 38 | + // expect row-major tensor buffers. |
| 39 | + auto* input_ptr = input.contiguous().data<float>(); |
| 40 | + auto* weight_ptr = weight.contiguous().data<int8_t>(); |
| 41 | + |
| 42 | + AT_ASSERT(input.dim() >= 2); |
| 43 | + int64_t M = 1; |
| 44 | + for (size_t i = 0; i < input.dim() - 1; ++i) { |
| 45 | + M *= input.size(i); |
| 46 | + } |
| 47 | + int64_t K = input.size(input.dim() - 1); |
| 48 | + AT_ASSERT(weight.dim() == 2); |
| 49 | + AT_ASSERT(K == weight.size(1)); |
| 50 | + auto N = weight.size(0); |
| 51 | + AT_ASSERT(bias.dim() == 1); |
| 52 | + AT_ASSERT(bias.size(0) == N); |
| 53 | + AT_ASSERT(weight_scale.isFloatingPoint()); |
| 54 | + AT_ASSERT(weight_zero_point.isIntegral()); |
| 55 | + |
| 56 | + // Calculate statistics for quantization of the input Tensor |
| 57 | + float x_min, x_max; |
| 58 | + fbgemm::FindMinMax( |
| 59 | + /*m=*/input_ptr, |
| 60 | + /*min=*/&x_min, |
| 61 | + /*max=*/&x_max, |
| 62 | + /*len=*/input.numel()); |
| 63 | + |
| 64 | + // Input tensor is quantized as 8-bit unsigned values |
| 65 | + static constexpr int precision = 8; |
| 66 | + static constexpr bool is_signed = false; |
| 67 | + |
| 68 | + // Calculate scale and zero point for quantization of input tensor |
| 69 | + auto q_params = fbgemm::ChooseQuantizationParams( |
| 70 | + /*min=*/x_min, |
| 71 | + /*max=*/x_max, |
| 72 | + /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0, |
| 73 | + /*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1, |
| 74 | + /*preserve_sparsity=*/false); |
| 75 | + |
| 76 | + q_params.precision = precision; |
| 77 | + |
| 78 | + // This operation does the following: |
| 79 | + // 1) Quantizes the input matrix given the statistics we've calculated above |
| 80 | + // 2) Creates a "row buffer" vector with offset values that must be added |
| 81 | + // to the integer matrix multiplication operation to ensure correctness |
| 82 | + // 3) Packs the resulting quantized matrix into vector-register and cache |
| 83 | + // friendly tiles. |
| 84 | + // |
| 85 | + // Note this is not executed eagerly, but rather within the fbgemmPacked call |
| 86 | + // below. |
| 87 | + fbgemm::PackAWithQuantRowOffset<uint8_t> packA( |
| 88 | + /*trans=*/fbgemm::matrix_op_t::NoTranspose, |
| 89 | + /*nRow=*/M, |
| 90 | + /*nCol=*/K, |
| 91 | + /*smat=*/input_ptr, |
| 92 | + /*ld=*/K, |
| 93 | + /*pmat=*/nullptr, // packA manages ownership of `pmat` |
| 94 | + /*scale=*/q_params.scale, |
| 95 | + /*zero_pt=*/q_params.zero_point); |
| 96 | + |
| 97 | + // ReQuantizeForFloat requires pointers to the scale and zero point values, |
| 98 | + // since in the case of rowwise quantization these will be arrays rather than |
| 99 | + // scalars. But in this case, we're doing whole-tensor quantization so we just |
| 100 | + // pass a pointer to the scale values (and internally ReQuantizeFor Float |
| 101 | + // won't index past 0 |
| 102 | + float weight_scale_float = static_cast<float>(weight_scale.to<double>()); |
| 103 | + int32_t weight_zero_point_int32 = |
| 104 | + static_cast<int32_t>(weight_zero_point.to<int64_t>()); |
| 105 | + |
| 106 | + // This is the end of the pipeline, pass the resulting matrix through |
| 107 | + fbgemm::DoNothing<float, float> doNothingObj{}; |
| 108 | + |
| 109 | + // After the uint8 * int8 matrix multiplication is performed, this operation |
| 110 | + // does: |
| 111 | + // 1) Add in row and column offsets to the rows and columns, respectively |
| 112 | + // 2) Dequantize the results into floating point |
| 113 | + // 3) Add in the bias term |
| 114 | + fbgemm::ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj( |
| 115 | + /*nextop=*/doNothingObj, |
| 116 | + /*Aq_scale=*/q_params.scale, |
| 117 | + /*Bq_scale=*/&weight_scale_float, |
| 118 | + /*Aq_zero_point=*/q_params.zero_point, |
| 119 | + /*Bq_zero_point=*/&weight_zero_point_int32, |
| 120 | + /*row_offsets=*/packA.getRowOffsetBuffer(), |
| 121 | + /*col_offsets=*/col_offsets.data<int32_t>(), |
| 122 | + /*bias=*/bias.contiguous().data<float>(), |
| 123 | + /*ncol=*/N); |
| 124 | + |
| 125 | + // Allocate output Tensor and a buffer for fbgemmPacked to use |
| 126 | + auto output = at::zeros_like(bias).to(at::kFloat).expand({M, N}).contiguous(); |
| 127 | + auto buffer = at::zeros_like(output).to(at::kInt).contiguous(); |
| 128 | + |
| 129 | + // Pull out the PackBMatrix instance from the owning tensor |
| 130 | + auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>( |
| 131 | + packed.storage().data_ptr().get()); |
| 132 | + |
| 133 | + // Do the GEMM |
| 134 | + fbgemm::fbgemmPacked( |
| 135 | + /*packA=*/packA, |
| 136 | + /*packB=*/*packB, |
| 137 | + /*C=*/output.data<float>(), |
| 138 | + /*C_buffer=*/buffer.data<int32_t>(), |
| 139 | + /*ldc=*/N, |
| 140 | + /*outProcess=*/outputProcObj, |
| 141 | + /*thread_id=*/0, |
| 142 | + /*num_threads=*/1); |
| 143 | + |
| 144 | + // The resulting matrix here is 2-D, let's view it with the original |
| 145 | + // left hand dimensions of the input. |
| 146 | + std::vector<int64_t> out_sizes = input.sizes().vec(); |
| 147 | + out_sizes.back() = N; |
| 148 | + return output.view(out_sizes); |
| 149 | +} |
| 150 | + |
| 151 | +namespace { |
| 152 | +// Calculate the column offsets |
| 153 | +// Note this includes the sum of the columns as well as the scalar term |
| 154 | +// B_zero_point * K, whereas the row_offsets created by PackAWithQuantRowOffset |
| 155 | +// is only the sum of the A rows. |
| 156 | +void calc_col_offsets_transpose( |
| 157 | + int K, |
| 158 | + int N, |
| 159 | + const int8_t* Bint8, |
| 160 | + int32_t B_zero_point, |
| 161 | + int32_t* col_offsets) { |
| 162 | + for (size_t i = 0; i < N; ++i) { |
| 163 | + int32_t sum = 0; |
| 164 | + for (size_t j = 0; j < K; ++j) { |
| 165 | + sum += Bint8[i * K + j]; |
| 166 | + } |
| 167 | + col_offsets[i] = sum - B_zero_point * K; |
| 168 | + } |
| 169 | +} |
| 170 | +} // namespace |
| 171 | + |
| 172 | +std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight( |
| 173 | + const Tensor& weight) { |
| 174 | + // We make a strong guarantee that models using these operators will have the |
| 175 | + // same numerics across different machines. Therefore, we do not provide a |
| 176 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 177 | + AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); |
| 178 | + auto weight_contig = weight.contiguous(); |
| 179 | + |
| 180 | + // Calculate weight statistics |
| 181 | + float w_min, w_max; |
| 182 | + fbgemm::FindMinMax( |
| 183 | + /*m=*/weight_contig.data<float>(), |
| 184 | + /*min=*/&w_min, |
| 185 | + /*max=*/&w_max, |
| 186 | + /*len=*/weight_contig.numel()); |
| 187 | + |
| 188 | + // Choose parameters for quantizing the weight as 8-bit signed integer |
| 189 | + static constexpr bool is_signed = true; |
| 190 | + static constexpr int precision = 8; |
| 191 | + auto q_params = fbgemm::ChooseQuantizationParams( |
| 192 | + /*min=*/w_min, |
| 193 | + /*max=*/w_max, |
| 194 | + /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0, |
| 195 | + /*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1, |
| 196 | + /*preserve_sparsity=*/false); |
| 197 | + |
| 198 | + q_params.precision = precision; |
| 199 | + |
| 200 | + auto quantized = at::zeros_like(weight_contig).to(at::kChar).contiguous(); |
| 201 | + fbgemm::Quantize<int8_t>( |
| 202 | + /*src=*/weight_contig.data<float>(), |
| 203 | + /*dst=*/quantized.data<int8_t>(), |
| 204 | + /*len=*/weight_contig.numel(), |
| 205 | + /*qparams=*/q_params); |
| 206 | + |
| 207 | + // Calculate column offsets of the weight and store them away in a tensor. |
| 208 | + // Similarly to quantization, this can be done once and cached. |
| 209 | + auto col_offsets = |
| 210 | + at::zeros_like(quantized).sum({1}).to(at::kInt).contiguous(); |
| 211 | + calc_col_offsets_transpose( |
| 212 | + /*K=*/quantized.size(1), |
| 213 | + /*N=*/quantized.size(0), |
| 214 | + /*Bint8=*/quantized.data<int8_t>(), |
| 215 | + /*B_zero_point=*/q_params.zero_point, |
| 216 | + /*col_offsets=*/col_offsets.data<int32_t>()); |
| 217 | + |
| 218 | + return std::make_tuple( |
| 219 | + quantized, col_offsets, q_params.scale, q_params.zero_point); |
| 220 | +} |
| 221 | + |
| 222 | +bool fbgemm_is_cpu_supported() { |
| 223 | + return fbgemm::fbgemmSupportedCPU(); |
| 224 | +} |
| 225 | + |
| 226 | +Tensor fbgemm_pack_quantized_matrix( |
| 227 | + const Tensor& weight, |
| 228 | + int64_t K, |
| 229 | + int64_t N) { |
| 230 | + // We make a strong guarantee that models using these operators will have the |
| 231 | + // same numerics across different machines. Therefore, we do not provide a |
| 232 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 233 | + AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); |
| 234 | + auto contiguous_ptr = weight.contiguous().data<int8_t>(); |
| 235 | + auto* ptr = new fbgemm::PackBMatrix<int8_t>( |
| 236 | + /*trans=*/fbgemm::matrix_op_t::Transpose, |
| 237 | + /*nRow=*/K, |
| 238 | + /*nCol=*/N, |
| 239 | + /*smat=*/contiguous_ptr, |
| 240 | + /*ld=*/K, |
| 241 | + /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat |
| 242 | + /*groups=*/1); |
| 243 | + |
| 244 | + // We store this instance away in a Tensor and register a deleter function |
| 245 | + // so that we do not leak memory. On the other side, we pull out the storage's |
| 246 | + // data_ptr and get the PackBMatrix's pointer. |
| 247 | + at::DataPtr at_ptr( |
| 248 | + ptr, |
| 249 | + ptr, |
| 250 | + [](void* ptr) { |
| 251 | + fbgemm::PackBMatrix<int8_t>* typed_ptr = |
| 252 | + reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr); |
| 253 | + delete typed_ptr; |
| 254 | + }, |
| 255 | + at::kCPU); |
| 256 | + |
| 257 | + auto retval = at::empty( |
| 258 | + {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte)); |
| 259 | + |
| 260 | + retval.storage().set_data_ptr(std::move(at_ptr)); |
| 261 | + |
| 262 | + return retval; |
| 263 | +} |
| 264 | + |
| 265 | +#else // USE_FBGEMM |
| 266 | + |
| 267 | +Tensor fbgemm_linear_int8_weight( |
| 268 | + const Tensor& /*input*/, |
| 269 | + const Tensor& /*weight*/, |
| 270 | + const Tensor& /*packed*/, |
| 271 | + const Tensor& /*col_offsets*/, |
| 272 | + Scalar /*weight_scale*/, |
| 273 | + Scalar /*weight_zero_point*/, |
| 274 | + const Tensor& /*bias*/) { |
| 275 | + // We make a strong guarantee that models using these operators will have the |
| 276 | + // same numerics across different machines. Therefore, we do not provide a |
| 277 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 278 | + AT_ASSERTM( |
| 279 | + false, "This PyTorch installation was not built with FBGEMM operators"); |
| 280 | +} |
| 281 | + |
| 282 | +std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight( |
| 283 | + const Tensor& /*weight*/) { |
| 284 | + // We make a strong guarantee that models using these operators will have the |
| 285 | + // same numerics across different machines. Therefore, we do not provide a |
| 286 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 287 | + AT_ASSERTM( |
| 288 | + false, "This PyTorch installation was not built with FBGEMM operators"); |
| 289 | +} |
| 290 | + |
| 291 | +Tensor fbgemm_pack_quantized_matrix( |
| 292 | + const Tensor& /*input*/, |
| 293 | + int64_t /*K*/, |
| 294 | + int64_t /*N*/) { |
| 295 | + // We make a strong guarantee that models using these operators will have the |
| 296 | + // same numerics across different machines. Therefore, we do not provide a |
| 297 | + // fallback path and rather fail loudly if we cannot run FBGEMM. |
| 298 | + AT_ASSERTM( |
| 299 | + false, "This PyTorch installation was not built with FBGEMM operators"); |
| 300 | +} |
| 301 | + |
| 302 | +bool fbgemm_is_cpu_supported() { |
| 303 | + return false; |
| 304 | +} |
| 305 | + |
| 306 | +#endif // USE_FBGEMM |
| 307 | +} |
| 308 | +} // namespace at |
0 commit comments