diff --git a/extensions/tokenizers/rust/src/models/camembert.rs b/extensions/tokenizers/rust/src/models/camembert.rs index 23e10cade39..b53b65151f4 100644 --- a/extensions/tokenizers/rust/src/models/camembert.rs +++ b/extensions/tokenizers/rust/src/models/camembert.rs @@ -480,12 +480,17 @@ pub struct CamembertModel { impl CamembertModel { pub fn load(vb: VarBuilder, config: &CamembertConfig) -> Result { let (embeddings, encoder) = match ( - BertEmbeddings::load(vb.pp("roberta.embeddings"), config), - BertEncoder::load(vb.pp("roberta.encoder"), config), + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp("roberta.embeddings".to_string()), config), + BertEncoder::load(vb.pp("roberta.encoder".to_string()), config), + ) { + (embeddings, encoder) + } else if let (Ok(embeddings), Ok(encoder)) = ( BertEmbeddings::load(vb.pp("deberta.embeddings".to_string()), config), BertEncoder::load(vb.pp("deberta.encoder".to_string()), config), ) {