From: Erik Mackdanz Date: Thu, 28 Nov 2024 16:42:33 +0000 (-0600) Subject: implement search X-Git-Url: https://git.humopery.space/?a=commitdiff_plain;h=968e24d253123648e3a4a36a46ae9f52b7046eef;p=vecsearch.git implement search --- diff --git a/src/main.rs b/src/main.rs index f85dae5..494d013 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,7 +50,24 @@ enum Action { }, /// Search the database for documents matching --prompt - Search, + Search { + + #[arg(long,default_value="vsearch")] + dbname: String, + + #[arg(long,default_value="localhost")] + host: String, + + #[arg(long,default_value="cvmigrator")] + user: String, + + #[arg(long,env)] + password: String, + + #[arg(long)] + /// Search for this + search: String, + }, } #[derive(Parser, Debug)] @@ -163,23 +180,15 @@ fn init_database(dbname: String, host: String, user: String, password: String) - Ok(()) } -fn index(dbname: String, host: String, user: String, password: String, - file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> { - - println!("indexing a file"); - +fn get_embeddings(input: &String, model: BertModel, mut tokenizer: Tokenizer) -> Result> { let start = std::time::Instant::now(); - let device = &model.device; - - let doc_content = std::fs::read_to_string(file)?; - let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let tokens = tokenizer - .encode(doc_content.clone(), true) + .encode(input.clone(), true) .map_err(E::msg)? .get_ids() .to_vec(); @@ -194,6 +203,16 @@ fn index(dbname: String, host: String, user: String, password: String, let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?; let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; let embeddings = embeddings.squeeze(0)?.to_vec1::()?; + Ok(embeddings) +} + +fn index(dbname: String, host: String, user: String, password: String, + file: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { + + println!("indexing a file"); + + let doc_content = std::fs::read_to_string(file)?; + let embeddings = get_embeddings(&doc_content,model,tokenizer)?; let mut client = postgres::Config::new() .dbname(&dbname) @@ -212,6 +231,31 @@ fn index(dbname: String, host: String, user: String, password: String, Ok(()) } +fn search(dbname: String, host: String, user: String, password: String, + search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> { + + println!("indexing a file"); + let embeddings = get_embeddings(&search,model,tokenizer)?; + + let mut client = postgres::Config::new() + .dbname(&dbname) + .host(&host) + .user(&user) + .password(password) + .connect(NoTls)?; + client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?; + + for row in client.query("SELECT content FROM documents \ + ORDER BY embedding <=> $1 LIMIT 5", + &[&Vector::from(embeddings)])? { + let content: &str = row.get(0); + println!("{}", content); + } + + let _ = client.close(); + Ok(()) +} + fn main() -> Result<()> { let args = Args::parse(); @@ -224,7 +268,9 @@ fn main() -> Result<()> { Action::Index{ dbname, host, user, password, file } => { index(dbname, host, user, password, file, model, tokenizer)?; } - _ => {} + Action::Search{ dbname, host, user, password, search: search_term } => { + search(dbname, host, user, password, search_term, model, tokenizer)?; + } } Ok(()) diff --git a/testdata/0 b/testdata/0 new file mode 100644 index 0000000..45947c1 --- /dev/null +++ b/testdata/0 @@ -0,0 +1 @@ +The dog is barking \ No newline at end of file diff --git a/testdata/1 b/testdata/1 new file mode 100644 index 0000000..5302447 --- /dev/null +++ b/testdata/1 @@ -0,0 +1 @@ +The cat is purring \ No newline at end of file diff --git a/testdata/2 b/testdata/2 new file mode 100644 index 0000000..5a3572c --- /dev/null +++ b/testdata/2 @@ -0,0 +1 @@ +The bear is growling \ No newline at end of file