Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rust] Add GTE and Gemma2 model #3422

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 366 additions & 0 deletions extensions/tokenizers/rust/src/models/gemma2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
use crate::models::Model;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;

#[derive(Deserialize)]
pub struct Gemma2Config {
#[allow(unused)]
pub architectures: Vec<String>,
#[allow(unused)]
model_type: Option<String>,
pub attention_bias: bool,
pub head_dim: usize,
// The code gemma configs include both hidden_act and hidden_activation.
pub hidden_act: Option<Activation>,
pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub use_flash_attn: Option<bool>,
}

impl Gemma2Config {
fn hidden_act(&self) -> Result<Activation> {
match (self.hidden_act, self.hidden_activation) {
(None, Some(act)) | (Some(act), None) => Ok(act),
(Some(act), Some(_)) => Ok(act),
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
}
}
}

#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
eps: f64,
}

impl RmsNorm {
fn load(vb: VarBuilder, dim: usize, eps: f64) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
}

impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}

#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}

impl RotaryEmbedding {
fn new(config: &Gemma2Config, dtype: DType, dev: &Device) -> Result<Self> {
let dim = config.head_dim;
let max_seq_len = config.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / config.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}

fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, 0, seq_len)?;
let sin = self.sin.narrow(0, 0, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}

#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}

impl MLP {
fn load(vb: VarBuilder, config: &Gemma2Config) -> Result<Self> {
let hidden_sz = config.hidden_size;
let intermediate_sz = config.intermediate_size;
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: config.hidden_act()?,
})
}
}

impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}

#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
}

impl Attention {
fn load(
vb: VarBuilder,
config: &Gemma2Config,
rotary_emb: Arc<RotaryEmbedding>,
) -> Result<Self> {
let hidden_sz = config.hidden_size;
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = config.head_dim;
let bias = config.attention_bias;
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
rotary_emb,
use_flash_attn: config.use_flash_attn.unwrap_or(false),
})
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;

let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;

let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;

let (query_states, key_states) = self
.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states)?;

let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;

let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;

let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, ()))?
.apply(&self.o_proj)
}
}

#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}

#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}

#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}

impl DecoderLayer {
fn load(
vb: VarBuilder,
config: &Gemma2Config,
rotary_emb: Arc<RotaryEmbedding>,
) -> Result<Self> {
let self_attn = Attention::load(vb.pp("self_attn"), config, rotary_emb)?;
let mlp = MLP::load(vb.pp("mlp"), config)?;
let input_layernorm = RmsNorm::load(
vb.pp("input_layernorm"),
config.hidden_size,
config.rms_norm_eps,
)?;
let post_attention_layernorm = RmsNorm::load(
vb.pp("post_attention_layernorm"),
config.hidden_size,
config.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
}

#[derive(Debug)]
pub struct Gemma2Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
hidden_size: usize,
}

impl Gemma2Model {
pub fn load(vb: VarBuilder, config: &Gemma2Config) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(config.vocab_size, config.hidden_size, vb.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(config, vb.dtype(), vb.device())?);
let mut layers = Vec::with_capacity(config.num_hidden_layers);
let vb_l = vb.pp("layers");
for layer_idx in 0..config.num_hidden_layers {
let layer = DecoderLayer::load(vb_l.pp(layer_idx), config, rotary_emb.clone())?;
layers.push(layer)
}
let norm = RmsNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
hidden_size: config.hidden_size,
})
}

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}
}

impl Model for Gemma2Model {
fn get_input_names(&self) -> Vec<String> {
return vec!["input_ids".to_string(), "attention_mask".to_string()];
}

fn forward(
&self,
input_ids: &Tensor,
_attention_mask: &Tensor,
_token_type_ids: Option<&Tensor>,
) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask.as_ref())?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}
Loading
Loading