Skip to content

Commit

Permalink
Support .bin, .pt, .pth extensions (#557)
Browse files Browse the repository at this point in the history
* Support .bin, .pt, .pth extensions

* Fix extention match
  • Loading branch information
EricLBuehler authored Jul 8, 2024
1 parent dab7339 commit ff8557f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
20 changes: 19 additions & 1 deletion mistralrs-core/src/pipeline/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,25 @@ pub fn get_model_paths(
}
None => {
let mut filenames = vec![];
for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
let listing = api_dir_list!(api, model_id);
let safetensors = listing
.clone()
.filter(|x| x.ends_with(".safetensors"))
.collect::<Vec<_>>();
let pickles = listing
.clone()
.filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
.collect::<Vec<_>>();
let files = if !safetensors.is_empty() {
// Always prefer safetensors
safetensors
} else if !pickles.is_empty() {
// Fall back to pickle
pickles
} else {
anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
};
for rfilename in files {
filenames.push(api_get_file!(api, &rfilename, model_id));
}
Ok(filenames)
Expand Down
61 changes: 56 additions & 5 deletions mistralrs-core/src/utils/varbuilder_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
use std::{collections::HashMap, path::PathBuf, thread::JoinHandle};

use candle_core::{DType, Device, Result, Tensor};
use candle_core::{
pickle::PthTensors, safetensors::MmapedSafetensors, DType, Device, Result, Tensor,
};
use candle_nn::{
var_builder::{SimpleBackend, VarBuilderArgs},
VarBuilder,
Expand All @@ -15,6 +17,43 @@ use derive_new::new;

use super::progress::{Joinable, NonThreadingHandle, Parellelize};

trait TensorLoaderBackend {
fn get_names(&self) -> Vec<String>;
fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result<Tensor>;
}

struct SafetensorBackend(MmapedSafetensors);

impl TensorLoaderBackend for SafetensorBackend {
fn get_names(&self) -> Vec<String> {
self.0
.tensors()
.into_iter()
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result<Tensor> {
self.0.load(name, device)?.to_dtype(dtype)
}
}

struct PickleBackend(PthTensors);

impl TensorLoaderBackend for PickleBackend {
fn get_names(&self) -> Vec<String> {
self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result<Tensor> {
self.0
.get(name)?
.ok_or(candle_core::Error::Msg(format!(
"Could not load tensor {name}"
)))?
.to_device(device)?
.to_dtype(dtype)
}
}

/// Load tensors into a VarBuilder backed by a VarMap using MmapedSafetensors.
/// Set `silent` to not show a progress bar.
/// Only include keys for which predicate evaluates to true
Expand Down Expand Up @@ -101,21 +140,33 @@ trait LoadTensors {
is_silent: bool,
predicate: impl Fn(String) -> bool,
) -> Result<HashMap<String, Tensor>> {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(path)? };
let tensors: Box<dyn TensorLoaderBackend> = match path
.extension()
.expect("Expected extension")
.to_str()
.expect("Expected to convert")
{
"safetensors" => Box::new(SafetensorBackend(unsafe {
candle_core::safetensors::MmapedSafetensors::new(path)?
})),
"pth" | "pt" | "bin" => Box::new(PickleBackend(
candle_core::pickle::PthTensors::new(path, None)?
)),
other => candle_core::bail!("Unexpected extension `{other}`, this should have been handles by `get_model_paths`."),
};

// Extracts the tensor name and processes it, filtering tensors and deriving the key name:
let names_only = tensors
.tensors()
.get_names()
.into_iter()
.map(|(name, _)| name)
.filter(|x| predicate(x.to_string()));
let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();

// Take the filtered list of tensors to load, store with derived lookup key:
let mut loaded_tensors = HashMap::new();
if !iter.is_empty() {
for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
let tensor = tensors.load(&load_name, device)?.to_dtype(dtype)?;
let tensor = tensors.load_name(&load_name, device, dtype)?;

loaded_tensors.insert(key_name, tensor);
}
Expand Down

0 comments on commit ff8557f

Please sign in to comment.