Skip to content

Commit

Permalink
no cache
Browse files Browse the repository at this point in the history
  • Loading branch information
timlacroix committed Oct 17, 2023
1 parent a165337 commit adf138d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma
assert all(len(p) > 0 for p in prompt_chunks)
prelogits = model.forward(
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
cache,
seqlens=[len(p) for p in prompt_chunks]
seqlens=[len(p) for p in prompt_chunks],
cache=cache
)
logits = torch.log_softmax(prelogits, dim=-1)

Expand Down Expand Up @@ -89,7 +89,7 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma
logprobs[i].append(last_token_logits[i, next_token[i]].item())

generated_tokens.append(next_token[:, None])
last_token_prelogits = model.forward(next_token, cache, seqlens=[1] * len(prompts))
last_token_prelogits = model.forward(next_token, seqlens=[1] * len(prompts), cache=cache)
assert last_token_prelogits.shape == (B, V)

generated_words = []
Expand Down
46 changes: 33 additions & 13 deletions mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from pathlib import Path
import json
from typing import List
from typing import List, Optional

from mistral.rope import precompute_freqs_cis, apply_rotary_emb
from mistral.cache import CacheView, RotatingBufferCache
Expand All @@ -28,6 +28,20 @@ class ModelArgs:
max_batch_size: int = 0


@dataclass
class SimpleInputMetadata:
# rope absolute positions
positions: torch.Tensor

@staticmethod
def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
return SimpleInputMetadata(
positions = torch.cat(
[torch.arange(0, seqlen) for seqlen in seqlens]
).to(device=device, dtype=torch.long)
)


def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
Expand Down Expand Up @@ -72,7 +86,7 @@ def __init__(self, args: ModelArgs):
def forward(
self, x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: CacheView,
cache: Optional[CacheView],
) -> torch.Tensor:
seqlen_sum, _ = x.shape

Expand All @@ -82,7 +96,9 @@ def forward(
xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

if cache.prefill:
if cache is None:
key, val = xk, xv
elif cache.prefill:
key, val = cache.interleave_kv(xk, xv)
cache.update(xk, xv)
else:
Expand All @@ -96,7 +112,7 @@ def forward(

# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, cache.mask)
output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)

return self.wo(output.view_as(x))

Expand Down Expand Up @@ -151,7 +167,7 @@ def __init__(self, args: ModelArgs):
self.args = args

def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: CacheView
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
Expand Down Expand Up @@ -195,31 +211,35 @@ def device(self) -> torch.device:
def forward_partial(
self,
input_ids: torch.Tensor,
cache: RotatingBufferCache,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

input_metadata = cache.get_input_metadata(seqlens)
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
h = self.tok_embeddings(input_ids)
freqs_cis = self.freqs_cis[input_metadata.positions]

for layer_id, layer in enumerate(self.layers):
h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

cache.update_seqlens(seqlens)
cache_view = None if cache is None else cache.get_view(layer_id, input_metadata)
h = layer(h, freqs_cis, cache_view)

if cache is not None:
cache.update_seqlens(seqlens)

return self.norm(h)

def forward(
self,
input_ids: torch.Tensor,
cache: RotatingBufferCache,
seqlens: List[int],
cache: Optional[RotatingBufferCache]=None,
) -> torch.Tensor:
return self.output(self.forward_partial(
input_ids, cache, seqlens
input_ids, seqlens, cache=cache
)).float()

@staticmethod
Expand Down

0 comments on commit adf138d

Please sign in to comment.