Skip to content

Commit

Permalink
Rollback CUDA graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Sep 4, 2023
1 parent b941df5 commit 177669e
Showing 1 changed file with 13 additions and 134 deletions.
147 changes: 13 additions & 134 deletions src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
See commit history for authorship.
"""
import math
from functools import partial
from typing import Optional, Tuple

import torch
Expand All @@ -17,25 +16,15 @@
FalconLinear,
FalconMLP,
FalconModel,
FalconRotaryEmbedding,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)

KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192

# @torch.jit.script
def rotate_half(x):
x1, x2 = torch.chunk(x, 2, dim=2)
return torch.cat((-x2, x1), dim=-1)


# @torch.jit.script
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)


class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
Expand All @@ -45,55 +34,24 @@ def __init__(self, head_dim: int, base=10000):
self.head_dim = head_dim
self.seq_len_cached = -1

self.cuda_graph = None
self.input_surface = None
self.static_outputs = None

self.cos_sin(
seq_len=INFERENCE_MAX_LENGTH,
past_key_values_length=0,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)

def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
self.input_surface = (query, key, cos, sin)

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
apply_rotary(*self.input_surface)
torch.cuda.current_stream().wait_stream(s)

with torch.cuda.graph(self.cuda_graph):
self.static_outputs = apply_rotary(*self.input_surface)

inputs = (query, key, cos, sin)
for static_input, data in zip(self.input_surface, inputs):
static_input.copy_(data)
self.cuda_graph.replay()
return tuple(o.detach() for o in self.static_outputs)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
# warm up the cache
total_length = max(INFERENCE_MAX_LENGTH, total_length)

if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
with torch.inference_mode(False):
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)

return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
Expand All @@ -103,25 +61,7 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
# print(cos, sin)
# if seq_len == 1 and torch.is_inference_mode_enabled():
# return self._optimized_apply_rotary(query, key, cos, sin)
# else:
return apply_rotary(query, key, cos, sin)


# @torch.jit.script
def split_heads(
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)

query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)


class OptimizedFalconAttention(FalconAttention):
Expand Down Expand Up @@ -158,35 +98,6 @@ def __init__(self, config: FalconConfig):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.qkv_graph = None
self.input_surface = None
self.static_outputs = None

def _optimized_apply_qkv(self, hidden_states):
if self.qkv_graph is None:
self.qkv_graph = torch.cuda.CUDAGraph()
self.input_surface = torch.randn_like(hidden_states)

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
fused_qkv = self.query_key_value(self.input_surface)
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)

with torch.cuda.graph(self.qkv_graph):
static_fused_qkv = self.query_key_value(self.input_surface)
self.static_outputs = self._split_heads(static_fused_qkv)

self.input_surface.copy_(hidden_states)
self.qkv_graph.replay()
return tuple(o.detach() for o in self.static_outputs)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -199,9 +110,6 @@ def forward(
):
assert not output_attentions

# if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
# query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
# else:
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
Expand Down Expand Up @@ -320,32 +228,6 @@ def __init__(self, config: FalconConfig):
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

self.ln_graph = None
self.static_input = None
self.static_outputs = None

def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
self.ln_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.ln_attn(hidden_states)
self.ln_mlp(hidden_states)
torch.cuda.current_stream().wait_stream(s)

with torch.cuda.graph(self.ln_graph):
ln_attn_output = self.ln_attn(hidden_states)
ln_mlp_output = self.ln_mlp(hidden_states)
self.static_outputs = (ln_attn_output, ln_mlp_output)

self.static_input.copy_(hidden_states)
self.ln_graph.replay()
return tuple(o.detach() for o in self.static_outputs)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -359,11 +241,8 @@ def forward(
residual = hidden_states

if self.config.new_decoder_architecture:
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)

Expand Down

0 comments on commit 177669e

Please sign in to comment.