Skip to content

Commit

Permalink
Make cos_cached/sin_cached buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Sep 3, 2023
1 parent ae8ed93 commit 8060c55
Showing 1 changed file with 18 additions and 88 deletions.
106 changes: 18 additions & 88 deletions src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconModel,
FalconRotaryEmbedding,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)
Expand Down Expand Up @@ -61,11 +59,26 @@ def _optimized_apply_rotary(self, query, key, cos, sin):
return tuple(o.detach() for o in self.static_outputs)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
# warm up the cache
super().cos_sin(1, INFERENCE_MAX_LENGTH - 1, device=device, dtype=dtype)
return super().cos_sin(
seq_len=seq_len, past_key_values_length=past_key_values_length, device=device, dtype=dtype
total_length = max(INFERENCE_MAX_LENGTH, total_length)

if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype))
self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype))

return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
)

def forward(self, query, key, past_key_values_length=0):
Expand Down Expand Up @@ -372,86 +385,3 @@ def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state


class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):
batch_size, seq_length = hidden_states.shape[:2]

if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length

attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)

outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs,
)

if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)

return outputs

def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]

if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)

return (key_states, value_states)

def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)

assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)

return (key_states, value_states)

def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads

state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state

def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads

state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state

0 comments on commit 8060c55

Please sign in to comment.