Skip to content

Commit

Permalink
Add comments on RoPE initialization (vllm-project#1176)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 26, 2023
1 parent a425bd9 commit 03ffd0a
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ def __init__(
self.is_neox_style = is_neox_style

# Create the cos and sin cache.
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda")
Expand All @@ -274,7 +283,6 @@ def __init__(

# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
Expand Down

0 comments on commit 03ffd0a

Please sign in to comment.