Skip to content

Commit

Permalink
Merge pull request mistralai#52 from mistralai/tlax/add_forward_partial
Browse files Browse the repository at this point in the history
Add simple classification example
  • Loading branch information
timlacroix committed Oct 19, 2023
2 parents 745c58a + 9b844cc commit 7fbbfb3
Show file tree
Hide file tree
Showing 3 changed files with 635 additions and 15 deletions.
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma
assert all(len(p) > 0 for p in prompt_chunks)
prelogits = model.forward(
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
cache,
seqlens=[len(p) for p in prompt_chunks]
seqlens=[len(p) for p in prompt_chunks],
cache=cache
)
logits = torch.log_softmax(prelogits, dim=-1)

Expand Down Expand Up @@ -89,7 +89,7 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma
logprobs[i].append(last_token_logits[i, next_token[i]].item())

generated_tokens.append(next_token[:, None])
last_token_prelogits = model.forward(next_token, cache, seqlens=[1] * len(prompts))
last_token_prelogits = model.forward(next_token, seqlens=[1] * len(prompts), cache=cache)
assert last_token_prelogits.shape == (B, V)

generated_words = []
Expand Down
54 changes: 42 additions & 12 deletions mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from pathlib import Path
import json
from typing import List
from typing import List, Optional

from mistral.rope import precompute_freqs_cis, apply_rotary_emb
from mistral.cache import CacheView, RotatingBufferCache
Expand All @@ -28,6 +28,20 @@ class ModelArgs:
max_batch_size: int = 0


@dataclass
class SimpleInputMetadata:
# rope absolute positions
positions: torch.Tensor

@staticmethod
def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
return SimpleInputMetadata(
positions = torch.cat(
[torch.arange(0, seqlen) for seqlen in seqlens]
).to(device=device, dtype=torch.long)
)


def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
Expand Down Expand Up @@ -72,7 +86,7 @@ def __init__(self, args: ModelArgs):
def forward(
self, x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: CacheView,
cache: Optional[CacheView],
) -> torch.Tensor:
seqlen_sum, _ = x.shape

Expand All @@ -82,7 +96,9 @@ def forward(
xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

if cache.prefill:
if cache is None:
key, val = xk, xv
elif cache.prefill:
key, val = cache.interleave_kv(xk, xv)
cache.update(xk, xv)
else:
Expand All @@ -96,7 +112,7 @@ def forward(

# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, cache.mask)
output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)

return self.wo(output.view_as(x))

Expand Down Expand Up @@ -151,7 +167,7 @@ def __init__(self, args: ModelArgs):
self.args = args

def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: CacheView
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
Expand Down Expand Up @@ -192,25 +208,39 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.tok_embeddings.weight.device

def forward(
def forward_partial(
self,
input_ids: torch.Tensor,
cache: RotatingBufferCache,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

input_metadata = cache.get_input_metadata(seqlens)
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
h = self.tok_embeddings(input_ids)
freqs_cis = self.freqs_cis[input_metadata.positions]

for layer_id, layer in enumerate(self.layers):
h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))
cache_view = None if cache is None else cache.get_view(layer_id, input_metadata)
h = layer(h, freqs_cis, cache_view)

if cache is not None:
cache.update_seqlens(seqlens)

cache.update_seqlens(seqlens)
return self.norm(h)

return self.output(self.norm(h)).float()
def forward(
self,
input_ids: torch.Tensor,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
return self.output(self.forward_partial(
input_ids, seqlens, cache=cache
)).float()

@staticmethod
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16) -> "Transformer":
Expand Down
590 changes: 590 additions & 0 deletions tutorials/classifier.ipynb

Large diffs are not rendered by default.

0 comments on commit 7fbbfb3

Please sign in to comment.