-
Notifications
You must be signed in to change notification settings - Fork 209
/
gpt_neo.rs
153 lines (139 loc) · 5.33 KB
/
gpt_neo.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use rust_bert::gpt_neo::{
GptNeoConfig, GptNeoConfigResources, GptNeoForCausalLM, GptNeoMergesResources,
GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn gpt_neo_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let weights_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
)?;
let mut config = GptNeoConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let gpt_neo_model = GptNeoForCausalLM::new(vs.root(), &config)?;
vs.load(weights_path)?;
// Define input
let input = ["It was a sunny"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
gpt_neo_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)?;
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(&[next_word_id], true, true);
let next_score = model_output
.lm_logits
.get(0)
.get(-1)
.double_value(&[next_word_id]);
// Output
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
assert_eq!(next_word_id, 1110_i64);
assert!((next_score - (-0.0279)).abs() < 1e-4);
assert_eq!(next_word, String::from(" day"));
// Attentions & hidden states
assert!(model_output.all_attentions.is_some());
assert_eq!(model_output.all_attentions.as_ref().unwrap().len(), 12);
assert_eq!(
model_output.all_attentions.as_ref().unwrap()[0].size(),
vec![1, 12, 4, 4]
);
assert_eq!(
model_output.all_attentions.as_ref().unwrap()[1].size(),
vec![1, 12, 4, 4]
);
assert!(model_output.all_hidden_states.is_some());
assert_eq!(model_output.all_hidden_states.as_ref().unwrap().len(), 12);
assert_eq!(
model_output.all_hidden_states.as_ref().unwrap()[0].size(),
vec![1, 4, 768]
);
Ok(())
}
#[test]
fn test_generation_gpt_neo() -> anyhow::Result<()> {
// Resources paths
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));
// Set-up model
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: Some(32),
do_sample: false,
early_stopping: true,
num_beams: 4,
num_return_sequences: 1,
device: Device::Cpu,
..Default::default()
};
let model = TextGenerationModel::new(generation_config)?;
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day. The sun was shining through the clouds, and the sky was clear. The wind was blowing through the trees,");
assert_eq!(output[1], "It was a gloom winter night, and the sky was dark and cold, and the wind was blowing thick and heavy.\n\n\"What\'s the matter?\"");
Ok(())
}