Skip to content

Commit

Permalink
Add GGUF BF16 support (#17)
Browse files Browse the repository at this point in the history
* Add GGUF bf16 type support

* Add non avx impl for vec_dot_bf16

* Fix from_u32

* Fix loading

* Fix dequant of bf16
  • Loading branch information
EricLBuehler committed Nov 26, 2024
1 parent c12db59 commit b482af4
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 9 deletions.
83 changes: 81 additions & 2 deletions candle-core/src/cpu/avx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Cpu, CpuF16};
use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use half::f16;
use half::{bf16, f16};

pub struct CurrentCpu {}

Expand Down Expand Up @@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

pub struct CurrentCpuBF16 {}
impl CpuBF16<ARR> for CurrentCpuBF16 {
type Unit = __m256;
type Array = [__m256; ARR];

const STEP: usize = STEP;
const EPR: usize = EPR;

fn n() -> usize {
ARR
}

unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}

unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}

unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}

#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}

unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}

unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}

#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
}
}

unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
7 changes: 7 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}

#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
*res = half::bf16::from_f32(res_f32);
}
}
impl VecOps for u8 {
#[inline(always)]
Expand Down
62 changes: 60 additions & 2 deletions candle-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,33 @@ trait CpuF16<const ARR: usize> {
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;

#[allow(unused)]
trait CpuBF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;

fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
}

use half::{bf16, f16};

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};

#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
Expand Down Expand Up @@ -172,6 +191,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
*c = sumf;
}

#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuBF16::STEP - 1);

let mut sum = CurrentCpuBF16::zero_array();
let mut ax = CurrentCpuBF16::zero_array();
let mut ay = CurrentCpuBF16::zero_array();

for i in (0..np).step_by(CurrentCpuBF16::STEP) {
for j in 0..CurrentCpuBF16::n() {
ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));

sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
}
}

CurrentCpuBF16::vec_reduce(sum, &mut sumf);

// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
Expand All @@ -182,3 +229,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
}
*c = sum;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}
1 change: 1 addition & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ impl QCudaStorage {
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::BF16 => deq::<half::bf16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/quantized/ggml_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub fn qtensor_from_ggml(
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::BF16 => from_raw_data::<half::bf16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
Expand Down
46 changes: 45 additions & 1 deletion candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::utils::{
use super::GgmlDType;
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use half::{bf16, f16};
use rayon::prelude::*;

// Default to QK_K 256 rather than 64.
Expand Down Expand Up @@ -1963,3 +1963,47 @@ impl GgmlType for f16 {
Ok(())
}
}

impl GgmlType for bf16 {
const DTYPE: GgmlDType = GgmlDType::BF16;
const BLCK_SIZE: usize = 1;
type VecDotType = bf16;

fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
if ys.len() < n {
crate::bail!("size mismatch {} < {n}", ys.len())
}
let mut res = 0f32;
unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
Ok(res)
}

fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = bf16::from_f32(*x)
}
Ok(())
}

fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = x.to_f32()
}
Ok(())
}
}
14 changes: 10 additions & 4 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::f16;
use half::{bf16, f16};

pub use k_quants::GgmlType;

Expand Down Expand Up @@ -134,6 +134,7 @@ impl QStorage {
pub enum GgmlDType {
F32,
F16,
BF16,
Q4_0,
Q4_1,
Q5_0,
Expand Down Expand Up @@ -165,6 +166,8 @@ impl GgmlDType {
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
// https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
30 => Self::BF16,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
Expand All @@ -186,6 +189,8 @@ impl GgmlDType {
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
// https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
Self::BF16 => 30,
}
}

Expand All @@ -206,14 +211,15 @@ impl GgmlDType {
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::F16 | Self::BF16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Expand All @@ -234,7 +240,7 @@ impl GgmlDType {
pub fn block_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::F16 | Self::BF16 => 1,
Self::Q4_0 => k_quants::QK4_0,
Self::Q4_1 => k_quants::QK4_1,
Self::Q5_0 => k_quants::QK5_0,
Expand Down Expand Up @@ -422,7 +428,7 @@ thread_local! {
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => true,
GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
Expand Down

0 comments on commit b482af4

Please sign in to comment.