Skip to content

Commit

Permalink
feat(lmdeploy): fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo committed Nov 27, 2023
1 parent 54fbc0c commit 0f47007
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 45 deletions.
3 changes: 2 additions & 1 deletion lmdeploy/pytorch_poc/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass
Expand Down Expand Up @@ -32,6 +32,7 @@ class ModelConfig:
eos_token_id: int
dtype: str
multi_query_attention: bool = False
json_config: dict = field(default_factory=dict)

def get_head_size(self):
return self.hidden_size // self.num_heads
6 changes: 3 additions & 3 deletions lmdeploy/pytorch_poc/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
json_config: dict,
rank: int = 0,
world_size: int = 1,
device_mesh: DeviceMesh = None,
Expand All @@ -51,8 +50,9 @@ def __init__(
self.num_layers = model_config.num_layers
self.num_heads = model_config.num_heads

if 'kv_cache_dtype' in json_config:
self.kv_cache_dtype = eval(json_config['kv_cache_dtype'])
if 'kv_cache_dtype' in model_config.json_config:
self.kv_cache_dtype = eval(
model_config.json_config['kv_cache_dtype'])
else:
self.kv_cache_dtype = model_config.dtype

Expand Down
42 changes: 18 additions & 24 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,48 +84,45 @@ def _build_model_config(model_path: str, hf_config: Any):
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
multi_query_attention=hf_config.multi_query)
multi_query_attention=hf_config.multi_query,
json_config=hf_config.to_dict())
elif 'chatglm' in model_path:
model_config = ModelConfig(
hf_config.hidden_size // hf_config.num_attention_heads *
hf_config.multi_query_group_num,
hf_config.num_layers,
hf_config.multi_query_group_num,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
model_config = ModelConfig(hf_config.hidden_size //
hf_config.num_attention_heads *
hf_config.multi_query_group_num,
hf_config.num_layers,
hf_config.multi_query_group_num,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
json_config=hf_config.to_dict())
else:
model_config = ModelConfig(
hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
model_config = ModelConfig(hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
json_config=hf_config.to_dict())

return model_config


def _build_model_agent(model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
trust_remote_code: bool,
tp: int = 1):
"""create model agent."""
if tp == 1:
model_agent = BaseModelAgent(model_path,
model_config=model_config,
cache_config=cache_config,
json_config=json_config,
trust_remote_code=trust_remote_code)
else:
model_agent = TPModelAgent(model_path,
model_config=model_config,
cache_config=cache_config,
json_config=json_config,
world_size=tp,
trust_remote_code=trust_remote_code)
return model_agent
Expand Down Expand Up @@ -163,15 +160,12 @@ def __init__(self,
torch_dtype = _get_torch_dtype(hf_config)
self.torch_dtype = torch_dtype

self.json_config = hf_config.to_dict()

model_config = _build_model_config(model_path, hf_config)

self.model_agent = _build_model_agent(
model_path,
model_config=model_config,
cache_config=cache_config,
json_config=self.json_config,
trust_remote_code=trust_remote_code,
tp=tp)

Expand Down
18 changes: 5 additions & 13 deletions lmdeploy/pytorch_poc/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,9 @@ def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
trust_remote_code: bool = True):
self.model_config = model_config
self.cache_config = cache_config
self.json_config = json_config
torch_dtype = model_config.dtype

self.patched_model = self._build_model(
Expand All @@ -235,8 +233,7 @@ def __init__(self,

_update_cache_config(model_config, cache_config)

self.cache_engine = CacheEngine(cache_config, model_config,
json_config)
self.cache_engine = CacheEngine(cache_config, model_config)
self.stream = torch.cuda.Stream()

def _build_model(self,
Expand Down Expand Up @@ -271,7 +268,7 @@ def forward(self, inputs: ModelInputs, swap_in_map: Dict[int, int],
self.patched_model,
inputs,
self.cache_engine,
self.json_config,
self.model_config.json_config,
world_size=1,
stream=self.stream,
)
Expand Down Expand Up @@ -416,7 +413,6 @@ def _tp_model_loop(
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
in_que: mp.Queue,
out_que: mp.Queue,
world_size: int,
Expand Down Expand Up @@ -452,7 +448,7 @@ def _tp_model_loop(
patched_model,
inputs,
cache_engine,
json_config,
model_config.json_config,
world_size=world_size,
stream=stream,
)
Expand Down Expand Up @@ -508,29 +504,26 @@ def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
world_size: int,
trust_remote_code: bool = True) -> None:
mp.set_start_method('spawn')
self.world_size = world_size
self.model_config = model_config
self.cache_config = cache_config
self.json_config = json_config
self.tp_model_in_que = mp.Queue(10)
self.tp_model_out_que = mp.Queue(10)

self.patch_model_tp(model_path,
model_config=model_config,
cache_config=cache_config,
json_config=json_config,
in_que=self.tp_model_in_que,
out_que=self.tp_model_out_que,
world_size=world_size,
trust_remote_code=trust_remote_code)

def patch_model_tp(self, model_path: str, model_config: ModelConfig,
cache_config: CacheConfig, json_config: dict,
in_que: mp.Queue, out_que: mp.Queue, world_size: int,
cache_config: CacheConfig, in_que: mp.Queue,
out_que: mp.Queue, world_size: int,
trust_remote_code: bool):
"""Start tensor parallel sub process.
Expand All @@ -553,7 +546,6 @@ def patch_model_tp(self, model_path: str, model_config: ModelConfig,
(model_path, ),
dict(model_config=model_config,
cache_config=cache_config,
json_config=json_config,
in_que=in_que,
out_que=out_que,
world_size=world_size,
Expand Down
23 changes: 23 additions & 0 deletions lmdeploy/pytorch_poc/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def attention_forward_with_paged_attention(


def quant_kv(key: torch.Tensor, value: torch.Tensor, out_type: torch.dtype):
"""Quantize key and value of attention to `out_type`.
Args:
key (torch.Tensor): Attention key.
value (torch.Tensor): Attention value.
out_type (torch.dtype): Output data type.
"""
assert out_type is torch.int8
# quantize key and value
_min = torch.min(key, axis=-1).values
Expand Down Expand Up @@ -263,6 +270,15 @@ def quant_kv(key: torch.Tensor, value: torch.Tensor, out_type: torch.dtype):

def dequant_kv(context: Any, layer_id: str, key_int8: torch.Tensor,
value_int8: torch.Tensor, out_type: torch.dtype):
"""Dequantize key and value of attention to `out_type`.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
key (torch.Tensor): Quantized attention key.
value (torch.Tensor): Quantized attention value.
out_type (torch.dtype): output data type.
"""
qparams = context.get_output(layer_id)

key_scale = qparams['key_scale']
Expand All @@ -278,6 +294,13 @@ def dequant_kv(context: Any, layer_id: str, key_int8: torch.Tensor,


def sync_qparam_to_context(context: Any, layer_id: str, qparams: dict):
"""Merge quantization param to context.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
qparams (dict): Quantization param of current step.
"""
if context.inputs.meta is not None:
last_qparam = context.inputs.meta[layer_id]
for _k in last_qparam.keys():
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/pytorch_poc/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ def _contiguous_batching_forward_rerope_impl(
def apply_rotary_pos_emb_rerope(q, k, cos, sin, position_ids):
assert 1 == position_ids.shape[0]

_, __, seq_len, dim = cos.shape
_, seq_len = position_ids.shape
_, dim = cos.shape

cos = cos[0, 0][position_ids].reshape(
cos = cos[position_ids].reshape(
seq_len, 1, dim) # [bs, seq_len, dim] to [seq_len, 1, dim]
sin = sin[0, 0][position_ids].reshape(
sin = sin[position_ids].reshape(
seq_len, 1, dim) # [bs, seq_len, dim] to [seq_len, 1, dim]

q_embed = ((q * cos[-q.shape[0]:]) +
Expand Down Expand Up @@ -219,6 +220,7 @@ def _contiguous_batching_forward_impl(
support.
"""
assert not output_attentions

json_config = self.context.context.json_config
use_rerope = 'rerope' in json_config and json_config['rerope']
if use_rerope:
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch_poc/passkey_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def parse_config():
help='interval for evaluation')
parser.add_argument('--num_tests',
type=int,
default=30,
default=1,
help='number of repeat testing for each length')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -136,6 +136,7 @@ def main(args):
# This is a rough ratio to control the number of texts and tokens
for val in range(4096, args.max_tokens, args.interval):
n_garbage = int(3.75 * val // 1024 * 1024)
assert n_garbage > 0
passed_tests = 0
total_tokens = 0

Expand Down

0 comments on commit 0f47007

Please sign in to comment.