Skip to content

Commit

Permalink
enable t5 text generation pipeline (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed May 4, 2023
1 parent 06e6b32 commit 0eda850
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/pipelines/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguag
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::reformer::ReformerGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::xlnet::XLNetGenerator;

#[cfg(feature = "remote")]
Expand Down Expand Up @@ -199,6 +200,8 @@ pub enum TextGenerationOption {
XLNet(XLNetGenerator),
/// Text Generator based on Reformer model
Reformer(ReformerGenerator),
/// Text Generator based on T5 model
T5(T5Generator),
}

impl TextGenerationOption {
Expand All @@ -222,6 +225,7 @@ impl TextGenerationOption {
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
Expand Down Expand Up @@ -252,6 +256,10 @@ impl TextGenerationOption {
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(
GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
Expand All @@ -268,6 +276,7 @@ impl TextGenerationOption {
Self::GPTJ(_) => ModelType::GPTJ,
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
Self::T5(_) => ModelType::T5,
}
}

Expand All @@ -280,6 +289,7 @@ impl TextGenerationOption {
Self::GPTJ(model_ref) => model_ref._get_tokenizer(),
Self::XLNet(model_ref) => model_ref._get_tokenizer(),
Self::Reformer(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
}
}

Expand All @@ -292,6 +302,7 @@ impl TextGenerationOption {
Self::GPTJ(model_ref) => model_ref._get_tokenizer_mut(),
Self::XLNet(model_ref) => model_ref._get_tokenizer_mut(),
Self::Reformer(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
}
}

Expand Down Expand Up @@ -341,6 +352,11 @@ impl TextGenerationOption {
.into_iter()
.map(|output| output.indices)
.collect(),
Self::T5(ref model) => model
.generate_indices(prompt_texts, generate_options)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}

Expand All @@ -352,6 +368,7 @@ impl TextGenerationOption {
Self::GPTJ(model_ref) => model_ref.half(),
Self::XLNet(model_ref) => model_ref.half(),
Self::Reformer(model_ref) => model_ref.half(),
Self::T5(model_ref) => model_ref.half(),
}
}

Expand All @@ -363,6 +380,7 @@ impl TextGenerationOption {
Self::GPTJ(model_ref) => model_ref.float(),
Self::XLNet(model_ref) => model_ref.float(),
Self::Reformer(model_ref) => model_ref.float(),
Self::T5(model_ref) => model_ref.float(),
}
}

Expand All @@ -374,6 +392,7 @@ impl TextGenerationOption {
Self::GPTJ(model_ref) => model_ref.set_device(device),
Self::XLNet(model_ref) => model_ref.set_device(device),
Self::Reformer(model_ref) => model_ref.set_device(device),
Self::T5(model_ref) => model_ref.set_device(device),
}
}
}
Expand Down

0 comments on commit 0eda850

Please sign in to comment.