Skip to content

Commit

Permalink
Mixed resources (guillaume-be#291)
Browse files Browse the repository at this point in the history
* - made `merges` resource optional for all pipelines
- allow mixing local and remote resources for pipelines

* Updated changelog

* Fixed Clippy warnings
  • Loading branch information
guillaume-be authored and Miezhiko committed Mar 21, 2023
1 parent 5b0ddbe commit 672c6ae
Show file tree
Hide file tree
Showing 44 changed files with 203 additions and 150 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file. The format

## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines

## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.
Expand Down
4 changes: 3 additions & 1 deletion benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ fn create_text_generation_model() -> TextGenerationModel {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 30,
do_sample: true,
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt_neo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
do_sample: false,
Expand Down
5 changes: 1 addition & 4 deletions examples/generation_reformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
Expand All @@ -41,7 +38,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
do_sample: true,
Expand Down
5 changes: 1 addition & 4 deletions examples/generation_xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
Expand All @@ -39,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
max_length: 32,
do_sample: false,
num_beams: 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
Expand Down
4 changes: 2 additions & 2 deletions examples/summarization_pegasus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::Pegasus,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
length_penalty: 1.0,
num_beams: 4,
no_repeat_ngram_size: 3,
Expand Down
4 changes: 2 additions & 2 deletions examples/summarization_prophetnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
length_penalty: 1.2,
num_beams: 4,
no_repeat_ngram_size: 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ fn main() -> anyhow::Result<()> {
ModelType::T5,
weights_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
None,
);
let summarization_model = SummarizationModel::new(summarization_config)?;

Expand Down
2 changes: 1 addition & 1 deletion examples/translation_m2m100.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),
Expand Down
2 changes: 1 addition & 1 deletion examples/translation_marian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),
Expand Down
4 changes: 1 addition & 3 deletions examples/translation_mbart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ fn main() -> anyhow::Result<()> {
let config_resource =
RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let merges_resource =
RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);

let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
Expand All @@ -37,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
None,
source_languages,
target_languages,
Device::cuda_if_available(),
Expand Down
3 changes: 1 addition & 2 deletions examples/translation_t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ fn main() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);

let source_languages = [
Language::English,
Expand All @@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
None,
source_languages,
target_languages,
Device::cuda_if_available(),
Expand Down
10 changes: 9 additions & 1 deletion src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,15 @@ impl BartGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"BART expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::Bart,
Expand Down
10 changes: 9 additions & 1 deletion src/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,15 @@ impl GPT2Generator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT2 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::GPT2,
Expand Down
10 changes: 9 additions & 1 deletion src/gpt_neo/gpt_neo_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,15 @@ impl GptNeoGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT-Neo expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::GPTNeo,
Expand Down
2 changes: 1 addition & 1 deletion src/gpt_neo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! merges_resource: Some(merges_resource),
//! num_beams: 4,
//! no_repeat_ngram_size: 3,
//! device: Device::cuda_if_available(),
Expand Down
10 changes: 9 additions & 1 deletion src/m2m_100/m2m_100_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,15 @@ impl M2M100Generator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"M2M100 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::M2M100,
Expand Down
11 changes: 10 additions & 1 deletion src/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,16 @@ impl MarianGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<MarianGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let sentence_piece_path = generate_config.merges_resource.get_local_path()?;
let sentence_piece_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"Marian expects a merges (SentencePiece model) resources to be provided"
.to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::Marian,
Expand Down
10 changes: 9 additions & 1 deletion src/openai_gpt/openai_gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,15 @@ impl OpenAIGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;

let tokenizer = TokenizerOption::from_file(
ModelType::OpenAiGpt,
Expand Down
6 changes: 3 additions & 3 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub struct ConversationConfig {
/// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
Expand Down Expand Up @@ -131,9 +131,9 @@ impl Default for ConversationConfig {
vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Box::new(RemoteResource::from_pretrained(
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
)),
))),
min_length: 0,
max_length: 1000,
min_length_for_response: 64,
Expand Down
6 changes: 4 additions & 2 deletions src/pipelines/generation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub struct GenerateConfig {
/// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
Expand Down Expand Up @@ -143,7 +143,9 @@ impl Default for GenerateConfig {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 20,
do_sample: true,
Expand Down
28 changes: 16 additions & 12 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,20 @@ impl QuestionAnsweringConfig {
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
Expand Down Expand Up @@ -210,12 +212,12 @@ impl QuestionAnsweringConfig {
/// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new<R>(
pub fn custom_new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
Expand All @@ -225,7 +227,9 @@ impl QuestionAnsweringConfig {
max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
Expand Down
14 changes: 8 additions & 6 deletions src/pipelines/sequence_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,20 @@ impl SequenceClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SequenceClassificationConfig {
model_type,
Expand Down
Loading

0 comments on commit 672c6ae

Please sign in to comment.