]> Humopery - vecsearch.git/commitdiff
index a single document
authorErik Mackdanz <erikmack@gmail.com>
Thu, 28 Nov 2024 15:31:35 +0000 (09:31 -0600)
committerErik Mackdanz <erikmack@gmail.com>
Thu, 28 Nov 2024 15:31:35 +0000 (09:31 -0600)
Cargo.lock
Cargo.toml
src/main.rs

index e03fcbf4e53a9ce1d4667d18322ca8fc547664ba..055b283cbef2c13d98f5f812cfc8d91dea427735 100644 (file)
@@ -332,6 +332,7 @@ dependencies = [
  "candle-transformers",
  "clap",
  "hf-hub",
+ "pgvector",
  "postgres",
  "serde_json",
  "tokenizers 0.20.3",
@@ -1849,6 +1850,16 @@ version = "2.3.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
 
+[[package]]
+name = "pgvector"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6"
+dependencies = [
+ "bytes",
+ "postgres-types",
+]
+
 [[package]]
 name = "phf"
 version = "0.11.2"
index ccd6a2c3bfb910dc86b15bca509ed6944b18a1fd..6fba55915bce424352f9a089e36dea2954dc66ef 100644 (file)
@@ -11,6 +11,7 @@ candle-nn = "0.8.0"
 candle-transformers = "0.8.0"
 clap = { version = "4.5.21", features = ["derive", "env"] }
 hf-hub = "0.3.2"
+pgvector = { version = "0.4.0", features = ["postgres"] }
 postgres = "0.19.9"
 serde_json = "1.0.133"
 tokenizers = "0.20.3"
index 98c1db85775b0d70d2656b69f8b9191642fcf683..1def53029d7a02de4d6b1a398e4fb64fb46cd90c 100644 (file)
@@ -5,13 +5,15 @@ use candle_core::Tensor;
 use candle_nn::VarBuilder;
 use clap::{Parser,Subcommand};
 use hf_hub::{api::sync::Api, Repo, RepoType};
+use pgvector::Vector;
 use postgres::NoTls;
 use tokenizers::Tokenizer;
 
 #[derive(Subcommand,Debug)]
 enum Action {
 
-    /// Initialize the database when the table doesn't exist already
+    /// Initialize the database when the database or
+    /// table doesn't exist already
     InitDatabase {
 
        #[arg(long,default_value="vsearch")]
@@ -27,8 +29,25 @@ enum Action {
        password: String,
     },
 
-    /// Add one document to the database index
-    Index,
+    /// Read one document and add it to the database index
+    Index {
+
+       #[arg(long,default_value="vsearch")]
+       dbname: String,
+
+       #[arg(long,default_value="localhost")]
+       host: String,
+
+       #[arg(long,default_value="cvmigrator")]
+       user: String,
+
+       #[arg(long,env)]
+       password: String,
+
+       #[arg(long)]
+       /// The file containing document contents
+       file: String,
+    },
 
     /// Search the database for documents matching --prompt
     Search,
@@ -55,10 +74,6 @@ struct Args {
     #[arg(long)]
     use_pth: bool,
 
-    /// The number of times to run the prompt.
-    #[arg(long, default_value = "1")]
-    n: usize,
-
     /// L2 normalization for embeddings.
     #[arg(long, default_value = "true")]
     normalize_embeddings: bool,
@@ -148,41 +163,66 @@ fn init_database(dbname: String, host: String, user: String, password: String) -
     Ok(())
 }
 
-fn main() -> Result<()> {
+fn index(dbname: String, host: String, user: String, password: String,
+        file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> {
+
+    println!("indexing a file");
 
-    let args = Args::parse();
     let start = std::time::Instant::now();
 
-    let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
     let device = &model.device;
 
-    if let Some(prompt) = args.prompt {
-        let tokenizer = tokenizer
-            .with_padding(None)
-            .with_truncation(None)
-            .map_err(E::msg)?;
-        let tokens = tokenizer
-            .encode(prompt, true)
-            .map_err(E::msg)?
-            .get_ids()
-            .to_vec();
-        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("Loaded and encoded {:?}", start.elapsed());
-        for idx in 0..args.n {
-            let start = std::time::Instant::now();
-            let ys = model.forward(&token_ids, &token_type_ids, None)?;
-            if idx == 0 {
-                println!("{ys}");
-            }
-            println!("Took {:?}", start.elapsed());
-        }
-    }
+    let doc_content = std::fs::read_to_string(file)?;
+
+    let tokenizer = tokenizer
+        .with_padding(None)
+        .with_truncation(None)
+        .map_err(E::msg)?;
+    let tokens = tokenizer
+        .encode(doc_content.clone(), true)
+        .map_err(E::msg)?
+        .get_ids()
+        .to_vec();
+    let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+    let token_type_ids = token_ids.zeros_like()?;
+    println!("Loaded and encoded {:?}", start.elapsed());
+
+    let start = std::time::Instant::now();
+    let embeddings = model.forward(&token_ids, &token_type_ids, None)?;
+    println!("Took {:?}", start.elapsed());
+
+    let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
+    let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
+    let embeddings = embeddings.squeeze(0)?.to_vec1::<f32>()?;
+
+    let mut client = postgres::Config::new()
+       .dbname(&dbname)
+       .host(&host)
+       .user(&user)
+       .password(password)
+       .connect(NoTls)?;
+
+    client.execute("INSERT INTO documents (content, embedding) \
+                   values ($1, $2)",
+                  &[&doc_content,&Vector::from(embeddings)],
+    )?;
+
+    let _ = client.close();
+    Ok(())
+}
+
+fn main() -> Result<()> {
+
+    let args = Args::parse();
+    let (model, tokenizer) = args.build_model_and_tokenizer()?;
 
     match args.action {
        Action::InitDatabase{ dbname, host, user, password } => {
            init_database(dbname, host, user, password)?;
        }
+       Action::Index{ dbname, host, user, password, file } => {
+           index(dbname, host, user, password, file, model, tokenizer)?;
+       }
        _ => {}
     }