-
Notifications
You must be signed in to change notification settings - Fork 215
/
sentence_embeddings_local.rs
37 lines (34 loc) · 1.33 KB
/
sentence_embeddings_local.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
/// Download model:
/// ```sh
/// git lfs install
/// git -C resources clone https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2
/// ```
/// Prepare model:
/// ```sh
/// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin
/// ```
///
/// For models missing the prefix in their saved weights (e.g. Distil-based models), the
/// conversion needs to be updated to include this prefix so that the weights can be found:
/// ```sh
/// python ./utils/convert_model.py resources/path/to/pytorch_model.bin --prefix distilbert.
/// ```
///
/// For models including a dense projection layer (e.g. Distil-based models), these weights
/// need to be converted as well:
/// ```sh
/// python ../utils/convert_model.py resources/path/to/2_Dense/pytorch_model.bin --suffix
/// ```
fn main() -> anyhow::Result<()> {
// Set-up sentence embeddings model
let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2")
.with_device(tch::Device::cuda_if_available())
.create_model()?;
// Define input
let sentences = ["this is an example sentence", "each sentence is converted"];
// Generate Embeddings
let embeddings = model.encode(&sentences)?;
println!("{embeddings:?}");
Ok(())
}