},
/// Search the database for documents matching --prompt
- Search,
+ Search {
+
+ #[arg(long,default_value="vsearch")]
+ dbname: String,
+
+ #[arg(long,default_value="localhost")]
+ host: String,
+
+ #[arg(long,default_value="cvmigrator")]
+ user: String,
+
+ #[arg(long,env)]
+ password: String,
+
+ #[arg(long)]
+ /// Search for this
+ search: String,
+ },
}
#[derive(Parser, Debug)]
Ok(())
}
-fn index(dbname: String, host: String, user: String, password: String,
- file: String, model: BertModel, mut tokenizer: Tokenizer) -> Result<()> {
-
- println!("indexing a file");
-
+fn get_embeddings(input: &String, model: BertModel, mut tokenizer: Tokenizer) -> Result<Vec<f32>> {
let start = std::time::Instant::now();
-
let device = &model.device;
-
- let doc_content = std::fs::read_to_string(file)?;
-
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
- .encode(doc_content.clone(), true)
+ .encode(input.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let embeddings = embeddings.squeeze(0)?.to_vec1::<f32>()?;
+ Ok(embeddings)
+}
+
+fn index(dbname: String, host: String, user: String, password: String,
+ file: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+
+ println!("indexing a file");
+
+ let doc_content = std::fs::read_to_string(file)?;
+ let embeddings = get_embeddings(&doc_content,model,tokenizer)?;
let mut client = postgres::Config::new()
.dbname(&dbname)
Ok(())
}
+fn search(dbname: String, host: String, user: String, password: String,
+ search: String, model: BertModel, tokenizer: Tokenizer) -> Result<()> {
+
+ println!("indexing a file");
+ let embeddings = get_embeddings(&search,model,tokenizer)?;
+
+ let mut client = postgres::Config::new()
+ .dbname(&dbname)
+ .host(&host)
+ .user(&user)
+ .password(password)
+ .connect(NoTls)?;
+ client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
+
+ for row in client.query("SELECT content FROM documents \
+ ORDER BY embedding <=> $1 LIMIT 5",
+ &[&Vector::from(embeddings)])? {
+ let content: &str = row.get(0);
+ println!("{}", content);
+ }
+
+ let _ = client.close();
+ Ok(())
+}
+
fn main() -> Result<()> {
let args = Args::parse();
Action::Index{ dbname, host, user, password, file } => {
index(dbname, host, user, password, file, model, tokenizer)?;
}
- _ => {}
+ Action::Search{ dbname, host, user, password, search: search_term } => {
+ search(dbname, host, user, password, search_term, model, tokenizer)?;
+ }
}
Ok(())