Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the Falcon block for inference #500

Merged
merged 18 commits into from
Sep 4, 2023
Prev Previous commit
Next Next commit
Fix rotary embeddings
  • Loading branch information
mryab committed Sep 4, 2023
commit d56f57acd2255b41a9d02eef4675c0122a85ab40
13 changes: 10 additions & 3 deletions src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def __init__(self, head_dim: int, base=10000):
self.input_surface = None
self.static_outputs = None

self.cos_sin(
seq_len=INFERENCE_MAX_LENGTH,
past_key_values_length=0,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)

def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
Expand Down Expand Up @@ -80,11 +87,11 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype
emb = emb.float()

self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype))
self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype))
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype))

return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
)

def forward(self, query, key, past_key_values_length=0):
Expand Down