diff --git a/CHANGELOG.md b/CHANGELOG.md index bcb8a709..8bbaf7bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## Added +- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines + ## Fixed - (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences). - Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index 7be7e135..f96e6383 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -177,32 +177,18 @@ impl SentenceEmbeddingsModel { /// /// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU) pub fn new(config: SentenceEmbeddingsConfig) -> Result { - let SentenceEmbeddingsConfig { - modules_config_resource, - sentence_bert_config_resource, - tokenizer_config_resource, - tokenizer_vocab_resource, - tokenizer_merges_resource, - transformer_type, - transformer_config_resource, - transformer_weights_resource, - pooling_config_resource, - dense_config_resource, - dense_weights_resource, - device, - } = config; - - let modules = - SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?) - .validate()?; - - // Setup tokenizer + let transformer_type = config.transformer_type; + let tokenizer_vocab_resource = &config.tokenizer_vocab_resource; + let tokenizer_merges_resource = &config.tokenizer_merges_resource; + let tokenizer_config_resource = &config.tokenizer_config_resource; + let sentence_bert_config_resource = &config.sentence_bert_config_resource; let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file( tokenizer_config_resource.get_local_path()?, ); let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file( sentence_bert_config_resource.get_local_path()?, ); + let tokenizer = TokenizerOption::from_file( transformer_type, tokenizer_vocab_resource @@ -222,6 +208,44 @@ impl SentenceEmbeddingsModel { tokenizer_config.add_prefix_space, )?; + Self::new_with_tokenizer(config, tokenizer) + } + + /// Build a new `ONNXCausalGenerator` from a `GenerateConfig` and `TokenizerOption`. + /// + /// A tokenizer must be provided by the user and can be customized to use non-default settings. + /// + /// # Arguments + /// + /// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU) + /// * `tokenizer` - `TokenizerOption` tokenizer to use for question answering. + pub fn new_with_tokenizer( + config: SentenceEmbeddingsConfig, + tokenizer: TokenizerOption, + ) -> Result { + let SentenceEmbeddingsConfig { + modules_config_resource, + sentence_bert_config_resource, + tokenizer_config_resource: _, + tokenizer_vocab_resource: _, + tokenizer_merges_resource: _, + transformer_type, + transformer_config_resource, + transformer_weights_resource, + pooling_config_resource, + dense_config_resource, + dense_weights_resource, + device, + } = config; + + let modules = + SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?) + .validate()?; + + let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file( + sentence_bert_config_resource.get_local_path()?, + ); + // Setup transformer let mut var_store = nn::VarStore::new(device); let transformer_config = ConfigOption::from_file(