Skip to content

Commit

Permalink
support inference with multi-GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 18, 2023
1 parent ecf5760 commit cfc024a
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 33 deletions.
13 changes: 0 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,6 @@ CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--checkpoint_dir path_to_checkpoint
```

### Deploy the Fine-tuned Model

```python
import sys
sys.path.append("src")
from src import load_pretrained, ModelArguments
model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
model, tokenizer = load_pretrained(model_args)
model = model.cuda()
model.eval()
# model.generate, model.chat()...
```

### Hardware Requirements

| Fine-tune method | Batch size | Mode | GRAM | Speed |
Expand Down
13 changes: 0 additions & 13 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,6 @@ CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--checkpoint_dir path_to_checkpoint
```

### 模型部署

```python
import sys
sys.path.append("src")
from src import load_pretrained, ModelArguments
model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
model, tokenizer = load_pretrained(model_args)
model = model.cuda()
model.eval()
# model.generate, model.chat()...
```

### 硬件需求

| 微调方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
Expand Down
19 changes: 19 additions & 0 deletions examples/deploy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8

import sys
sys.path.append("../src")
import torch
from src import ModelArguments, auto_configure_device_map, load_pretrained

if __name__ == "__main__":
model_args = ModelArguments(checkpoint_dir="path_to_lora_checkpoint")
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
13 changes: 13 additions & 0 deletions examples/export_lora_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# coding=utf-8

import sys
sys.path.append("../src")
from src import ModelArguments, load_pretrained

if __name__ == "__main__":
model_args = ModelArguments(checkpoint_dir="path_to_lora_checkpoint")
model, tokenizer = load_pretrained(model_args)
model = model.get_base_model()
model._keys_to_ignore_on_save = "lora"
model.save_pretrained("path_to_save_model", max_shard_size="1GB")
tokenizer.save_pretrained("path_to_save_model")
1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .utils import (
auto_configure_device_map,
load_pretrained,
ModelArguments
)
10 changes: 8 additions & 2 deletions src/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@


import os
import torch
import signal
import platform

from utils import ModelArguments, load_pretrained
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser


Expand Down Expand Up @@ -36,7 +37,12 @@ def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model = model.cuda()
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()

history = []
Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from .ppo import PPOTrainerForChatGLM

from .config import ModelArguments
from .other import get_logits_processor, plot_loss
from .other import auto_configure_device_map, get_logits_processor, plot_loss
23 changes: 23 additions & 0 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))


def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
r"""
Configures device map for ChatGLM.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
"""
num_layers = 28
layers_per_gpu = 30 / num_gpus
device_map = {"transformer.word_embeddings": 0, "transformer.final_layernorm": 0, "lm_head": 0}
added_layers = 2
target_gpu = 0

for i in range(num_layers):
if added_layers >= layers_per_gpu:
target_gpu += 1
added_layers = 0
assert target_gpu < num_gpus
device_map[f"transformer.layers.{i}"] = target_gpu
added_layers += 1

return device_map


def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
"""
EMA implementation according to TensorBoard.
Expand Down
3 changes: 2 additions & 1 deletion src/utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def ppo_train(self, max_target_length: int) -> None:
loss_meter = AverageMeter()
reward_meter = AverageMeter()

for step in tqdm(range(max_steps)):
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):

for _ in range(self.config.gradient_accumulation_steps):

Expand Down Expand Up @@ -289,6 +289,7 @@ def save_model(self, output_dir: Optional[str] = None) -> None:
output_dir = output_dir if output_dir is not None else self.training_args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
self.accelerator.wait_for_everyone()
save_trainable_params(output_dir, self.model)
torch.save(self.training_args, os.path.join(output_dir, TRAINING_ARGS_NAME))
torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME))
12 changes: 9 additions & 3 deletions src/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
# This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/web_demo.py


import gradio as gr
import torch
import mdtex2html
import gradio as gr

from utils import ModelArguments, load_pretrained
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser


parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model = model.cuda()
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()


Expand Down

0 comments on commit cfc024a

Please sign in to comment.