From: Erik Mackdanz Date: Tue, 26 Nov 2024 04:58:35 +0000 (-0600) Subject: stub out commands X-Git-Url: https://git.humopery.space/?a=commitdiff_plain;h=e4f8ed55226d0530c05baea9f50404bda2dd63f9;p=vecsearch.git stub out commands --- diff --git a/src/main.rs b/src/main.rs index 426bc63..3b641ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,20 +3,28 @@ use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use anyhow::{Error as E, Result}; use candle_core::Tensor; use candle_nn::VarBuilder; -use clap::Parser; +use clap::{Parser,Subcommand}; use hf_hub::{api::sync::Api, Repo, RepoType}; -use tokenizers::{PaddingParams, Tokenizer}; +use tokenizers::Tokenizer; + +#[derive(Subcommand,Debug)] +enum Action { + + /// Initialize the database when the table doesn't exist already + InitDatabase, + + /// Add one document to the database index + Index, + + /// Search the database for documents matching --prompt + Search, +} #[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - /// Enable tracing (generates a trace-timestamp.json file). - #[arg(long)] - tracing: bool, + #[command(subcommand)] + action: Action, /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] @@ -44,11 +52,12 @@ struct Args { /// Use tanh based approximation for Gelu instead of erf implementation. #[arg(long, default_value = "false")] approximate_gelu: bool, + } impl Args { fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { - let device = candle_examples::device(self.cpu)?; + let device = candle_core::Device::Cpu; let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); let default_revision = "refs/pr/21".to_string(); let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { @@ -89,18 +98,8 @@ impl Args { } fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; let args = Args::parse(); - let _guard = if args.tracing { - println!("tracing..."); - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; let start = std::time::Instant::now(); let (model, mut tokenizer) = args.build_model_and_tokenizer()?; @@ -127,81 +126,6 @@ fn main() -> Result<()> { } println!("Took {:?}", start.elapsed()); } - } else { - let sentences = [ - "The cat sits outside", - "A man is playing guitar", - "I love pasta", - "The new movie is awesome", - "The cat plays in the garden", - "A woman watches TV", - "The new movie is so great", - "Do you like pizza?", - ]; - let n_sentences = sentences.len(); - if let Some(pp) = tokenizer.get_padding_mut() { - pp.strategy = tokenizers::PaddingStrategy::BatchLongest - } else { - let pp = PaddingParams { - strategy: tokenizers::PaddingStrategy::BatchLongest, - ..Default::default() - }; - tokenizer.with_padding(Some(pp)); - } - let tokens = tokenizer - .encode_batch(sentences.to_vec(), true) - .map_err(E::msg)?; - let token_ids = tokens - .iter() - .map(|tokens| { - let tokens = tokens.get_ids().to_vec(); - Ok(Tensor::new(tokens.as_slice(), device)?) - }) - .collect::>>()?; - let attention_mask = tokens - .iter() - .map(|tokens| { - let tokens = tokens.get_attention_mask().to_vec(); - Ok(Tensor::new(tokens.as_slice(), device)?) - }) - .collect::>>()?; - - let token_ids = Tensor::stack(&token_ids, 0)?; - let attention_mask = Tensor::stack(&attention_mask, 0)?; - let token_type_ids = token_ids.zeros_like()?; - println!("running inference on batch {:?}", token_ids.shape()); - let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?; - println!("generated embeddings {:?}", embeddings.shape()); - // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; - let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; - let embeddings = if args.normalize_embeddings { - normalize_l2(&embeddings)? - } else { - embeddings - }; - println!("pooled embeddings {:?}", embeddings.shape()); - - let mut similarities = vec![]; - for i in 0..n_sentences { - let e_i = embeddings.get(i)?; - for j in (i + 1)..n_sentences { - let e_j = embeddings.get(j)?; - let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; - let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; - let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; - let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); - similarities.push((cosine_similarity, i, j)) - } - } - similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); - for &(score, i, j) in similarities[..5].iter() { - println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) - } } Ok(()) } - -pub fn normalize_l2(v: &Tensor) -> Result { - Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) -}