forked from mistralai/mistral-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
253 lines (204 loc) · 7.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import torch
from torch import nn
from dataclasses import dataclass
from pathlib import Path
import json
from typing import List, Optional
from mistral.rope import precompute_freqs_cis, apply_rotary_emb
from mistral.cache import CacheView, RotatingBufferCache
from xformers.ops.fmha import (
memory_efficient_attention,
)
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
sliding_window: int
norm_eps: float
vocab_size: int
max_batch_size: int = 0
@dataclass
class SimpleInputMetadata:
# rope absolute positions
positions: torch.Tensor
@staticmethod
def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
return SimpleInputMetadata(
positions = torch.cat(
[torch.arange(0, seqlen) for seqlen in seqlens]
).to(device=device, dtype=torch.long)
)
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
return keys, values
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.sliding_window = self.args.sliding_window
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(
args.dim,
args.n_heads * args.head_dim,
bias=False
)
self.wk = nn.Linear(
args.dim,
args.n_kv_heads * args.head_dim,
bias=False
)
self.wv = nn.Linear(
args.dim,
args.n_kv_heads * args.head_dim,
bias=False
)
self.wo = nn.Linear(
args.n_heads * args.head_dim,
args.dim,
bias=False
)
def forward(
self, x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[CacheView],
) -> torch.Tensor:
seqlen_sum, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim)
xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if cache is None:
key, val = xk, xv
elif cache.prefill:
key, val = cache.interleave_kv(xk, xv)
cache.update(xk, xv)
else:
cache.update(xk, xv)
key, val = cache.key, cache.value
key = key.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.args.head_dim)
val = val.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.args.head_dim)
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)
return self.wo(output.view_as(x))
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.dim,
args.hidden_dim,
bias=False
)
self.w2 = nn.Linear(
args.hidden_dim,
args.dim,
bias=False
)
self.w3 = nn.Linear(
args.dim,
args.hidden_dim,
bias=False
)
def forward(self, x) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList(
[TransformerBlock(args=args) for _ in range(args.n_layers)]
)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False
)
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda")
@property
def dtype(self) -> torch.dtype:
return self.tok_embeddings.weight.dtype
@property
def device(self) -> torch.device:
return self.tok_embeddings.weight.device
def forward_partial(
self,
input_ids: torch.Tensor,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
h = self.tok_embeddings(input_ids)
freqs_cis = self.freqs_cis[input_metadata.positions]
for layer_id, layer in enumerate(self.layers):
cache_view = None if cache is None else cache.get_view(layer_id, input_metadata)
h = layer(h, freqs_cis, cache_view)
if cache is not None:
cache.update_seqlens(seqlens)
return self.norm(h)
def forward(
self,
input_ids: torch.Tensor,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
return self.output(self.forward_partial(
input_ids, seqlens, cache=cache
)).float()
@staticmethod
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16) -> "Transformer":
with open(folder / 'params.json', 'r') as f:
model_args = ModelArgs(**json.loads(f.read()))
model_args.max_batch_size = max_batch_size
model = Transformer(model_args).to(device=device, dtype=dtype)
loaded = torch.load(folder / 'consolidated.00.pth')
model.load_state_dict(loaded)
return model