use candle_nn::VarBuilder;
use clap::{Parser,Subcommand};
use hf_hub::{api::sync::Api, Repo, RepoType};
+use pgvector::Vector;
use postgres::NoTls;
use tokenizers::Tokenizer;
#[derive(Subcommand,Debug)]
enum Action {
- /// Initialize the database when the table doesn't exist already
+ /// Initialize the database when the database or
+ /// table doesn't exist already
InitDatabase {
#[arg(long,default_value="vsearch")]
password: String,
},
- /// Add one document to the database index
- Index,
+ /// Read one document and add it to the database index
+ Index {
+
+ #[arg(long,default_value="vsearch")]
+ dbname: String,
+
+ #[arg(long,default_value="localhost")]
+ host: String,
+
+ #[arg(long,default_value="cvmigrator")]
+ user: String,
+
+ #[arg(long,env)]
+ password: String,
+
+ #[arg(long)]
+ /// The file containing document contents
+ file: String,
+ },
/// Search the database for documents matching --prompt
Search,
#[arg(long)]
use_pth: bool,
- /// The number of times to run the prompt.
- #[arg(long, default_value = "1")]
- n: usize,
-
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
Ok(())
}
-fn main() -> Result<()> {
+fn index(dbname: String, host: String, user: String, password: String,
+ file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> {
+
+ println!("indexing a file");
- let args = Args::parse();
let start = std::time::Instant::now();
- let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
- if let Some(prompt) = args.prompt {
- let tokenizer = tokenizer
- .with_padding(None)
- .with_truncation(None)
- .map_err(E::msg)?;
- let tokens = tokenizer
- .encode(prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
- let token_type_ids = token_ids.zeros_like()?;
- println!("Loaded and encoded {:?}", start.elapsed());
- for idx in 0..args.n {
- let start = std::time::Instant::now();
- let ys = model.forward(&token_ids, &token_type_ids, None)?;
- if idx == 0 {
- println!("{ys}");
- }
- println!("Took {:?}", start.elapsed());
- }
- }
+ let doc_content = std::fs::read_to_string(file)?;
+
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ let tokens = tokenizer
+ .encode(doc_content.clone(), true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+ println!("Loaded and encoded {:?}", start.elapsed());
+
+ let start = std::time::Instant::now();
+ let embeddings = model.forward(&token_ids, &token_type_ids, None)?;
+ println!("Took {:?}", start.elapsed());
+
+ let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
+ let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
+ let embeddings = embeddings.squeeze(0)?.to_vec1::<f32>()?;
+
+ let mut client = postgres::Config::new()
+ .dbname(&dbname)
+ .host(&host)
+ .user(&user)
+ .password(password)
+ .connect(NoTls)?;
+
+ client.execute("INSERT INTO documents (content, embedding) \
+ values ($1, $2)",
+ &[&doc_content,&Vector::from(embeddings)],
+ )?;
+
+ let _ = client.close();
+ Ok(())
+}
+
+fn main() -> Result<()> {
+
+ let args = Args::parse();
+ let (model, tokenizer) = args.build_model_and_tokenizer()?;
match args.action {
Action::InitDatabase{ dbname, host, user, password } => {
init_database(dbname, host, user, password)?;
}
+ Action::Index{ dbname, host, user, password, file } => {
+ index(dbname, host, user, password, file, model, tokenizer)?;
+ }
_ => {}
}