Skip to content

Commit 5cdc28b

Browse files
committed
Improved examples [skip ci]
1 parent 84103ca commit 5cdc28b

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

examples/cohere/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fn main() -> Result<(), Box<dyn Error>> {
2222
"The cat is purring",
2323
"The bear is growling",
2424
];
25-
let embeddings = fetch_embeddings(&input, "search_document")?;
25+
let embeddings = embed(&input, "search_document")?;
2626
for (content, embedding) in input.iter().zip(embeddings) {
2727
let embedding = Bit::from_bytes(&embedding);
2828
client.execute(
@@ -32,7 +32,7 @@ fn main() -> Result<(), Box<dyn Error>> {
3232
}
3333

3434
let query = "forest";
35-
let query_embedding = fetch_embeddings(&[query], "search_query")?;
35+
let query_embedding = embed(&[query], "search_query")?;
3636
for row in client.query(
3737
"SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5",
3838
&[&Bit::from_bytes(&query_embedding[0])],
@@ -44,7 +44,7 @@ fn main() -> Result<(), Box<dyn Error>> {
4444
Ok(())
4545
}
4646

47-
fn fetch_embeddings(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
47+
fn embed(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
4848
let api_key = std::env::var("CO_API_KEY").or(Err("Set CO_API_KEY"))?;
4949

5050
let response: Value = ureq::post("https://api.cohere.com/v1/embed")

examples/openai/src/main.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,28 @@ fn main() -> Result<(), Box<dyn Error>> {
2222
"The cat is purring",
2323
"The bear is growling",
2424
];
25-
let embeddings = fetch_embeddings(&input)?;
26-
25+
let embeddings = embed(&input)?;
2726
for (content, embedding) in input.iter().zip(embeddings) {
28-
let embedding = Vector::from(embedding);
2927
client.execute(
3028
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
31-
&[&content, &embedding],
29+
&[&content, &Vector::from(embedding)],
3230
)?;
3331
}
3432

35-
let document_id = 2;
36-
for row in client.query("SELECT content FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5", &[&document_id])? {
33+
let query = "forest";
34+
let query_embedding = embed(&[query])?.drain(..).next().unwrap();
35+
for row in client.query(
36+
"SELECT content FROM documents ORDER BY embedding <=> $1 LIMIT 5",
37+
&[&Vector::from(query_embedding)],
38+
)? {
3739
let content: &str = row.get(0);
3840
println!("{}", content);
3941
}
4042

4143
Ok(())
4244
}
4345

44-
fn fetch_embeddings(input: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
46+
fn embed(input: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
4547
let api_key = std::env::var("OPENAI_API_KEY").or(Err("Set OPENAI_API_KEY"))?;
4648

4749
let response: Value = ureq::post("https://api.openai.com/v1/embeddings")

0 commit comments

Comments
 (0)