Skip to content

Commit

Permalink
Expose embeddings dimensions in SentenceEmbeddingsModel (#371)
Browse files Browse the repository at this point in the history
* Expose method to extract sentence embeddings dimensions

* Updated changelog
  • Loading branch information
guillaume-be committed May 7, 2023
1 parent 66944eb commit 9fd7983
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. The format
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.
- Addition of a `get_embedding_dim` method to get the dimension of the embeddings for sentence embeddings pipelines

## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
Expand Down
13 changes: 9 additions & 4 deletions src/pipelines/sentence_embeddings/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ pub struct SentenceEmbeddingsModel {
pooling_layer: Pooling,
dense_layer: Option<Dense>,
normalize_embeddings: bool,
embeddings_dim: i64,
}

impl SentenceEmbeddingsModel {
Expand Down Expand Up @@ -196,7 +197,6 @@ impl SentenceEmbeddingsModel {
.validate()?;

// Setup tokenizer

let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
tokenizer_config_resource.get_local_path()?,
);
Expand All @@ -223,7 +223,6 @@ impl SentenceEmbeddingsModel {
)?;

// Setup transformer

let mut var_store = nn::VarStore::new(device);
let transformer_config = ConfigOption::from_file(
transformer_type,
Expand All @@ -234,15 +233,15 @@ impl SentenceEmbeddingsModel {
var_store.load(transformer_weights_resource.get_local_path()?)?;

// Setup pooling layer

let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
let mut embeddings_dim = pooling_config.word_embedding_dimension;
let pooling_layer = Pooling::new(pooling_config);

// Setup dense layer

let dense_layer = if modules.dense_module().is_some() {
let dense_config =
DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
embeddings_dim = dense_config.out_features;
Some(Dense::new(
dense_config,
dense_weights_resource.unwrap().get_local_path()?,
Expand All @@ -264,6 +263,7 @@ impl SentenceEmbeddingsModel {
pooling_layer,
dense_layer,
normalize_embeddings,
embeddings_dim,
})
}

Expand All @@ -282,6 +282,11 @@ impl SentenceEmbeddingsModel {
self.tokenizer_truncation_strategy = truncation_strategy;
}

/// Return the embedding output dimension
pub fn get_embedding_dim(&self) -> Result<i64, RustBertError> {
Ok(self.embeddings_dim)
}

/// Tokenizes the inputs
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
where
Expand Down

0 comments on commit 9fd7983

Please sign in to comment.