]> Humopery - vecsearch.git/commitdiff
add init-model command
authorErik Mackdanz <erikmack@gmail.com>
Sat, 30 Nov 2024 14:54:08 +0000 (08:54 -0600)
committerErik Mackdanz <erikmack@gmail.com>
Sat, 30 Nov 2024 14:54:08 +0000 (08:54 -0600)
src/main.rs

index 16aceda83a8675e9573ee38c738eb6b9b1e742d6..4a11d2bfdf97809afbcbb11eb2c575c24b25d10a 100644 (file)
 //! maybe creating database objects
 //! ```
 //! 
+//! ## Initialize the model
+//! 
+//! Download the model files. This command is optional since the model
+//! files can be downloaded lazily by the index and search actions.
+//! 
+//! ```text
+//! $ ./vecsearch init-model
+//! ```
+//!
 //! ## Add documents
 //! 
 //! A document is a regular file.
@@ -84,7 +93,6 @@
 //! ## TODO
 //! 
 //! - why model from PR not main?
-//! - init-model command
 //! - rename cv-*
 //! 
 use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
@@ -118,6 +126,9 @@ enum Action {
        dbpassword: String,
     },
 
+    /// Download the model file in advance of the index or search commands
+    InitModel,
+
     /// Read one document and add it to the database index
     Index {
 
@@ -187,46 +198,44 @@ struct Args {
 
 }
 
-impl Args {
-    fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
-        let device = candle_core::Device::Cpu;
-        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
-        let default_revision = "refs/pr/21".to_string();
-        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
-            (Some(model_id), Some(revision)) => (model_id, revision),
-            (Some(model_id), None) => (model_id, "main".to_string()),
-            (None, Some(revision)) => (default_model, revision),
-            (None, None) => (default_model, default_revision),
-        };
-
-        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
-        let (config_filename, tokenizer_filename, weights_filename) = {
-            let api = Api::new()?;
-            let api = api.repo(repo);
-            let config = api.get("config.json")?;
-            let tokenizer = api.get("tokenizer.json")?;
-            let weights = if self.use_pth {
-                api.get("pytorch_model.bin")?
-            } else {
-                api.get("model.safetensors")?
-            };
-            (config, tokenizer, weights)
-        };
-        let config = std::fs::read_to_string(config_filename)?;
-        let mut config: Config = serde_json::from_str(&config)?;
-        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
-
-        let vb = if self.use_pth {
-            VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+fn build_model_and_tokenizer(args: &Args) -> Result<(BertModel, Tokenizer)> {
+    let device = candle_core::Device::Cpu;
+    let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+    let default_revision = "refs/pr/21".to_string();
+    let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
+        (Some(model_id), Some(revision)) => (model_id, revision),
+        (Some(model_id), None) => (model_id, "main".to_string()),
+        (None, Some(revision)) => (default_model, revision),
+        (None, None) => (default_model, default_revision),
+    };
+
+    let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+    let (config_filename, tokenizer_filename, weights_filename) = {
+        let api = Api::new()?;
+        let api = api.repo(repo);
+        let config = api.get("config.json")?;
+        let tokenizer = api.get("tokenizer.json")?;
+        let weights = if args.use_pth {
+            api.get("pytorch_model.bin")?
         } else {
-            unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+            api.get("model.safetensors")?
         };
-        if self.approximate_gelu {
-            config.hidden_act = HiddenAct::GeluApproximate;
-        }
-        let model = BertModel::load(vb, &config)?;
-        Ok((model, tokenizer))
+        (config, tokenizer, weights)
+    };
+    let config = std::fs::read_to_string(config_filename)?;
+    let mut config: Config = serde_json::from_str(&config)?;
+    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+    let vb = if args.use_pth {
+        VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+    } else {
+        unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+    };
+    if args.approximate_gelu {
+        config.hidden_act = HiddenAct::GeluApproximate;
     }
+    let model = BertModel::load(vb, &config)?;
+    Ok((model, tokenizer))
 }
 
 fn init_database(dbname: String, dbhost: String, dbuser: String, dbpassword: String) -> Result<()> {
@@ -248,7 +257,7 @@ fn init_database(dbname: String, dbhost: String, dbuser: String, dbpassword: Str
        },
        None => println!("database {} exists already", dbname),
     }
-    let _ = client.close();
+    client.close()?;
 
     println!("maybe creating database objects");
     let mut client = postgres::Config::new()
@@ -294,8 +303,8 @@ fn get_embeddings(input: &String, model: &BertModel, mut tokenizer: Tokenizer) -
     Ok(embeddings)
 }
 
-fn index(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
-        files: Vec<String>, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+fn index(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String,
+        files: &Vec<String>, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
 
     println!("indexing file(s)");
 
@@ -317,12 +326,12 @@ fn index(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
        )?;
     }
 
-    let _ = client.close();
+    client.close()?;
     Ok(())
 }
 
-fn search(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
-        search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+fn search(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String,
+        search: &String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
 
     println!("searching for document matches");
     let embeddings = get_embeddings(&search,&model,tokenizer.clone())?;
@@ -341,23 +350,28 @@ fn search(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
         println!("{}", content);
     }
 
-    let _ = client.close();
+    client.close()?;
     Ok(())
 }
 
 fn main() -> Result<()> {
 
     let args = Args::parse();
-    let (model, tokenizer) = args.build_model_and_tokenizer()?;
 
     match args.action {
        Action::InitDatabase{ dbname, dbhost, dbuser, dbpassword } => {
            init_database(dbname, dbhost, dbuser, dbpassword)?;
        }
-       Action::Index{ dbname, dbhost, dbuser, dbpassword, file } => {
+       Action::InitModel => {
+           build_model_and_tokenizer(&args)?;
+       }
+       Action::Index{ ref dbname, ref dbhost, ref dbuser, ref dbpassword, ref file } => {
+           let (model, tokenizer) = build_model_and_tokenizer(&args)?;
            index(dbname, dbhost, dbuser, dbpassword, file, model, tokenizer)?;
        }
-       Action::Search{ dbname, dbhost, dbuser, dbpassword, search: search_term } => {
+       Action::Search{ ref dbname, ref dbhost, ref dbuser, ref dbpassword,
+                       search: ref search_term } => {
+           let (model, tokenizer) = build_model_and_tokenizer(&args)?;
            search(dbname, dbhost, dbuser, dbpassword, search_term, model, tokenizer)?;
        }
     }