Skip to content

Commit

Permalink
Add new_with_tokenizer for SentenceEmbeddingsModel (#407)
Browse files Browse the repository at this point in the history
* Add option to create sentence embeddings model with custom tokenizer

* updated changelog

* Fix Clippy warning

* Fix config lookup

* trigger ci
  • Loading branch information
guillaume-be committed Aug 3, 2023
1 parent a655b3c commit af3839e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines

## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
Expand Down
64 changes: 44 additions & 20 deletions src/pipelines/sentence_embeddings/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,32 +177,18 @@ impl SentenceEmbeddingsModel {
///
/// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
let SentenceEmbeddingsConfig {
modules_config_resource,
sentence_bert_config_resource,
tokenizer_config_resource,
tokenizer_vocab_resource,
tokenizer_merges_resource,
transformer_type,
transformer_config_resource,
transformer_weights_resource,
pooling_config_resource,
dense_config_resource,
dense_weights_resource,
device,
} = config;

let modules =
SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
.validate()?;

// Setup tokenizer
let transformer_type = config.transformer_type;
let tokenizer_vocab_resource = &config.tokenizer_vocab_resource;
let tokenizer_merges_resource = &config.tokenizer_merges_resource;
let tokenizer_config_resource = &config.tokenizer_config_resource;
let sentence_bert_config_resource = &config.sentence_bert_config_resource;
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
tokenizer_config_resource.get_local_path()?,
);
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
sentence_bert_config_resource.get_local_path()?,
);

let tokenizer = TokenizerOption::from_file(
transformer_type,
tokenizer_vocab_resource
Expand All @@ -222,6 +208,44 @@ impl SentenceEmbeddingsModel {
tokenizer_config.add_prefix_space,
)?;

Self::new_with_tokenizer(config, tokenizer)
}

/// Build a new `ONNXCausalGenerator` from a `GenerateConfig` and `TokenizerOption`.
///
/// A tokenizer must be provided by the user and can be customized to use non-default settings.
///
/// # Arguments
///
/// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
/// * `tokenizer` - `TokenizerOption` tokenizer to use for question answering.
pub fn new_with_tokenizer(
config: SentenceEmbeddingsConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
let SentenceEmbeddingsConfig {
modules_config_resource,
sentence_bert_config_resource,
tokenizer_config_resource: _,
tokenizer_vocab_resource: _,
tokenizer_merges_resource: _,
transformer_type,
transformer_config_resource,
transformer_weights_resource,
pooling_config_resource,
dense_config_resource,
dense_weights_resource,
device,
} = config;

let modules =
SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
.validate()?;

let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
sentence_bert_config_resource.get_local_path()?,
);

// Setup transformer
let mut var_store = nn::VarStore::new(device);
let transformer_config = ConfigOption::from_file(
Expand Down

0 comments on commit af3839e

Please sign in to comment.