-
Notifications
You must be signed in to change notification settings - Fork 215
/
hf_tokenizers.rs
112 lines (98 loc) · 4.45 KB
/
hf_tokenizers.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#[cfg(feature = "hf-tokenizers")]
mod tests {
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2ModelResources};
use rust_bert::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider};
use std::fs::File;
use std::io::Write;
use tch::Device;
use tempfile::TempDir;
#[test]
fn gpt2_generation() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let dummy_vocab_resource = Box::new(LocalResource {
local_path: Default::default(),
});
let tokenizer_resource = Box::new(RemoteResource::from_pretrained((
"gpt2/tokenizer",
"https://huggingface.co/gpt2/resolve/main/tokenizer.json",
)));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource: dummy_vocab_resource,
merges_resource: None,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 1.2,
device: Device::Cpu,
num_return_sequences: 3,
..Default::default()
};
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"#
)?;
let tokenizer_path = tokenizer_resource.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
let input_context = "The dog";
let output = model.generate(&[input_context], None)?;
assert_eq!(output.len(), 3);
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
Ok(())
}
#[test]
fn distilbert_question_answering() -> anyhow::Result<()> {
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"pad_token": "[PAD]", "sep_token": "[SEP]", "cls_token": "[CLS]", "mask_token": "[MASK]", "unk_token": "[UNK]"}}"#
)?;
let tokenizer_resource = Box::new(RemoteResource::from_pretrained((
"distilbert-base-cased-distilled-squad/tokenizer",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/tokenizer.json",
)));
let tokenizer_path = tokenizer_resource.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
// Set-up question answering model
let qa_model = QuestionAnsweringModel::new_with_tokenizer(Default::default(), tokenizer)?;
// Define input
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.9978).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}
}