From be48ce78768b6da7d6cdc2b6fcb7075f078a3a5a Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:43:41 -0700 Subject: [PATCH] [rust] Fix bert model classifier loading (#3441) --- extensions/tokenizers/rust/src/models/bert.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/extensions/tokenizers/rust/src/models/bert.rs b/extensions/tokenizers/rust/src/models/bert.rs index 0e7b1722f72..da7602315c8 100644 --- a/extensions/tokenizers/rust/src/models/bert.rs +++ b/extensions/tokenizers/rust/src/models/bert.rs @@ -492,7 +492,16 @@ impl BertClassificationHead { Ok(layer) => Some(layer), Err(_) => None, }; - let output = Linear::load(vb.pp("classifier"), config.hidden_size, n_classes, None)?; + let output = match Linear::load(vb.pp("classifier"), config.hidden_size, n_classes, None) { + Ok(output) => output, + Err(err) => { + if let Ok(output) = Linear::load(vb, config.hidden_size, n_classes, None) { + output + } else { + return Err(err); + } + } + }; Ok(Self { pooler,