From 2d65a0a52361485815cbe3960cdf64f6fb0ef867 Mon Sep 17 00:00:00 2001 From: Erik Mackdanz Date: Fri, 29 Nov 2024 10:30:35 -0600 Subject: [PATCH] Can index multiple files together --- src/main.rs | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/src/main.rs b/src/main.rs index 0b56c69..d0db012 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,16 +46,17 @@ //! The user for this operation requires only write access to the //! table (not superuser). //! +//! Specifying multiple files is more efficient than indexing one file +//! in each invocation. +//! //! ```text //! $ export PASSWORD=$(gpg -d pw-cvmigrator.gpg) -//! $ vecsearch index --file testdata/0 -//! indexing a file -//! Loaded and encoded 59.479µs -//! Took 14.982262ms -//! $ vecsearch index --file testdata/1 -//! ... -//! $ vecsearch index --file testdata/7 -//! ... +//! $ vecsearch index --file testdata/0 --file testdata/1 +//! indexing file(s) +//! Loaded and encoded 58.565µs +//! Took 15.628167ms +//! Loaded and encoded 55.513µs +//! Took 8.018493ms //! ``` //! //! ## Search @@ -82,7 +83,6 @@ //! //! ## TODO //! -//! - index multiple files //! - model from main not PR //! - env support for all args //! @@ -133,8 +133,8 @@ enum Action { password: String, #[arg(long)] - /// The file containing document contents - file: String, + /// The file containing document contents. Specify multiple + file: Vec, }, /// Search the database for documents matching --search @@ -265,7 +265,7 @@ fn init_database(dbname: String, host: String, user: String, password: String) - Ok(()) } -fn get_embeddings(input: &String, model: BertModel, mut tokenizer: Tokenizer) -> Result> { +fn get_embeddings(input: &String, model: &BertModel, mut tokenizer: Tokenizer) -> Result> { let start = std::time::Instant::now(); let device = &model.device; let tokenizer = tokenizer @@ -292,12 +292,9 @@ fn get_embeddings(input: &String, model: BertModel, mut tokenizer: Tokenizer) -> } fn index(dbname: String, host: String, user: String, password: String, - file: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { - - println!("indexing a file"); + files: Vec, model: BertModel, tokenizer: Tokenizer) -> Result<()> { - let doc_content = std::fs::read_to_string(file)?; - let embeddings = get_embeddings(&doc_content,model,tokenizer)?; + println!("indexing file(s)"); let mut client = postgres::Config::new() .dbname(&dbname) @@ -306,11 +303,16 @@ fn index(dbname: String, host: String, user: String, password: String, .password(password) .connect(NoTls)?; - client.execute("INSERT INTO documents (content, embedding) \ - values ($1, $2) \ - ON CONFLICT (content) DO UPDATE SET embedding = $2", - &[&doc_content,&Vector::from(embeddings)], - )?; + for file in files { + let doc_content = std::fs::read_to_string(file)?; + let embeddings = get_embeddings(&doc_content,&model,tokenizer.clone())?; + + client.execute("INSERT INTO documents (content, embedding) \ + values ($1, $2) \ + ON CONFLICT (content) DO UPDATE SET embedding = $2", + &[&doc_content,&Vector::from(embeddings)], + )?; + } let _ = client.close(); Ok(()) @@ -320,7 +322,7 @@ fn search(dbname: String, host: String, user: String, password: String, search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { println!("searching for document matches"); - let embeddings = get_embeddings(&search,model,tokenizer)?; + let embeddings = get_embeddings(&search,&model,tokenizer.clone())?; let mut client = postgres::Config::new() .dbname(&dbname) -- 2.52.0