Skip to content

Commit

Permalink
Change generate return type to Result (#437)
Browse files Browse the repository at this point in the history
* - Changed the return type of generate method to be `Result`, removed fallible unwraps

* Fix doctests
  • Loading branch information
guillaume-be committed Dec 4, 2023
1 parent 9f2cd17 commit 1f4d344
Show file tree
Hide file tree
Showing 31 changed files with 134 additions and 113 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file. The format

## Changed
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
- (BREAKING) Text generation traits and pipelines (including conversation, summarization and translation) now return a `Result` for improved error handling

## [0.21.0] - 2023-06-03
## Added
Expand Down
4 changes: 2 additions & 2 deletions examples/buffer_resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ about exoplanets like K2-18b."];
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}
Expand All @@ -58,7 +58,7 @@ about exoplanets like K2-18b."];
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn main() -> anyhow::Result<()> {

let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;

for sentence in output {
println!("{sentence:?}");
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt2_hf_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {

let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;

for sentence in output {
println!("{sentence:?}");
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt_neo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() -> anyhow::Result<()> {

let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;

for sentence in output {
println!("{sentence}");
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gptj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
"It was a very nice and sunny",
"It was a gloom winter night, and",
];
let output = model.generate(&prompts, None);
let output = model.generate(&prompts, None)?;

assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking");
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_reformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {

let input_context_1 = "The really great men must, I think,";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;

for sentence in output {
println!("{sentence}");
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;

let input_context = "Once upon a time,";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;

for sentence in output {
println!("{sentence}");
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_pegasus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_prophetnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/gpt_neo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
//!
//! let input_context_1 = "It was a very nice and sunny";
//! let input_context_2 = "It was a gloom winter night, and";
//! let output = model.generate(&[input_context_1, input_context_2], None);
//! let output = model.generate(&[input_context_1, input_context_2], None)?;
//!
//! for sentence in output {
//! println!("{}", sentence);
Expand Down
2 changes: 1 addition & 1 deletion src/models/prophetnet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
//! about exoplanets like K2-18b."];
//!
//! // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
//! let _output = summarization_model.summarize(&input);
//! let _output = summarization_model.summarize(&input)?;
//! for sentence in _output {
//! println!("{}", sentence);
//! }
Expand Down
17 changes: 9 additions & 8 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,14 +763,14 @@ impl ConversationOption {
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
match *self {
) -> Result<Vec<Vec<i64>>, RustBertError> {
Ok(match *self {
Self::GPT2(ref model) => model
.generate_from_ids_and_past(input_ids, attention_mask, None)
.generate_from_ids_and_past(input_ids, attention_mask, None)?
.into_iter()
.map(|output| output.indices)
.collect(),
}
})
}
}

Expand Down Expand Up @@ -887,9 +887,9 @@ impl ConversationModel {
pub fn generate_responses<'a>(
&self,
conversation_manager: &'a mut ConversationManager,
) -> HashMap<&'a Uuid, &'a str> {
) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
if !active_uuid.is_empty() {
let updated_conversations = if !active_uuid.is_empty() {
let texts = active_conversations
.iter()
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
Expand All @@ -906,7 +906,7 @@ impl ConversationModel {
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self
.model
.generate_from_ids_and_past(input_tensor, Some(attention_mask));
.generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
let removed_padding_quantities = self.clean_padding_indices(&mut generated);

let mut output = HashMap::with_capacity(active_uuid.len());
Expand Down Expand Up @@ -936,7 +936,8 @@ impl ConversationModel {
output
} else {
HashMap::new()
}
};
Ok(updated_conversations)
}

fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {
Expand Down
43 changes: 27 additions & 16 deletions src/pipelines/generation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1775,11 +1775,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedTextOutput>
) -> Result<Vec<GeneratedTextOutput>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let indices_outputs = self.generate_indices(prompt_texts, generate_options);
let indices_outputs = self.generate_indices(prompt_texts, generate_options)?;
let mut output = Vec::with_capacity(indices_outputs.len());
for generated_sequence in indices_outputs {
output.push(GeneratedTextOutput {
Expand All @@ -1789,7 +1789,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
score: generated_sequence.score,
});
}
output
Ok(output)
}

/// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
Expand Down Expand Up @@ -1869,7 +1869,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedIndicesOutput>
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
Expand All @@ -1896,11 +1896,12 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
}
None => match self.get_bos_id() {
Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
None => panic!(
None => return Err(RustBertError::ValueError(
"A model with a BOS token must be used to start generation with an empty input"
),
.to_string(),
)),
},
_ => return Vec::new(),
_ => return Ok(Vec::new()),
};
self.generate_from_ids_and_past(input_ids, None, generate_options)
}
Expand Down Expand Up @@ -1960,7 +1961,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
mut input_ids: Tensor,
mut attention_mask: Option<Tensor>,
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedIndicesOutput> {
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).cloned();

let config = PrivateLanguageGenerator::get_config(self);
Expand Down Expand Up @@ -2033,7 +2034,9 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
};

let encoder_outputs = if self.is_encoder_decoder() {
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
let encoder_outputs = self
.encode(&input_ids, Some(&attention_mask))
.ok_or(RustBertError::UnsupportedError)?;
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
.view((-1, 1))
.repeat([1, num_beams * effective_batch_mult])
Expand Down Expand Up @@ -2067,10 +2070,11 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
(input_ids, attention_mask)
}
} else {
let decoder_start_token_id = decoder_start_token_id.unwrap_or_else(|| {
self.get_decoder_start_id()
.expect("decoder start id must be specified for encoder decoders")
});
let decoder_start_token_id = decoder_start_token_id
.or(self.get_decoder_start_id())
.ok_or(RustBertError::ValueError(
"decoder start id must be specified for encoder decoders".to_string(),
))?;
let input_ids = Tensor::full(
[effective_batch_size * num_beams, 1],
decoder_start_token_id,
Expand Down Expand Up @@ -2103,9 +2107,16 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
config.max_length
};

if let Some(max_length) = max_length {
if input_ids.size2()?.1 > max_length {
return Err(RustBertError::ValueError("The input ids exceeds the maximum length for generation.\
Reduce the size of the provided input ids or increase the allowable maximum generation length.".to_string()));
}
}

if max_length.is_none() & eos_token_ids.is_none() {
panic!("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`")
return Err(RustBertError::InvalidConfigurationError("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`".to_string()));
}

let gen_opt = InternalGenerateOptions {
Expand Down Expand Up @@ -2182,7 +2193,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
token_scores,
});
}
output
Ok(output)
}

/// Returns a reference to the text generator's tokenizer
Expand Down
20 changes: 10 additions & 10 deletions src/pipelines/summarization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,43 +338,43 @@ impl SummarizationOption {
}

/// Interface method to generate() of the particular models.
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
match *self {
Ok(match *self {
Self::Bart(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::LongT5(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate(prompt_texts, None)
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
}
})
}
}

Expand Down Expand Up @@ -506,7 +506,7 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
Expand Down
Loading

0 comments on commit 1f4d344

Please sign in to comment.