Skip to content

Commit

Permalink
Merge pull request mistralai#83 from mistralai/moe
Browse files Browse the repository at this point in the history
Add MoE and Pipelining support

Update readme

Update requirements

Add faster loading

Make sliding window optional and add rope_theta with smart default
  • Loading branch information
diegolascasas committed Dec 12, 2023
2 parents 147c4e6 + b818190 commit eb3d6c2
Show file tree
Hide file tree
Showing 8 changed files with 847 additions and 110 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ To run logits equivalence through chunking and sliding window, launch
python -m test_generate
```

### Running large models

When running models that are too large to fit a single GPU's memory, use pipeline parallelism (PP) and `torchrun`. This is needed to run `Mixtral-7B-8x`. The code below does 2-way PP.

```
torchrun --nproc-per-node 2 -m main demo /path/to/mixtral-7B-8x-v0.1/ --num_pipeline_ranks=2
```

> [!Note]
> PP is not supported when running in interactive mode.
# Sliding window attention

Expand Down Expand Up @@ -112,6 +122,17 @@ For this we can choose as chunk size the window size. For each chunk, we thus ne
![Chunking](assets/chunking.png)


# Sparse Mixture of Experts (SMoE)

Sparse Mixture of Experts allows one to decouple throughput from memory costs by only activating subsets of the overall model for each token. In this approach, each token is assigned to one or more "experts" -- a separate set of weights -- and only processed by sunch experts. This division happens at feedforward layers of the model. The expert models specialize in different aspects of the data, allowing them to capture complex patterns and make more accurate predictions.

![SMoE](assets/smoe.png)

## Pipeline Parallelism

Pipeline parallelism is a set of techniques for partitioning models, enabling the distribution of a large model across multiple GPUs. We provide a simple implementation of pipeline parallelism, which allows our larger models to be executed within the memory constraints of modern GPUs. Note that this implementation favours simplicity over throughput efficiency, and most notabably does not include microbatching.


## Integrations and related projects


Expand Down
Binary file added assets/smoe.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 33 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mistral.cache import RotatingBufferCache
import logging
import torch
import fire
from typing import List
Expand Down Expand Up @@ -31,7 +32,7 @@ def sample(logits: torch.Tensor, temperature: float, top_p: float):


@torch.inference_mode()
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7):
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, temperature: float, chunk_size: int = None):
model = model.eval()
B, V = len(prompts), model.args.vocab_size

Expand All @@ -40,8 +41,16 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma
seqlens = [len(x) for x in encoded_prompts]

# Cache
cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens)
cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim)
cache_window = max(seqlens) + max_tokens
if model.args.sliding_window is not None and cache_window > model.args.sliding_window:
cache_window = model.args.sliding_window
cache = RotatingBufferCache(
model.n_local_layers,
model.args.max_batch_size,
cache_window,
model.args.n_kv_heads,
model.args.head_dim,
)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()

Expand Down Expand Up @@ -81,6 +90,7 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma

# decode
generated_tokens = []
assert last_token_prelogits is not None
for i_token in range(max_tokens):
next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8)

Expand Down Expand Up @@ -117,26 +127,40 @@ def interactive(model_path: str, max_tokens: int = 35, temperature: float = 0.7)
print(res[0])
print("=====================")

def demo(model_path: str, max_tokens: int = 35, temperature: float = 0):

def demo(
model_path: str, max_tokens: int = 35, temperature: float = 0, num_pipeline_ranks=1
):
if num_pipeline_ranks > 1:
torch.distributed.init_process_group()
torch.cuda.set_device(torch.distributed.get_rank())
should_print = torch.distributed.get_rank() == 0
else:
should_print = True
tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
transformer = Transformer.from_folder(
Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
)

res, _logprobs = generate(
[
"This is a test",
"This is another test",
"This is another great test",
"This is a third test, mistral AI is very good at testing. ",
],
transformer,
tokenizer,
max_tokens=max_tokens,
temperature=temperature,
)
for x in res:
print(x)
print("=====================")
if should_print:
for x,l in zip(res, _logprobs):
print(x)
logging.debug('Logprobs: %s',l)
print("=====================")

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
fire.Fire({
"interactive": interactive,
"demo": demo,
Expand Down
Loading

0 comments on commit eb3d6c2

Please sign in to comment.