Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for HF Tokenizers #408

Merged
merged 8 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
tokenizers output type conversion
  • Loading branch information
guillaume-be committed Aug 5, 2023
commit 08c430e10ba088c4d3c71a670940c8c41597d9cd
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ download-libtorch = ["tch/download-libtorch"]
onnx = ["ort", "ndarray"]
rustls-tls = ["cached-path/rustls-tls"]
default-tls = ["cached-path/default-tls"]
hf-tokenizers = ["tokenizers"]

[package.metadata.docs.rs]
features = ["doc-only"]
Expand All @@ -89,6 +90,7 @@ dirs = { version = "4", optional = true }
lazy_static = { version = "1", optional = true }
ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]}
ndarray = {version="0.15", optional = true}
tokenizers = {version="0.13.3", optional=true}

[dev-dependencies]
anyhow = "1"
Expand Down
33 changes: 11 additions & 22 deletions src/pipelines/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;

use std::fmt::Debug;

use std::path::{Path, PathBuf};
use tch::{Device, Kind, Tensor};

#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXModelConfig;
use crate::pipelines::tokenizers::HFTokenizer;

#[derive(Debug, Default)]
/// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past)
Expand Down Expand Up @@ -288,6 +291,8 @@ pub enum TokenizerOption {
FNet(FNetTokenizer),
/// Bart Tokenizer
Bart(RobertaTokenizer),
/// HF Tokenizer
HFTokenizer(HFTokenizer),
}

impl ConfigOption {
Expand Down Expand Up @@ -913,28 +918,12 @@ impl TokenizerOption {
Ok(tokenizer)
}

/// Returns the model type
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Roberta(_) => ModelType::Roberta,
Self::Bart(_) => ModelType::Bart,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::GPT2(_) => ModelType::GPT2,
Self::OpenAiGpt(_) => ModelType::OpenAiGpt,
Self::Reformer(_) => ModelType::Reformer,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
Self::MBart50(_) => ModelType::MBart,
Self::M2M100(_) | Self::NLLB(_) => ModelType::M2M100,
Self::FNet(_) => ModelType::FNet,
}
pub fn from_hf_tokenizer_file<P: AsRef<Path>, S: AsRef<Path>>(
tokenizer_file: P,
special_token_map: S,
) -> Result<Self, RustBertError> {
let hf_tokenizer = HFTokenizer::from_file(tokenizer_file, special_token_map)?;
Ok(TokenizerOption::HFTokenizer(hf_tokenizer))
}

/// Interface method
Expand Down
1 change: 1 addition & 0 deletions src/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ pub mod sequence_classification;
pub mod summarization;
pub mod text_generation;
pub mod token_classification;
pub mod tokenizers;
pub mod translation;
pub mod zero_shot_classification;

Expand Down
23 changes: 1 addition & 22 deletions src/pipelines/token_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification;
use ordered_float::OrderedFloat;
use rust_tokenizers::tokenizer::Tokenizer;
use rust_tokenizers::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
TokenizedInput,
Expand Down Expand Up @@ -1103,27 +1102,7 @@ impl TokenClassificationModel {
let offsets = &sentence_tokens.offsets[position_idx as usize];

let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLMRoberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Albert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLNet(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
_ => panic!(
"Token classification not implemented for {:?}!",
self.tokenizer.model_type()
),
},
None => self.tokenizer.decode(&[token_id], false, false),
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
Expand Down
184 changes: 184 additions & 0 deletions src/pipelines/tokenizers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use crate::RustBertError;
use rust_tokenizers::{Mask, Offset, OffsetSize, TokenizedInput};
use serde::{de, Deserialize, Deserializer};
use std::collections::HashSet;
use std::fmt;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use tokenizers::tokenizer::Tokenizer as HFBaseTokenizer;
use tokenizers::Encoding;

impl From<tokenizers::tokenizer::Error> for RustBertError {
fn from(error: tokenizers::tokenizer::Error) -> Self {
RustBertError::TokenizerError(error.to_string())
}
}

#[derive(Debug, Default, Clone, Deserialize)]
pub struct SpecialTokenMap {
pub unk_token: String,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub pad_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub bos_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub sep_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub cls_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub eos_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub mask_token: Option<String>,
pub additional_special_tokens: Option<HashSet<String>>,
}

fn string_or_added_token_struct<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
struct StringOrStruct;

impl<'de> de::Visitor<'de> for StringOrStruct {
type Value = Option<String>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string or map")
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}

fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: de::MapAccess<'de>,
{
let mut value = None;
while let Some(key) = map.next_key::<String>()? {
if key == "content" {
value = Some(map.next_value::<String>()?);
} else {
_ = map.next_value::<String>();
}
}
Ok(value)
}
}

Ok(deserializer.deserialize_any(StringOrStruct)?)
}

pub struct HFTokenizer {
tokenizer: HFBaseTokenizer,
special_token_map: SpecialTokenMap,
}

impl HFTokenizer {
pub fn from_file<P: AsRef<Path>, S: AsRef<Path>>(
tokenizer_file: P,
special_token_map: S,
) -> Result<Self, RustBertError> {
let tokenizer = HFBaseTokenizer::from_file(tokenizer_file)?;
let f = File::open(&special_token_map).map_err(|e| {
RustBertError::IOError(format!(
"{} special token map file not found :{}",
special_token_map.as_ref().display(),
e
))
})?;
let br = BufReader::new(f);
let special_token_map = serde_json::from_reader(br).map_err(|e| {
RustBertError::IOError(format!("Invalid special token mapping file {e}"))
})?;
Ok(Self {
tokenizer,
special_token_map,
})
}

fn encoding_to_tokenized_input(encoding: Encoding) -> TokenizedInput {
let token_ids = encoding
.get_ids()
.into_iter()
.map(|token_id| *token_id as i64)
.collect();
let segment_ids = encoding
.get_type_ids()
.into_iter()
.map(|segment_id| *segment_id as i8)
.collect();
let special_tokens_mask = encoding
.get_special_tokens_mask()
.into_iter()
.map(|segment_id| *segment_id as i8)
.collect();
let overflowing_tokens: Vec<i64> = encoding
.get_overflowing()
.iter()
.map(|encoding| encoding.get_ids())
.flatten()
.map(|token_id| *token_id as i64)
.collect();
let num_truncated_tokens = overflowing_tokens.len();
let token_offsets = encoding
.get_offsets()
.iter()
.map(|offset| {
Some(Offset {
begin: offset.0 as OffsetSize,
end: offset.1 as OffsetSize,
})
})
.collect();
let reference_offsets = encoding
.get_offsets()
.iter()
.map(|offset| (offset.0 as OffsetSize..offset.1 as OffsetSize).collect())
.collect();
let mask = encoding
.get_special_tokens_mask()
.into_iter()
.map(|segment_id| {
if *segment_id == 0 {
Mask::None
} else {
Mask::Special
}
})
.collect();
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
}
}

pub fn encode_list<S>(&self, text_list: &[S]) -> Result<Vec<TokenizedInput>, RustBertError>
where
S: AsRef<str> + Sync + Send + Clone,
{
let encoding_inputs = text_list.iter().map(|text| text.as_ref()).collect();
let mut encodings = self.tokenizer.encode_batch(encoding_inputs, true)?;
let mut tokenized_inputs: Vec<TokenizedInput> = Vec::with_capacity(encodings.len());
for encoding in encodings {
tokenized_inputs.push(Self::encoding_to_tokenized_input(encoding));
}

Ok(tokenized_inputs)
}
}