Skip to content

Commit

Permalink
Relative paths and model configuration (Lightning-AI#28)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <luca@lightning.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people authored Mar 27, 2023
1 parent f51c1a7 commit 17d7f7e
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__

# data
data
checkpoints
!data/shakespeare/prepare.py

# downloaded by scripts/compare.py
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ python scripts/convert_checkpoint.py \
You can now run inference:

```bash
python scripts/generate.py \
--prompt "Hello, my name is" \
--checkpoint_path checkpoints/lit-llama/7B/state_dict.pt \
--tokenizer_path checkpoints/lit-llama/tokenizer.model
python scripts/generate.py --prompt "Hello, my name is"
```

This will run using the 7B model and will require roughly 26 GB of GPU memory (A100 GPU).
Expand Down
18 changes: 13 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import sys
import time
import torch
from typing import Optional

import lightning as L
import torch

from model import LLaMA, LLaMAConfig
from model import LLaMA
from quantization.bnb import quantize as quantize_model
from tokenizer import Tokenizer

Expand Down Expand Up @@ -66,8 +68,9 @@ def main(
# compilation fails as it does not support torch.complex64 for RoPE
# compile: bool = False,
accelerator: str = "auto",
checkpoint_path: str = "/srv/data/checkpoints/llama/converted_nano/7B/state_dict.pth",
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_nano/tokenizer.model",
checkpoint_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
model_size: str = "7B",
quantize: bool = False,
):
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Expand All @@ -86,6 +89,11 @@ def main(
tokenizer_path: The tokenizer path to load.
quantize: Whether to quantize the model using the `LLM.int8()` method
"""
if not checkpoint_path:
checkpoint_path = f"./checkpoints/lit-llama/{model_size}/state_dict.pth"
if not tokenizer_path:
tokenizer_path = "./checkpoints/lit-llama/tokenizer.model"

assert os.path.isfile(checkpoint_path)
assert os.path.isfile(tokenizer_path)

Expand All @@ -94,14 +102,14 @@ def main(
if quantize:
print("Running quantization. This may take a minute ...")
# TODO: Initializing the model directly on the device does not work with quantization
model = LLaMA(LLaMAConfig())
model = LLaMA.from_name(model_size)
# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize_model(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
else:
with fabric.device:
model = LLaMA(LLaMAConfig())
model = LLaMA.from_name(model_size)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

Expand Down
18 changes: 17 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,26 @@ def forward(self, x):
return x


llama_configs = {
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
}


@dataclass
class LLaMAConfig:
block_size: int = 4096 # 7B
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096

@classmethod
def from_name(cls, name: str):
return cls(**llama_configs[name])


class LLaMA(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -200,3 +212,7 @@ def step(self, idx, targets):
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss

@classmethod
def from_name(cls, name: str):
return cls(LLaMAConfig.from_name(name))
6 changes: 3 additions & 3 deletions scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def convert_state_dict(state_dict):

def meta_weights_for_nano_model(
*,
output_dir: Path,
ckpt_dir: Path = Path("/srv/data/checkpoints/llama/raw"),
tokenizer_path: Path = Path("/srv/data/checkpoints/llama/raw/tokenizer.model"),
output_dir: Path = Path("checkpoints/lit-llama"),
ckpt_dir: Path = Path("checkpoints/llama/"),
tokenizer_path: Path = Path("checkpoints/llama/tokenizer.model"),
model_size: str = "7B",
):
output_dir = output_dir / model_size
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepare_shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def prepare(
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_meta/tokenizer.model",
tokenizer_path: str = "checkpoints/llama/tokenizer.model",
destination_path: str = "data/shakespeare",
):
os.makedirs(destination_path, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():

train_data, val_data = load_datasets()

config = LLaMAConfig
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

with fabric.device:
Expand Down

0 comments on commit 17d7f7e

Please sign in to comment.