From 47833d96c4d7747fe6686f20e0a06d0972745ce0 Mon Sep 17 00:00:00 2001 From: Erik Mackdanz Date: Sat, 30 Nov 2024 08:54:08 -0600 Subject: [PATCH] add init-model command --- src/main.rs | 110 +++++++++++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/src/main.rs b/src/main.rs index 16aceda..4a11d2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,6 +39,15 @@ //! maybe creating database objects //! ``` //! +//! ## Initialize the model +//! +//! Download the model files. This command is optional since the model +//! files can be downloaded lazily by the index and search actions. +//! +//! ```text +//! $ ./vecsearch init-model +//! ``` +//! //! ## Add documents //! //! A document is a regular file. @@ -84,7 +93,6 @@ //! ## TODO //! //! - why model from PR not main? -//! - init-model command //! - rename cv-* //! use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; @@ -118,6 +126,9 @@ enum Action { dbpassword: String, }, + /// Download the model file in advance of the index or search commands + InitModel, + /// Read one document and add it to the database index Index { @@ -187,46 +198,44 @@ struct Args { } -impl Args { - fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { - let device = candle_core::Device::Cpu; - let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); - let default_revision = "refs/pr/21".to_string(); - let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { - (Some(model_id), Some(revision)) => (model_id, revision), - (Some(model_id), None) => (model_id, "main".to_string()), - (None, Some(revision)) => (default_model, revision), - (None, None) => (default_model, default_revision), - }; - - let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = { - let api = Api::new()?; - let api = api.repo(repo); - let config = api.get("config.json")?; - let tokenizer = api.get("tokenizer.json")?; - let weights = if self.use_pth { - api.get("pytorch_model.bin")? - } else { - api.get("model.safetensors")? - }; - (config, tokenizer, weights) - }; - let config = std::fs::read_to_string(config_filename)?; - let mut config: Config = serde_json::from_str(&config)?; - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - - let vb = if self.use_pth { - VarBuilder::from_pth(&weights_filename, DTYPE, &device)? +fn build_model_and_tokenizer(args: &Args) -> Result<(BertModel, Tokenizer)> { + let device = candle_core::Device::Cpu; + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let api = api.repo(repo); + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = if args.use_pth { + api.get("pytorch_model.bin")? } else { - unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + api.get("model.safetensors")? }; - if self.approximate_gelu { - config.hidden_act = HiddenAct::GeluApproximate; - } - let model = BertModel::load(vb, &config)?; - Ok((model, tokenizer)) + (config, tokenizer, weights) + }; + let config = std::fs::read_to_string(config_filename)?; + let mut config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let vb = if args.use_pth { + VarBuilder::from_pth(&weights_filename, DTYPE, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + }; + if args.approximate_gelu { + config.hidden_act = HiddenAct::GeluApproximate; } + let model = BertModel::load(vb, &config)?; + Ok((model, tokenizer)) } fn init_database(dbname: String, dbhost: String, dbuser: String, dbpassword: String) -> Result<()> { @@ -248,7 +257,7 @@ fn init_database(dbname: String, dbhost: String, dbuser: String, dbpassword: Str }, None => println!("database {} exists already", dbname), } - let _ = client.close(); + client.close()?; println!("maybe creating database objects"); let mut client = postgres::Config::new() @@ -294,8 +303,8 @@ fn get_embeddings(input: &String, model: &BertModel, mut tokenizer: Tokenizer) - Ok(embeddings) } -fn index(dbname: String, dbhost: String, dbuser: String, dbpassword: String, - files: Vec, model: BertModel, tokenizer: Tokenizer) -> Result<()> { +fn index(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String, + files: &Vec, model: BertModel, tokenizer: Tokenizer) -> Result<()> { println!("indexing file(s)"); @@ -317,12 +326,12 @@ fn index(dbname: String, dbhost: String, dbuser: String, dbpassword: String, )?; } - let _ = client.close(); + client.close()?; Ok(()) } -fn search(dbname: String, dbhost: String, dbuser: String, dbpassword: String, - search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { +fn search(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String, + search: &String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { println!("searching for document matches"); let embeddings = get_embeddings(&search,&model,tokenizer.clone())?; @@ -341,23 +350,28 @@ fn search(dbname: String, dbhost: String, dbuser: String, dbpassword: String, println!("{}", content); } - let _ = client.close(); + client.close()?; Ok(()) } fn main() -> Result<()> { let args = Args::parse(); - let (model, tokenizer) = args.build_model_and_tokenizer()?; match args.action { Action::InitDatabase{ dbname, dbhost, dbuser, dbpassword } => { init_database(dbname, dbhost, dbuser, dbpassword)?; } - Action::Index{ dbname, dbhost, dbuser, dbpassword, file } => { + Action::InitModel => { + build_model_and_tokenizer(&args)?; + } + Action::Index{ ref dbname, ref dbhost, ref dbuser, ref dbpassword, ref file } => { + let (model, tokenizer) = build_model_and_tokenizer(&args)?; index(dbname, dbhost, dbuser, dbpassword, file, model, tokenizer)?; } - Action::Search{ dbname, dbhost, dbuser, dbpassword, search: search_term } => { + Action::Search{ ref dbname, ref dbhost, ref dbuser, ref dbpassword, + search: ref search_term } => { + let (model, tokenizer) = build_model_and_tokenizer(&args)?; search(dbname, dbhost, dbuser, dbpassword, search_term, model, tokenizer)?; } } -- 2.52.0