//! maybe creating database objects
//! ```
//!
+//! ## Initialize the model
+//!
+//! Download the model files. This command is optional since the model
+//! files can be downloaded lazily by the index and search actions.
+//!
+//! ```text
+//! $ ./vecsearch init-model
+//! ```
+//!
//! ## Add documents
//!
//! A document is a regular file.
//! ## TODO
//!
//! - why model from PR not main?
-//! - init-model command
//! - rename cv-*
//!
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
dbpassword: String,
},
+ /// Download the model file in advance of the index or search commands
+ InitModel,
+
/// Read one document and add it to the database index
Index {
}
-impl Args {
- fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
- let device = candle_core::Device::Cpu;
- let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
- let default_revision = "refs/pr/21".to_string();
- let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
- (Some(model_id), Some(revision)) => (model_id, revision),
- (Some(model_id), None) => (model_id, "main".to_string()),
- (None, Some(revision)) => (default_model, revision),
- (None, None) => (default_model, default_revision),
- };
-
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
- let (config_filename, tokenizer_filename, weights_filename) = {
- let api = Api::new()?;
- let api = api.repo(repo);
- let config = api.get("config.json")?;
- let tokenizer = api.get("tokenizer.json")?;
- let weights = if self.use_pth {
- api.get("pytorch_model.bin")?
- } else {
- api.get("model.safetensors")?
- };
- (config, tokenizer, weights)
- };
- let config = std::fs::read_to_string(config_filename)?;
- let mut config: Config = serde_json::from_str(&config)?;
- let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
-
- let vb = if self.use_pth {
- VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+fn build_model_and_tokenizer(args: &Args) -> Result<(BertModel, Tokenizer)> {
+ let device = candle_core::Device::Cpu;
+ let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+ let default_revision = "refs/pr/21".to_string();
+ let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let (config_filename, tokenizer_filename, weights_filename) = {
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ let config = api.get("config.json")?;
+ let tokenizer = api.get("tokenizer.json")?;
+ let weights = if args.use_pth {
+ api.get("pytorch_model.bin")?
} else {
- unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+ api.get("model.safetensors")?
};
- if self.approximate_gelu {
- config.hidden_act = HiddenAct::GeluApproximate;
- }
- let model = BertModel::load(vb, &config)?;
- Ok((model, tokenizer))
+ (config, tokenizer, weights)
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let mut config: Config = serde_json::from_str(&config)?;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let vb = if args.use_pth {
+ VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+ } else {
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+ };
+ if args.approximate_gelu {
+ config.hidden_act = HiddenAct::GeluApproximate;
}
+ let model = BertModel::load(vb, &config)?;
+ Ok((model, tokenizer))
}
fn init_database(dbname: String, dbhost: String, dbuser: String, dbpassword: String) -> Result<()> {
},
None => println!("database {} exists already", dbname),
}
- let _ = client.close();
+ client.close()?;
println!("maybe creating database objects");
let mut client = postgres::Config::new()
Ok(embeddings)
}
-fn index(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
- files: Vec<String>, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+fn index(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String,
+ files: &Vec<String>, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
println!("indexing file(s)");
)?;
}
- let _ = client.close();
+ client.close()?;
Ok(())
}
-fn search(dbname: String, dbhost: String, dbuser: String, dbpassword: String,
- search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+fn search(dbname: &String, dbhost: &String, dbuser: &String, dbpassword: &String,
+ search: &String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
println!("searching for document matches");
let embeddings = get_embeddings(&search,&model,tokenizer.clone())?;
println!("{}", content);
}
- let _ = client.close();
+ client.close()?;
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
- let (model, tokenizer) = args.build_model_and_tokenizer()?;
match args.action {
Action::InitDatabase{ dbname, dbhost, dbuser, dbpassword } => {
init_database(dbname, dbhost, dbuser, dbpassword)?;
}
- Action::Index{ dbname, dbhost, dbuser, dbpassword, file } => {
+ Action::InitModel => {
+ build_model_and_tokenizer(&args)?;
+ }
+ Action::Index{ ref dbname, ref dbhost, ref dbuser, ref dbpassword, ref file } => {
+ let (model, tokenizer) = build_model_and_tokenizer(&args)?;
index(dbname, dbhost, dbuser, dbpassword, file, model, tokenizer)?;
}
- Action::Search{ dbname, dbhost, dbuser, dbpassword, search: search_term } => {
+ Action::Search{ ref dbname, ref dbhost, ref dbuser, ref dbpassword,
+ search: ref search_term } => {
+ let (model, tokenizer) = build_model_and_tokenizer(&args)?;
search(dbname, dbhost, dbuser, dbpassword, search_term, model, tokenizer)?;
}
}