]> Humopery - vecsearch.git/commitdiff
implement search
authorErik Mackdanz <erikmack@gmail.com>
Thu, 28 Nov 2024 16:42:33 +0000 (10:42 -0600)
committerErik Mackdanz <erikmack@gmail.com>
Thu, 28 Nov 2024 16:42:33 +0000 (10:42 -0600)
src/main.rs
testdata/0 [new file with mode: 0644]
testdata/1 [new file with mode: 0644]
testdata/2 [new file with mode: 0644]

index f85dae531a9b85eb2bd8da5b6fa4ae4c9876d071..494d01308c3e16a0498ca7ef426000d69b3e042a 100644 (file)
@@ -50,7 +50,24 @@ enum Action {
     },
 
     /// Search the database for documents matching --prompt
-    Search,
+    Search {
+
+       #[arg(long,default_value="vsearch")]
+       dbname: String,
+
+       #[arg(long,default_value="localhost")]
+       host: String,
+
+       #[arg(long,default_value="cvmigrator")]
+       user: String,
+
+       #[arg(long,env)]
+       password: String,
+
+       #[arg(long)]
+       /// Search for this
+       search: String,
+    },
 }
 
 #[derive(Parser, Debug)]
@@ -163,23 +180,15 @@ fn init_database(dbname: String, host: String, user: String, password: String) -
     Ok(())
 }
 
-fn index(dbname: String, host: String, user: String, password: String,
-        file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> {
-
-    println!("indexing a file");
-
+fn get_embeddings(input: &String, model: BertModel, mut tokenizer: Tokenizer) -> Result<Vec<f32>> {
     let start = std::time::Instant::now();
-
     let device = &model.device;
-
-    let doc_content = std::fs::read_to_string(file)?;
-
     let tokenizer = tokenizer
         .with_padding(None)
         .with_truncation(None)
         .map_err(E::msg)?;
     let tokens = tokenizer
-        .encode(doc_content.clone(), true)
+        .encode(input.clone(), true)
         .map_err(E::msg)?
         .get_ids()
         .to_vec();
@@ -194,6 +203,16 @@ fn index(dbname: String, host: String, user: String, password: String,
     let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
     let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
     let embeddings = embeddings.squeeze(0)?.to_vec1::<f32>()?;
+    Ok(embeddings)
+}
+
+fn index(dbname: String, host: String, user: String, password: String,
+        file: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+
+    println!("indexing a file");
+
+    let doc_content = std::fs::read_to_string(file)?;
+    let embeddings = get_embeddings(&doc_content,model,tokenizer)?;
 
     let mut client = postgres::Config::new()
        .dbname(&dbname)
@@ -212,6 +231,31 @@ fn index(dbname: String, host: String, user: String, password: String,
     Ok(())
 }
 
+fn search(dbname: String, host: String, user: String, password: String,
+        search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+
+    println!("indexing a file");
+    let embeddings = get_embeddings(&search,model,tokenizer)?;
+
+    let mut client = postgres::Config::new()
+       .dbname(&dbname)
+       .host(&host)
+       .user(&user)
+       .password(password)
+       .connect(NoTls)?;
+    client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
+
+    for row in client.query("SELECT content FROM documents \
+                            ORDER BY embedding <=> $1 LIMIT 5",
+                           &[&Vector::from(embeddings)])? {
+        let content: &str = row.get(0);
+        println!("{}", content);
+    }
+
+    let _ = client.close();
+    Ok(())
+}
+
 fn main() -> Result<()> {
 
     let args = Args::parse();
@@ -224,7 +268,9 @@ fn main() -> Result<()> {
        Action::Index{ dbname, host, user, password, file } => {
            index(dbname, host, user, password, file, model, tokenizer)?;
        }
-       _ => {}
+       Action::Search{ dbname, host, user, password, search: search_term } => {
+           search(dbname, host, user, password, search_term, model, tokenizer)?;
+       }
     }
 
     Ok(())
diff --git a/testdata/0 b/testdata/0
new file mode 100644 (file)
index 0000000..45947c1
--- /dev/null
@@ -0,0 +1 @@
+The dog is barking
\ No newline at end of file
diff --git a/testdata/1 b/testdata/1
new file mode 100644 (file)
index 0000000..5302447
--- /dev/null
@@ -0,0 +1 @@
+The cat is purring
\ No newline at end of file
diff --git a/testdata/2 b/testdata/2
new file mode 100644 (file)
index 0000000..5a3572c
--- /dev/null
@@ -0,0 +1 @@
+The bear is growling
\ No newline at end of file