From 3b734ff8031c0abb5b4526d5f0e0a572bde8c2c5 Mon Sep 17 00:00:00 2001 From: Erik Mackdanz Date: Thu, 28 Nov 2024 09:31:35 -0600 Subject: [PATCH] index a single document --- Cargo.lock | 11 ++++++ Cargo.toml | 1 + src/main.rs | 104 ++++++++++++++++++++++++++++++++++++---------------- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e03fcbf..055b283 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,6 +332,7 @@ dependencies = [ "candle-transformers", "clap", "hf-hub", + "pgvector", "postgres", "serde_json", "tokenizers 0.20.3", @@ -1849,6 +1850,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pgvector" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6" +dependencies = [ + "bytes", + "postgres-types", +] + [[package]] name = "phf" version = "0.11.2" diff --git a/Cargo.toml b/Cargo.toml index ccd6a2c..6fba559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ candle-nn = "0.8.0" candle-transformers = "0.8.0" clap = { version = "4.5.21", features = ["derive", "env"] } hf-hub = "0.3.2" +pgvector = { version = "0.4.0", features = ["postgres"] } postgres = "0.19.9" serde_json = "1.0.133" tokenizers = "0.20.3" diff --git a/src/main.rs b/src/main.rs index 98c1db8..1def530 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,13 +5,15 @@ use candle_core::Tensor; use candle_nn::VarBuilder; use clap::{Parser,Subcommand}; use hf_hub::{api::sync::Api, Repo, RepoType}; +use pgvector::Vector; use postgres::NoTls; use tokenizers::Tokenizer; #[derive(Subcommand,Debug)] enum Action { - /// Initialize the database when the table doesn't exist already + /// Initialize the database when the database or + /// table doesn't exist already InitDatabase { #[arg(long,default_value="vsearch")] @@ -27,8 +29,25 @@ enum Action { password: String, }, - /// Add one document to the database index - Index, + /// Read one document and add it to the database index + Index { + + #[arg(long,default_value="vsearch")] + dbname: String, + + #[arg(long,default_value="localhost")] + host: String, + + #[arg(long,default_value="cvmigrator")] + user: String, + + #[arg(long,env)] + password: String, + + #[arg(long)] + /// The file containing document contents + file: String, + }, /// Search the database for documents matching --prompt Search, @@ -55,10 +74,6 @@ struct Args { #[arg(long)] use_pth: bool, - /// The number of times to run the prompt. - #[arg(long, default_value = "1")] - n: usize, - /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, @@ -148,41 +163,66 @@ fn init_database(dbname: String, host: String, user: String, password: String) - Ok(()) } -fn main() -> Result<()> { +fn index(dbname: String, host: String, user: String, password: String, + file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> { + + println!("indexing a file"); - let args = Args::parse(); let start = std::time::Instant::now(); - let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let device = &model.device; - if let Some(prompt) = args.prompt { - let tokenizer = tokenizer - .with_padding(None) - .with_truncation(None) - .map_err(E::msg)?; - let tokens = tokenizer - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; - let token_type_ids = token_ids.zeros_like()?; - println!("Loaded and encoded {:?}", start.elapsed()); - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&token_ids, &token_type_ids, None)?; - if idx == 0 { - println!("{ys}"); - } - println!("Took {:?}", start.elapsed()); - } - } + let doc_content = std::fs::read_to_string(file)?; + + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(doc_content.clone(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let token_type_ids = token_ids.zeros_like()?; + println!("Loaded and encoded {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + let embeddings = model.forward(&token_ids, &token_type_ids, None)?; + println!("Took {:?}", start.elapsed()); + + let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?; + let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; + let embeddings = embeddings.squeeze(0)?.to_vec1::()?; + + let mut client = postgres::Config::new() + .dbname(&dbname) + .host(&host) + .user(&user) + .password(password) + .connect(NoTls)?; + + client.execute("INSERT INTO documents (content, embedding) \ + values ($1, $2)", + &[&doc_content,&Vector::from(embeddings)], + )?; + + let _ = client.close(); + Ok(()) +} + +fn main() -> Result<()> { + + let args = Args::parse(); + let (model, tokenizer) = args.build_model_and_tokenizer()?; match args.action { Action::InitDatabase{ dbname, host, user, password } => { init_database(dbname, host, user, password)?; } + Action::Index{ dbname, host, user, password, file } => { + index(dbname, host, user, password, file, model, tokenizer)?; + } _ => {} } -- 2.52.0