Skip to content

Commit

Permalink
Fix DeciLM (vllm-project#2883)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz authored and jimpang committed Feb 20, 2024
1 parent 5d34065 commit 2e6db3e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion vllm/model_executor/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
from transformers import PretrainedConfig

from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader,
Expand Down Expand Up @@ -56,10 +57,13 @@ def __init__(
self,
config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, linear_method=linear_method)
super().__init__(config=config,
linear_method=linear_method,
lora_config=lora_config)

def load_weights(self,
model_name_or_path: str,
Expand Down

0 comments on commit 2e6db3e

Please sign in to comment.