]> Humopery - vecsearch.git/commitdiff
stub out commands
authorErik Mackdanz <erikmack@gmail.com>
Tue, 26 Nov 2024 04:58:35 +0000 (22:58 -0600)
committerErik Mackdanz <erikmack@gmail.com>
Tue, 26 Nov 2024 04:58:35 +0000 (22:58 -0600)
src/main.rs

index 426bc6365d8094d805fb7158772f0ac5768a74c8..3b641adb1f0eaca10493fa16c0ecd8ae758fc733 100644 (file)
@@ -3,20 +3,28 @@ use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
 use anyhow::{Error as E, Result};
 use candle_core::Tensor;
 use candle_nn::VarBuilder;
-use clap::Parser;
+use clap::{Parser,Subcommand};
 use hf_hub::{api::sync::Api, Repo, RepoType};
-use tokenizers::{PaddingParams, Tokenizer};
+use tokenizers::Tokenizer;
+
+#[derive(Subcommand,Debug)]
+enum Action {
+
+    /// Initialize the database when the table doesn't exist already
+    InitDatabase,
+
+    /// Add one document to the database index
+    Index,
+
+    /// Search the database for documents matching --prompt
+    Search,
+}
 
 #[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
 struct Args {
-    /// Run on CPU rather than on GPU.
-    #[arg(long)]
-    cpu: bool,
 
-    /// Enable tracing (generates a trace-timestamp.json file).
-    #[arg(long)]
-    tracing: bool,
+    #[command(subcommand)]
+    action: Action,
 
     /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
     #[arg(long)]
@@ -44,11 +52,12 @@ struct Args {
     /// Use tanh based approximation for Gelu instead of erf implementation.
     #[arg(long, default_value = "false")]
     approximate_gelu: bool,
+
 }
 
 impl Args {
     fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
-        let device = candle_examples::device(self.cpu)?;
+        let device = candle_core::Device::Cpu;
         let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
         let default_revision = "refs/pr/21".to_string();
         let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
@@ -89,18 +98,8 @@ impl Args {
 }
 
 fn main() -> Result<()> {
-    use tracing_chrome::ChromeLayerBuilder;
-    use tracing_subscriber::prelude::*;
 
     let args = Args::parse();
-    let _guard = if args.tracing {
-        println!("tracing...");
-        let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
-        tracing_subscriber::registry().with(chrome_layer).init();
-        Some(guard)
-    } else {
-        None
-    };
     let start = std::time::Instant::now();
 
     let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
@@ -127,81 +126,6 @@ fn main() -> Result<()> {
             }
             println!("Took {:?}", start.elapsed());
         }
-    } else {
-        let sentences = [
-            "The cat sits outside",
-            "A man is playing guitar",
-            "I love pasta",
-            "The new movie is awesome",
-            "The cat plays in the garden",
-            "A woman watches TV",
-            "The new movie is so great",
-            "Do you like pizza?",
-        ];
-        let n_sentences = sentences.len();
-        if let Some(pp) = tokenizer.get_padding_mut() {
-            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
-        } else {
-            let pp = PaddingParams {
-                strategy: tokenizers::PaddingStrategy::BatchLongest,
-                ..Default::default()
-            };
-            tokenizer.with_padding(Some(pp));
-        }
-        let tokens = tokenizer
-            .encode_batch(sentences.to_vec(), true)
-            .map_err(E::msg)?;
-        let token_ids = tokens
-            .iter()
-            .map(|tokens| {
-                let tokens = tokens.get_ids().to_vec();
-                Ok(Tensor::new(tokens.as_slice(), device)?)
-            })
-            .collect::<Result<Vec<_>>>()?;
-        let attention_mask = tokens
-            .iter()
-            .map(|tokens| {
-                let tokens = tokens.get_attention_mask().to_vec();
-                Ok(Tensor::new(tokens.as_slice(), device)?)
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        let token_ids = Tensor::stack(&token_ids, 0)?;
-        let attention_mask = Tensor::stack(&attention_mask, 0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("running inference on batch {:?}", token_ids.shape());
-        let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
-        println!("generated embeddings {:?}", embeddings.shape());
-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
-        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
-        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
-        let embeddings = if args.normalize_embeddings {
-            normalize_l2(&embeddings)?
-        } else {
-            embeddings
-        };
-        println!("pooled embeddings {:?}", embeddings.shape());
-
-        let mut similarities = vec![];
-        for i in 0..n_sentences {
-            let e_i = embeddings.get(i)?;
-            for j in (i + 1)..n_sentences {
-                let e_j = embeddings.get(j)?;
-                let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
-                similarities.push((cosine_similarity, i, j))
-            }
-        }
-        similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
-        for &(score, i, j) in similarities[..5].iter() {
-            println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
-        }
     }
     Ok(())
 }
-
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
-}