Skip to content

Commit

Permalink
Removed anyhow, error handling got more verbose, but simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
cpetersen committed Mar 11, 2024
1 parent 541b907 commit 790a005
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
1 change: 0 additions & 1 deletion ext/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-core = "0.4.1"
candle-nn = "0.4.1"
candle-transformers = "0.4.1"
Expand Down
51 changes: 33 additions & 18 deletions ext/candle/src/model/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use candle_transformers::models::jina_bert::{BertModel, Config};

use anyhow::Error as E;
use anyhow::Result;
use candle_core::{Device, DType, Module, Tensor};
use candle_nn::VarBuilder;
use core::result::Result;
use tokenizers::Tokenizer;
use magnus::Error;

pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
Error::new(magnus::exception::runtime_error(), err.to_string())
}

pub fn wrap_candle_err(err: candle_core::Error) -> Error {
Error::new(magnus::exception::runtime_error(), err.to_string())
}

pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
Error::new(magnus::exception::runtime_error(), err.to_string())
}

pub struct ModelConfig {
device: Device,
Expand All @@ -29,56 +40,60 @@ impl ModelConfig {
}
}

pub fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model_path = match &self.model_path {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
None => Api::new()
.map_err(wrap_hf_err)?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
.get("model.safetensors")
.map_err(wrap_hf_err)?
};
let tokenizer_path = match &self.tokenizer_path {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
None => Api::new()
.map_err(wrap_hf_err)?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
.get("tokenizer.json")
.map_err(wrap_hf_err)?
};
// let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.device)? };
let model = BertModel::new(vb, &config)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(wrap_std_err)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.device).map_err(wrap_candle_err)? };
let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?;
Ok((model, tokenizer))
}

pub fn embedding(&self, input: String) -> anyhow::Result<Tensor> {
pub fn embedding(&self, input: String) -> Result<Tensor, Error> {
let config = ModelConfig::build();
let (model, tokenizer) = config.build_model_and_tokenizer()?;
return self.compute_embedding(input, model, tokenizer);
}

fn compute_embedding(&self, prompt: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<Tensor, E> {
fn compute_embedding(&self, prompt: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<Tensor, Error> {
let start: std::time::Instant = std::time::Instant::now();
// let prompt = args.prompt.as_deref().unwrap_or("Hello, world!");
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
.map_err(wrap_std_err)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.map_err(wrap_std_err)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &self.device)?.unsqueeze(0)?;
let token_ids = Tensor::new(&tokens[..], &self.device).map_err(wrap_candle_err)?.unsqueeze(0).map_err(wrap_candle_err)?;
println!("Loaded and encoded {:?}", start.elapsed());
let start: std::time::Instant = std::time::Instant::now();
let result = model.forward(&token_ids)?;
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
println!("{result}");
println!("Took {:?}", start.elapsed());
return Ok(result);
Expand Down

0 comments on commit 790a005

Please sign in to comment.